samirk08 commited on
Commit
8285881
·
verified ·
1 Parent(s): a9383ff

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +23 -0
  2. .gitignore +14 -0
  3. LICENSE +107 -0
  4. README.md +191 -8
  5. app.py +391 -0
  6. app_flux.py +305 -0
  7. app_p2p.py +567 -0
  8. densepose/__init__.py +22 -0
  9. densepose/config.py +277 -0
  10. densepose/converters/__init__.py +17 -0
  11. densepose/converters/base.py +95 -0
  12. densepose/converters/builtin.py +33 -0
  13. densepose/converters/chart_output_hflip.py +73 -0
  14. densepose/converters/chart_output_to_chart_result.py +190 -0
  15. densepose/converters/hflip.py +36 -0
  16. densepose/converters/segm_to_mask.py +152 -0
  17. densepose/converters/to_chart_result.py +72 -0
  18. densepose/converters/to_mask.py +51 -0
  19. densepose/data/__init__.py +27 -0
  20. densepose/data/build.py +738 -0
  21. densepose/data/combined_loader.py +46 -0
  22. densepose/data/dataset_mapper.py +170 -0
  23. densepose/data/image_list_dataset.py +74 -0
  24. densepose/data/inference_based_loader.py +174 -0
  25. densepose/data/meshes/__init__.py +7 -0
  26. densepose/data/meshes/builtin.py +103 -0
  27. densepose/data/meshes/catalog.py +73 -0
  28. densepose/data/samplers/__init__.py +10 -0
  29. densepose/data/samplers/densepose_base.py +205 -0
  30. densepose/data/samplers/densepose_confidence_based.py +110 -0
  31. densepose/data/samplers/densepose_cse_base.py +141 -0
  32. densepose/data/samplers/densepose_cse_confidence_based.py +121 -0
  33. densepose/data/samplers/densepose_cse_uniform.py +14 -0
  34. densepose/data/samplers/densepose_uniform.py +43 -0
  35. densepose/data/samplers/mask_from_densepose.py +30 -0
  36. densepose/data/samplers/prediction_to_gt.py +100 -0
  37. densepose/data/transform/__init__.py +5 -0
  38. densepose/data/transform/image.py +41 -0
  39. densepose/data/utils.py +40 -0
  40. densepose/data/video/__init__.py +19 -0
  41. densepose/data/video/frame_selector.py +89 -0
  42. densepose/data/video/video_keyframe_dataset.py +304 -0
  43. densepose/engine/__init__.py +5 -0
  44. densepose/engine/trainer.py +260 -0
  45. densepose/evaluation/__init__.py +5 -0
  46. densepose/evaluation/d2_evaluator_adapter.py +52 -0
  47. densepose/evaluation/densepose_coco_evaluation.py +1305 -0
  48. densepose/evaluation/evaluator.py +423 -0
  49. densepose/evaluation/mesh_alignment_evaluator.py +68 -0
  50. densepose/evaluation/tensor_storage.py +241 -0
.gitattributes CHANGED
@@ -33,3 +33,26 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ detectron2/_C.cpython-39-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
37
+ resource/demo/example/condition/overall/21744571_51588794_1000.jpg filter=lfs diff=lfs merge=lfs -text
38
+ resource/demo/example/condition/overall/23962182_54027982_1000.jpg filter=lfs diff=lfs merge=lfs -text
39
+ resource/demo/example/condition/person/baumu30483223c3_1719437121402_2-0._QL90_UX564_V12524t6_.jpg filter=lfs diff=lfs merge=lfs -text
40
+ resource/demo/example/condition/person/mison407622250d_1719258948458_2-0._QL90_UX564_V12524t6_.jpg filter=lfs diff=lfs merge=lfs -text
41
+ resource/demo/example/condition/person/mothr22044226e8_1718142523286_2-0._QL90_UX564_V12524t6_.jpg filter=lfs diff=lfs merge=lfs -text
42
+ resource/demo/example/condition/upper/21514384_52353349_1000.jpg filter=lfs diff=lfs merge=lfs -text
43
+ resource/demo/example/condition/upper/22790049_53294275_1000.jpg filter=lfs diff=lfs merge=lfs -text
44
+ resource/demo/example/condition/upper/24083449_54173465_2048.jpg filter=lfs diff=lfs merge=lfs -text
45
+ resource/demo/example/person/men/Simon_1.png filter=lfs diff=lfs merge=lfs -text
46
+ resource/demo/example/person/men/Yifeng_0.png filter=lfs diff=lfs merge=lfs -text
47
+ resource/demo/example/person/men/model_5.png filter=lfs diff=lfs merge=lfs -text
48
+ resource/demo/example/person/men/model_7.png filter=lfs diff=lfs merge=lfs -text
49
+ resource/demo/example/person/women/049713_0.jpg filter=lfs diff=lfs merge=lfs -text
50
+ resource/demo/example/person/women/1-model_3.png filter=lfs diff=lfs merge=lfs -text
51
+ resource/demo/example/person/women/2-model_4.png filter=lfs diff=lfs merge=lfs -text
52
+ resource/demo/example/person/women/model_8.png filter=lfs diff=lfs merge=lfs -text
53
+ resource/img/architecture.jpg filter=lfs diff=lfs merge=lfs -text
54
+ resource/img/comfyui-1.png filter=lfs diff=lfs merge=lfs -text
55
+ resource/img/comfyui.png filter=lfs diff=lfs merge=lfs -text
56
+ resource/img/efficency.jpg filter=lfs diff=lfs merge=lfs -text
57
+ resource/img/structure.jpg filter=lfs diff=lfs merge=lfs -text
58
+ resource/img/teaser.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ model/__pycache__
3
+ model/DensePose/__pycache__
4
+ model/SCHP/__pycache__
5
+ model/SCHP/*/__pycache__
6
+ resource/demo/output
7
+ resource/demo/example/.DS_Store
8
+ Models
9
+ Datasets
10
+ densepose_
11
+ .vscode
12
+ playground.py
13
+ output
14
+ model/cloth_masker_segformer.py
LICENSE ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
2
+
3
+ Creative Commons Corporation ("Creative Commons") is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an "as-is" basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
4
+
5
+ Using Creative Commons Public Licenses
6
+
7
+ Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
8
+
9
+ Considerations for licensors: Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. More considerations for licensors : wiki.creativecommons.org/Considerations_for_licensors
10
+
11
+ Considerations for the public: By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor's permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. More considerations for the public : wiki.creativecommons.org/Considerations_for_licensees
12
+
13
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
14
+
15
+ By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
16
+
17
+ Section 1 – Definitions.
18
+
19
+ a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
20
+ b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
21
+ c. BY-NC-SA Compatible License means a license listed at creativecommons.org/compatiblelicenses, approved by Creative Commons as essentially the equivalent of this Public License.
22
+ d. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
23
+ e. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
24
+ f. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
25
+ g. License Elements means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike.
26
+ h. Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
27
+ i. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
28
+ j. Licensor means the individual(s) or entity(ies) granting rights under this Public License.
29
+ k. NonCommercial means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
30
+ l. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
31
+ m. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
32
+ n. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
33
+ Section 2 – Scope.
34
+
35
+ a. License grant.
36
+ 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
37
+ A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
38
+ B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
39
+ 2. Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
40
+ 3. Term. The term of this Public License is specified in Section 6(a).
41
+ 4. Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
42
+ 5. Downstream recipients.
43
+ A. Offer from the Licensor – Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
44
+ B. Additional offer from the Licensor – Adapted Material. Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter's License You apply.
45
+ C. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
46
+ 6. No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
47
+ b. Other rights.
48
+ 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
49
+ 2. Patent and trademark rights are not licensed under this Public License.
50
+ 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
51
+ Section 3 – License Conditions.
52
+
53
+ Your exercise of the Licensed Rights is expressly made subject to the following conditions.
54
+
55
+ a. Attribution.
56
+ 1. If You Share the Licensed Material (including in modified form), You must:
57
+ A. retain the following if it is supplied by the Licensor with the Licensed Material:
58
+ i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
59
+ ii. a copyright notice;
60
+ iii. a notice that refers to this Public License;
61
+ iv. a notice that refers to the disclaimer of warranties;
62
+ v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
63
+
64
+ B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
65
+ C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
66
+ 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
67
+ 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
68
+ b. ShareAlike.In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply.
69
+ 1. The Adapter's License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License.
70
+ 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material.
71
+ 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply.
72
+ Section 4 – Sui Generis Database Rights.
73
+
74
+ Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
75
+
76
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
77
+ b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and
78
+ c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
79
+ For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
80
+ Section 5 – Disclaimer of Warranties and Limitation of Liability.
81
+
82
+ a. Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.
83
+ b. To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.
84
+ c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
85
+ Section 6 – Term and Termination.
86
+
87
+ a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
88
+ b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
89
+ 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
90
+ 2. upon express reinstatement by the Licensor.
91
+ For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
92
+
93
+ c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
94
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
95
+ Section 7 – Other Terms and Conditions.
96
+
97
+ a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
98
+ b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
99
+ Section 8 – Interpretation.
100
+
101
+ a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
102
+ b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
103
+ c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
104
+ d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
105
+ Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the "Licensor." The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark "Creative Commons" or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
106
+
107
+ Creative Commons may be contacted at creativecommons.org.
README.md CHANGED
@@ -1,12 +1,195 @@
1
  ---
2
- title: Vtontry
3
- emoji: 🐨
4
- colorFrom: purple
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.35.0
8
  app_file: app.py
9
- pinned: false
 
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: vtontry
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 4.41.0
6
  ---
7
+ # [ICLR 25]🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models
8
+
9
+ <div style="display: flex; justify-content: center; align-items: center;">
10
+ <a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
11
+ <img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
12
+ </a>
13
+ <a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
14
+ <img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
15
+ </a>
16
+ <a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
17
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
18
+ </a>
19
+ <a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
20
+ <img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
21
+ </a>
22
+ <a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
23
+ <img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
24
+ </a>
25
+ <a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
26
+ <img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
27
+ </a>
28
+ <a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
29
+ <img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
30
+ </a>
31
+ </div>
32
+
33
+
34
+ **CatVTON** is a simple and efficient virtual try-on diffusion model with ***1) Lightweight Network (899.06M parameters totally)***, ***2) Parameter-Efficient Training (49.57M parameters trainable)*** and ***3) Simplified Inference (< 8G VRAM for 1024X768 resolution)***.
35
+ <div align="center">
36
+ <img src="resource/img/teaser.jpg" width="100%" height="100%"/>
37
+ </div>
38
+
39
+
40
+
41
+ ## Updates
42
+ - **`2025/02/24`**: 🎉 We are excited to announce [**CatV2TON**](https://github.com/Zheng-Chong/CatV2TON), our new DiT-based model that supports both **image and video try-on**! Check it out!
43
+ - **`2025/02/20`**: Our [**Paper on ArXiv**](http://arxiv.org/abs/2407.15886) has been updated to v2, which includes more details.
44
+ - **`2025/01/24`**: 🥳 CatVTON has been accepted to **ICLR 2025**!
45
+ - **`2024/12/20`**: 😄 Code for gradio app of **CatVTON-FLUX** has been released! It is not a stable version, but it is a good start!
46
+ - **`2024/12/19`**: [**CatVTON-FLUX**](https://huggingface.co/spaces/zhengchong/CatVTON) has been released! It is a extremely lightweight LoRA (only 37.4M checkpints) for [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev), the lora weights are available in **[huggingface repo](https://huggingface.co/zhengchong/CatVTON/tree/main/flux-lora)**, code will be released soon!
47
+ - **`2024/11/26`**: Our **unified vision-based model for image and video try-on** will be released soon, bringing a brand-new virtual try-on experience! While our demo page will be temporarily taken offline, [**the demo on HuggingFace Space**](https://huggingface.co/spaces/zhengchong/CatVTON) will remain available for use !
48
+ - **`2024/10/17`**:[**Mask-free version**](https://huggingface.co/zhengchong/CatVTON-MaskFree)🤗 of CatVTON is release !
49
+ - **`2024/10/13`**: We have built a repo [**Awesome-Try-On-Models**](https://github.com/Zheng-Chong/Awesome-Try-On-Models) that focuses on image, video, and 3D-based try-on models published after 2023, aiming to provide insights into the latest technological trends. If you're interested, feel free to contribute or give it a 🌟 star!
50
+ - **`2024/08/13`**: We localize DensePose & SCHP to avoid certain environment issues.
51
+ - **`2024/08/10`**: Our 🤗 [**HuggingFace Space**](https://huggingface.co/spaces/zhengchong/CatVTON) is available now! Thanks for the grant from [**ZeroGPU**](https://huggingface.co/zero-gpu-explorers)!
52
+ - **`2024/08/09`**: [**Evaluation code**](https://github.com/Zheng-Chong/CatVTON?tab=readme-ov-file#3-calculate-metrics) is provided to calculate metrics 📚.
53
+ - **`2024/07/27`**: We provide code and workflow for deploying CatVTON on [**ComfyUI**](https://github.com/Zheng-Chong/CatVTON?tab=readme-ov-file#comfyui-workflow) 💥.
54
+ - **`2024/07/24`**: Our [**Paper on ArXiv**](http://arxiv.org/abs/2407.15886) is available 🥳!
55
+ - **`2024/07/22`**: Our [**App Code**](https://github.com/Zheng-Chong/CatVTON/blob/main/app.py) is released, deploy and enjoy CatVTON on your mechine 🎉!
56
+ - **`2024/07/21`**: Our [**Inference Code**](https://github.com/Zheng-Chong/CatVTON/blob/main/inference.py) and [**Weights** 🤗](https://huggingface.co/zhengchong/CatVTON) are released.
57
+ - **`2024/07/11`**: Our [**Online Demo**](https://huggingface.co/spaces/zhengchong/CatVTON) is released 😁.
58
+
59
+
60
+
61
+
62
+ ## Installation
63
+
64
+ Create a conda environment & Install requirments
65
+ ```shell
66
+ conda create -n catvton python==3.9.0
67
+ conda activate catvton
68
+ cd CatVTON-main # or your path to CatVTON project dir
69
+ pip install -r requirements.txt
70
+ ```
71
+
72
+ ## Deployment
73
+ ### ComfyUI Workflow
74
+ We have modified the main code to enable easy deployment of CatVTON on [ComfyUI](https://github.com/comfyanonymous/ComfyUI). Due to the incompatibility of the code structure, we have released this part in the [Releases](https://github.com/Zheng-Chong/CatVTON/releases/tag/ComfyUI), which includes the code placed under `custom_nodes` of ComfyUI and our workflow JSON files.
75
+
76
+ To deploy CatVTON to your ComfyUI, follow these steps:
77
+ 1. Install all the requirements for both CatVTON and ComfyUI, refer to [Installation Guide for CatVTON](https://github.com/Zheng-Chong/CatVTON/blob/main/INSTALL.md) and [Installation Guide for ComfyUI](https://github.com/comfyanonymous/ComfyUI?tab=readme-ov-file#installing).
78
+ 2. Download [`ComfyUI-CatVTON.zip`](https://github.com/Zheng-Chong/CatVTON/releases/download/ComfyUI/ComfyUI-CatVTON.zip) and unzip it in the `custom_nodes` folder under your ComfyUI project (clone from [ComfyUI](https://github.com/comfyanonymous/ComfyUI)).
79
+ 3. Run the ComfyUI.
80
+ 4. Download [`catvton_workflow.json`](https://github.com/Zheng-Chong/CatVTON/releases/download/ComfyUI/catvton_workflow.json) and drag it into you ComfyUI webpage and enjoy 😆!
81
+
82
+ > Problems under Windows OS, please refer to [issue#8](https://github.com/Zheng-Chong/CatVTON/issues/8).
83
+ >
84
+ When you run the CatVTON workflow for the first time, the weight files will be automatically downloaded, usually taking dozens of minutes.
85
+
86
+ <div align="center">
87
+ <img src="resource/img/comfyui-1.png" width="100%" height="100%"/>
88
+ </div>
89
+
90
+ <!-- <div align="center">
91
+ <img src="resource/img/comfyui.png" width="100%" height="100%"/>
92
+ </div> -->
93
+
94
+ ### Gradio App
95
+
96
+ To deploy the Gradio App for CatVTON on your machine, run the following command, and checkpoints will be automatically downloaded from HuggingFace.
97
+
98
+ ```PowerShell
99
+ CUDA_VISIBLE_DEVICES=0 python app.py \
100
+ --output_dir="resource/demo/output" \
101
+ --mixed_precision="bf16" \
102
+ --allow_tf32
103
+ ```
104
+ When using `bf16` precision, generating results with a resolution of `1024x768` only requires about `8G` VRAM.
105
+
106
+ ## Inference
107
+ ### 1. Data Preparation
108
+ Before inference, you need to download the [VITON-HD](https://github.com/shadow2496/VITON-HD) or [DressCode](https://github.com/aimagelab/dress-code) dataset.
109
+ Once the datasets are downloaded, the folder structures should look like these:
110
+ ```
111
+ ├── VITON-HD
112
+ | ├── test_pairs_unpaired.txt
113
+ │ ├── test
114
+ | | ├── image
115
+ │ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
116
+ │ │ ├── cloth
117
+ │ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
118
+ │ │ ├── agnostic-mask
119
+ │ │ │ ├── [000006_00_mask.png | 000008_00.png | ...]
120
+ ...
121
+ ```
122
+
123
+ ```
124
+ ├── DressCode
125
+ | ├── test_pairs_paired.txt
126
+ | ├── test_pairs_unpaired.txt
127
+ │ ├── [dresses | lower_body | upper_body]
128
+ | | ├── test_pairs_paired.txt
129
+ | | ├── test_pairs_unpaired.txt
130
+ │ │ ├── images
131
+ │ │ │ ├── [013563_0.jpg | 013563_1.jpg | 013564_0.jpg | 013564_1.jpg | ...]
132
+ │ │ ├── agnostic_masks
133
+ │ │ │ ├── [013563_0.png| 013564_0.png | ...]
134
+ ...
135
+ ```
136
+ For the DressCode dataset, we provide script to preprocessed agnostic masks, run the following command:
137
+ ```PowerShell
138
+ CUDA_VISIBLE_DEVICES=0 python preprocess_agnostic_mask.py \
139
+ --data_root_path <your_path_to_DressCode>
140
+ ```
141
+
142
+ ### 2. Inference on VTIONHD/DressCode
143
+ To run the inference on the DressCode or VITON-HD dataset, run the following command, checkpoints will be automatically downloaded from HuggingFace.
144
+
145
+ ```PowerShell
146
+ CUDA_VISIBLE_DEVICES=0 python inference.py \
147
+ --dataset [dresscode | vitonhd] \
148
+ --data_root_path <path> \
149
+ --output_dir <path>
150
+ --dataloader_num_workers 8 \
151
+ --batch_size 8 \
152
+ --seed 555 \
153
+ --mixed_precision [no | fp16 | bf16] \
154
+ --allow_tf32 \
155
+ --repaint \
156
+ --eval_pair
157
+ ```
158
+ ### 3. Calculate Metrics
159
+
160
+ After obtaining the inference results, calculate the metrics using the following command:
161
+
162
+ ```PowerShell
163
+ CUDA_VISIBLE_DEVICES=0 python eval.py \
164
+ --gt_folder <your_path_to_gt_image_folder> \
165
+ --pred_folder <your_path_to_predicted_image_folder> \
166
+ --paired \
167
+ --batch_size=16 \
168
+ --num_workers=16
169
+ ```
170
+
171
+ - `--gt_folder` and `--pred_folder` should be folders that contain **only images**.
172
+ - To evaluate the results in a paired setting, use `--paired`; for an unpaired setting, simply omit it.
173
+ - `--batch_size` and `--num_workers` should be adjusted based on your machine.
174
+
175
+
176
+ ## Acknowledgement
177
+ Our code is modified based on [Diffusers](https://github.com/huggingface/diffusers). We adopt [Stable Diffusion v1.5 inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting) as the base model. We use [SCHP](https://github.com/GoGoDuck912/Self-Correction-Human-Parsing/tree/master) and [DensePose](https://github.com/facebookresearch/DensePose) to automatically generate masks in our [Gradio](https://github.com/gradio-app/gradio) App and [ComfyUI](https://github.com/comfyanonymous/ComfyUI) workflow. Thanks to all the contributors!
178
+
179
+ ## License
180
+ All the materials, including code, checkpoints, and demo, are made available under the [Creative Commons BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) license. You are free to copy, redistribute, remix, transform, and build upon the project for non-commercial purposes, as long as you give appropriate credit and distribute your contributions under the same license.
181
+
182
+
183
+ ## Citation
184
 
185
+ ```bibtex
186
+ @misc{chong2024catvtonconcatenationneedvirtual,
187
+ title={CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models},
188
+ author={Zheng Chong and Xiao Dong and Haoxiang Li and Shiyue Zhang and Wenqing Zhang and Xujie Zhang and Hanqing Zhao and Xiaodan Liang},
189
+ year={2024},
190
+ eprint={2407.15886},
191
+ archivePrefix={arXiv},
192
+ primaryClass={cs.CV},
193
+ url={https://arxiv.org/abs/2407.15886},
194
+ }
195
+ ```
app.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from datetime import datetime
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from huggingface_hub import snapshot_download
10
+ from PIL import Image
11
+
12
+ # Set memory growth for MPS to prevent out of memory errors
13
+ os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
14
+
15
+ from model.cloth_masker import AutoMasker, vis_mask
16
+ from model.pipeline import CatVTONPipeline
17
+ from utils import init_weight_dtype, resize_and_crop, resize_and_padding
18
+
19
+ def parse_args():
20
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
21
+ parser.add_argument(
22
+ "--base_model_path",
23
+ type=str,
24
+ default="booksforcharlie/stable-diffusion-inpainting", # Change to a copy repo as runawayml delete original repo
25
+ help=(
26
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
27
+ ),
28
+ )
29
+ parser.add_argument(
30
+ "--resume_path",
31
+ type=str,
32
+ default="zhengchong/CatVTON",
33
+ help=(
34
+ "The Path to the checkpoint of trained tryon model."
35
+ ),
36
+ )
37
+ parser.add_argument(
38
+ "--output_dir",
39
+ type=str,
40
+ default="resource/demo/output",
41
+ help="The output directory where the model predictions will be written.",
42
+ )
43
+
44
+ parser.add_argument(
45
+ "--width",
46
+ type=int,
47
+ default=768,
48
+ help=(
49
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
50
+ " resolution"
51
+ ),
52
+ )
53
+ parser.add_argument(
54
+ "--height",
55
+ type=int,
56
+ default=1024,
57
+ help=(
58
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
59
+ " resolution"
60
+ ),
61
+ )
62
+ parser.add_argument(
63
+ "--repaint",
64
+ action="store_true",
65
+ help="Whether to repaint the result image with the original background."
66
+ )
67
+ parser.add_argument(
68
+ "--allow_tf32",
69
+ action="store_true",
70
+ default=True,
71
+ help=(
72
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
73
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
74
+ ),
75
+ )
76
+ parser.add_argument(
77
+ "--mixed_precision",
78
+ type=str,
79
+ default="no",
80
+ choices=["no", "fp16", "bf16"],
81
+ help=(
82
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
83
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
84
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
85
+ ),
86
+ )
87
+
88
+ args = parser.parse_args()
89
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
90
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
91
+ args.local_rank = env_local_rank
92
+
93
+ return args
94
+
95
+ def image_grid(imgs, rows, cols):
96
+ assert len(imgs) == rows * cols
97
+
98
+ w, h = imgs[0].size
99
+ grid = Image.new("RGB", size=(cols * w, rows * h))
100
+
101
+ for i, img in enumerate(imgs):
102
+ grid.paste(img, box=(i % cols * w, i // cols * h))
103
+ return grid
104
+
105
+
106
+ args = parse_args()
107
+ repo_path = snapshot_download(repo_id=args.resume_path)
108
+
109
+ # Auto-detect device (CUDA if available, otherwise CPU)
110
+ # Note: MPS is disabled due to memory and compatibility issues
111
+ if torch.cuda.is_available():
112
+ device = 'cuda'
113
+ else:
114
+ device = 'cpu'
115
+ print("Note: Running on CPU. This will be slower but more stable.")
116
+ print(f"Using device: {device}")
117
+
118
+ # Pipeline
119
+ pipeline = CatVTONPipeline(
120
+ base_ckpt=args.base_model_path,
121
+ attn_ckpt=repo_path,
122
+ attn_ckpt_version="mix",
123
+ weight_dtype=init_weight_dtype(args.mixed_precision),
124
+ use_tf32=args.allow_tf32 and torch.cuda.is_available(), # Only use TF32 if CUDA is available
125
+ device=device
126
+ )
127
+ # AutoMasker
128
+ mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
129
+ automasker = AutoMasker(
130
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
131
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
132
+ device=device,
133
+ )
134
+
135
+ def submit_function(
136
+ person_image,
137
+ cloth_image,
138
+ cloth_type,
139
+ num_inference_steps,
140
+ guidance_scale,
141
+ seed,
142
+ show_type
143
+ ):
144
+ person_image, mask = person_image["background"], person_image["layers"][0]
145
+ mask = Image.open(mask).convert("L")
146
+ if len(np.unique(np.array(mask))) == 1:
147
+ mask = None
148
+ else:
149
+ mask = np.array(mask)
150
+ mask[mask > 0] = 255
151
+ mask = Image.fromarray(mask)
152
+
153
+ tmp_folder = args.output_dir
154
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
155
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
156
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
157
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
158
+
159
+ generator = None
160
+ if seed != -1:
161
+ generator = torch.Generator(device=device).manual_seed(seed)
162
+
163
+ person_image = Image.open(person_image).convert("RGB")
164
+ cloth_image = Image.open(cloth_image).convert("RGB")
165
+
166
+ # Use default resolution
167
+ target_width = args.width
168
+ target_height = args.height
169
+
170
+ person_image = resize_and_crop(person_image, (target_width, target_height))
171
+ cloth_image = resize_and_padding(cloth_image, (target_width, target_height))
172
+
173
+ # Process mask
174
+ if mask is not None:
175
+ mask = resize_and_crop(mask, (target_width, target_height))
176
+ else:
177
+ mask = automasker(
178
+ person_image,
179
+ cloth_type
180
+ )['mask']
181
+ mask = mask_processor.blur(mask, blur_factor=9)
182
+
183
+ # Inference
184
+ # try:
185
+ result_image = pipeline(
186
+ image=person_image,
187
+ condition_image=cloth_image,
188
+ mask=mask,
189
+ num_inference_steps=num_inference_steps,
190
+ guidance_scale=guidance_scale,
191
+ generator=generator
192
+ )[0]
193
+ # except Exception as e:
194
+ # raise gr.Error(
195
+ # "An error occurred. Please try again later: {}".format(e)
196
+ # )
197
+
198
+ # Post-process
199
+ masked_person = vis_mask(person_image, mask)
200
+ save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
201
+ save_result_image.save(result_save_path)
202
+ if show_type == "result only":
203
+ return result_image
204
+ else:
205
+ width, height = person_image.size
206
+ if show_type == "input & result":
207
+ condition_width = width // 2
208
+ conditions = image_grid([person_image, cloth_image], 2, 1)
209
+ else:
210
+ condition_width = width // 3
211
+ conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
212
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
213
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
214
+ new_result_image.paste(conditions, (0, 0))
215
+ new_result_image.paste(result_image, (condition_width + 5, 0))
216
+ return new_result_image
217
+
218
+
219
+ def person_example_fn(image_path):
220
+ return image_path
221
+
222
+ HEADER = """
223
+ <h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
224
+ <div style="display: flex; justify-content: center; align-items: center;">
225
+ <a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
226
+ <img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
227
+ </a>
228
+ <a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
229
+ <img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
230
+ </a>
231
+ <a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
232
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
233
+ </a>
234
+ <a href="http://120.76.142.206:8888" style="margin: 0 2px;">
235
+ <img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
236
+ </a>
237
+ <a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
238
+ <img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
239
+ </a>
240
+ <a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
241
+ <img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
242
+ </a>
243
+ <a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
244
+ <img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
245
+ </a>
246
+ </div>
247
+ <br>
248
+ · This demo and our weights are only for <span>Non-commercial Use</span>. <br>
249
+ · You can try CatVTON in our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a> or our <a href="http://120.76.142.206:8888">online demo</a> (run on 3090). <br>
250
+ · Thanks to <a href="https://huggingface.co/zero-gpu-explorers">ZeroGPU</a> for providing A100 for our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a>. <br>
251
+ · SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
252
+ """
253
+
254
+ def app_gradio():
255
+ with gr.Blocks(title="CatVTON") as demo:
256
+ gr.Markdown(HEADER)
257
+ with gr.Row():
258
+ with gr.Column(scale=1, min_width=350):
259
+ with gr.Row():
260
+ image_path = gr.Image(
261
+ type="filepath",
262
+ interactive=True,
263
+ visible=False,
264
+ )
265
+ person_image = gr.ImageEditor(
266
+ interactive=True, label="Person Image", type="filepath"
267
+ )
268
+
269
+ with gr.Row():
270
+ with gr.Column(scale=1, min_width=230):
271
+ cloth_image = gr.Image(
272
+ interactive=True, label="Condition Image", type="filepath"
273
+ )
274
+ with gr.Column(scale=1, min_width=120):
275
+ gr.Markdown(
276
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
277
+ )
278
+ cloth_type = gr.Radio(
279
+ label="Try-On Cloth Type",
280
+ choices=["upper", "lower", "overall"],
281
+ value="upper",
282
+ )
283
+
284
+
285
+ submit = gr.Button("Submit")
286
+ gr.Markdown(
287
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
288
+ )
289
+
290
+ gr.Markdown(
291
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
292
+ )
293
+ with gr.Accordion("Advanced Options", open=False):
294
+ num_inference_steps = gr.Slider(
295
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
296
+ )
297
+ # Guidence Scale
298
+ guidance_scale = gr.Slider(
299
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
300
+ )
301
+ # Random Seed
302
+ seed = gr.Slider(
303
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
304
+ )
305
+ show_type = gr.Radio(
306
+ label="Show Type",
307
+ choices=["result only", "input & result", "input & mask & result"],
308
+ value="input & mask & result",
309
+ )
310
+
311
+ with gr.Column(scale=2, min_width=500):
312
+ result_image = gr.Image(interactive=False, label="Result")
313
+ with gr.Row():
314
+ # Photo Examples
315
+ root_path = "resource/demo/example"
316
+ with gr.Column():
317
+ men_exm = gr.Examples(
318
+ examples=[
319
+ os.path.join(root_path, "person", "men", _)
320
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
321
+ ],
322
+ examples_per_page=4,
323
+ inputs=image_path,
324
+ label="Person Examples ①",
325
+ )
326
+ women_exm = gr.Examples(
327
+ examples=[
328
+ os.path.join(root_path, "person", "women", _)
329
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
330
+ ],
331
+ examples_per_page=4,
332
+ inputs=image_path,
333
+ label="Person Examples ②",
334
+ )
335
+ gr.Markdown(
336
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
337
+ )
338
+ with gr.Column():
339
+ condition_upper_exm = gr.Examples(
340
+ examples=[
341
+ os.path.join(root_path, "condition", "upper", _)
342
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
343
+ ],
344
+ examples_per_page=4,
345
+ inputs=cloth_image,
346
+ label="Condition Upper Examples",
347
+ )
348
+ condition_overall_exm = gr.Examples(
349
+ examples=[
350
+ os.path.join(root_path, "condition", "overall", _)
351
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
352
+ ],
353
+ examples_per_page=4,
354
+ inputs=cloth_image,
355
+ label="Condition Overall Examples",
356
+ )
357
+ condition_person_exm = gr.Examples(
358
+ examples=[
359
+ os.path.join(root_path, "condition", "person", _)
360
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
361
+ ],
362
+ examples_per_page=4,
363
+ inputs=cloth_image,
364
+ label="Condition Reference Person Examples",
365
+ )
366
+ gr.Markdown(
367
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
368
+ )
369
+
370
+ image_path.change(
371
+ person_example_fn, inputs=image_path, outputs=person_image
372
+ )
373
+
374
+ submit.click(
375
+ submit_function,
376
+ [
377
+ person_image,
378
+ cloth_image,
379
+ cloth_type,
380
+ num_inference_steps,
381
+ guidance_scale,
382
+ seed,
383
+ show_type,
384
+ ],
385
+ result_image,
386
+ )
387
+ demo.queue().launch(share=True, show_error=True)
388
+
389
+
390
+ if __name__ == "__main__":
391
+ app_gradio()
app_flux.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import gradio as gr
4
+ from datetime import datetime
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from huggingface_hub import snapshot_download
10
+ from PIL import Image
11
+
12
+ from model.cloth_masker import AutoMasker, vis_mask
13
+ from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
14
+ from utils import resize_and_crop, resize_and_padding
15
+
16
+ def parse_args():
17
+ parser = argparse.ArgumentParser(description="FLUX Try-On Demo")
18
+ parser.add_argument(
19
+ "--base_model_path",
20
+ type=str,
21
+ # default="black-forest-labs/FLUX.1-Fill-dev",
22
+ default="Models/FLUX.1-Fill-dev",
23
+ help="The path to the base model to use for evaluation."
24
+ )
25
+ parser.add_argument(
26
+ "--resume_path",
27
+ type=str,
28
+ default="zhengchong/CatVTON",
29
+ help="The Path to the checkpoint of trained tryon model."
30
+ )
31
+ parser.add_argument(
32
+ "--output_dir",
33
+ type=str,
34
+ default="resource/demo/output",
35
+ help="The output directory where the model predictions will be written."
36
+ )
37
+ parser.add_argument(
38
+ "--mixed_precision",
39
+ type=str,
40
+ default="bf16",
41
+ choices=["no", "fp16", "bf16"],
42
+ help="Whether to use mixed precision."
43
+ )
44
+ parser.add_argument(
45
+ "--allow_tf32",
46
+ action="store_true",
47
+ default=True,
48
+ help="Whether or not to allow TF32 on Ampere GPUs."
49
+ )
50
+ parser.add_argument(
51
+ "--width",
52
+ type=int,
53
+ default=768,
54
+ help="The width of the input image."
55
+ )
56
+ parser.add_argument(
57
+ "--height",
58
+ type=int,
59
+ default=1024,
60
+ help="The height of the input image."
61
+ )
62
+ return parser.parse_args()
63
+
64
+ def image_grid(imgs, rows, cols):
65
+ assert len(imgs) == rows * cols
66
+ w, h = imgs[0].size
67
+ grid = Image.new("RGB", size=(cols * w, rows * h))
68
+ for i, img in enumerate(imgs):
69
+ grid.paste(img, box=(i % cols * w, i // cols * h))
70
+ return grid
71
+
72
+
73
+ def submit_function_flux(
74
+ person_image,
75
+ cloth_image,
76
+ cloth_type,
77
+ num_inference_steps,
78
+ guidance_scale,
79
+ seed,
80
+ show_type
81
+ ):
82
+
83
+ # Process image editor input
84
+ person_image, mask = person_image["background"], person_image["layers"][0]
85
+ mask = Image.open(mask).convert("L")
86
+ if len(np.unique(np.array(mask))) == 1:
87
+ mask = None
88
+ else:
89
+ mask = np.array(mask)
90
+ mask[mask > 0] = 255
91
+ mask = Image.fromarray(mask)
92
+
93
+ # Set random seed
94
+ generator = None
95
+ if seed != -1:
96
+ generator = torch.Generator(device='cuda').manual_seed(seed)
97
+
98
+ # Process input images
99
+ person_image = Image.open(person_image).convert("RGB")
100
+ cloth_image = Image.open(cloth_image).convert("RGB")
101
+
102
+ # Adjust image sizes
103
+ person_image = resize_and_crop(person_image, (args.width, args.height))
104
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
105
+
106
+ # Process mask
107
+ if mask is not None:
108
+ mask = resize_and_crop(mask, (args.width, args.height))
109
+ else:
110
+ mask = automasker(
111
+ person_image,
112
+ cloth_type
113
+ )['mask']
114
+ mask = mask_processor.blur(mask, blur_factor=9)
115
+
116
+ # Inference
117
+ result_image = pipeline_flux(
118
+ image=person_image,
119
+ condition_image=cloth_image,
120
+ mask_image=mask,
121
+ height=args.height,
122
+ width=args.width,
123
+ num_inference_steps=num_inference_steps,
124
+ guidance_scale=guidance_scale,
125
+ generator=generator
126
+ ).images[0]
127
+
128
+ # Post-processing
129
+ masked_person = vis_mask(person_image, mask)
130
+
131
+ # Return result based on show type
132
+ if show_type == "result only":
133
+ return result_image
134
+ else:
135
+ width, height = person_image.size
136
+ if show_type == "input & result":
137
+ condition_width = width // 2
138
+ conditions = image_grid([person_image, cloth_image], 2, 1)
139
+ else:
140
+ condition_width = width // 3
141
+ conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
142
+
143
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
144
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
145
+ new_result_image.paste(conditions, (0, 0))
146
+ new_result_image.paste(result_image, (condition_width + 5, 0))
147
+ return new_result_image
148
+
149
+ def person_example_fn(image_path):
150
+ return image_path
151
+
152
+
153
+ def app_gradio():
154
+ with gr.Blocks(title="CatVTON with FLUX.1-Fill-dev") as demo:
155
+ gr.Markdown("# CatVTON with FLUX.1-Fill-dev")
156
+ with gr.Row():
157
+ with gr.Column(scale=1, min_width=350):
158
+ with gr.Row():
159
+ image_path_flux = gr.Image(
160
+ type="filepath",
161
+ interactive=True,
162
+ visible=False,
163
+ )
164
+ person_image_flux = gr.ImageEditor(
165
+ interactive=True, label="Person Image", type="filepath"
166
+ )
167
+
168
+ with gr.Row():
169
+ with gr.Column(scale=1, min_width=230):
170
+ cloth_image_flux = gr.Image(
171
+ interactive=True, label="Condition Image", type="filepath"
172
+ )
173
+ with gr.Column(scale=1, min_width=120):
174
+ gr.Markdown(
175
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
176
+ )
177
+ cloth_type = gr.Radio(
178
+ label="Try-On Cloth Type",
179
+ choices=["upper", "lower", "overall"],
180
+ value="upper",
181
+ )
182
+
183
+ submit_flux = gr.Button("Submit")
184
+ gr.Markdown(
185
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
186
+ )
187
+
188
+ with gr.Accordion("Advanced Options", open=False):
189
+ num_inference_steps_flux = gr.Slider(
190
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
191
+ )
192
+ # Guidence Scale
193
+ guidance_scale_flux = gr.Slider(
194
+ label="CFG Strenth", minimum=0.0, maximum=50, step=0.5, value=30
195
+ )
196
+ # Random Seed
197
+ seed_flux = gr.Slider(
198
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
199
+ )
200
+ show_type = gr.Radio(
201
+ label="Show Type",
202
+ choices=["result only", "input & result", "input & mask & result"],
203
+ value="input & mask & result",
204
+ )
205
+
206
+ with gr.Column(scale=2, min_width=500):
207
+ result_image_flux = gr.Image(interactive=False, label="Result")
208
+ with gr.Row():
209
+ # Photo Examples
210
+ root_path = "resource/demo/example"
211
+ with gr.Column():
212
+ gr.Examples(
213
+ examples=[
214
+ os.path.join(root_path, "person", "men", _)
215
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
216
+ ],
217
+ examples_per_page=4,
218
+ inputs=image_path_flux,
219
+ label="Person Examples ①",
220
+ )
221
+ gr.Examples(
222
+ examples=[
223
+ os.path.join(root_path, "person", "women", _)
224
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
225
+ ],
226
+ examples_per_page=4,
227
+ inputs=image_path_flux,
228
+ label="Person Examples ②",
229
+ )
230
+ gr.Markdown(
231
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
232
+ )
233
+ with gr.Column():
234
+ gr.Examples(
235
+ examples=[
236
+ os.path.join(root_path, "condition", "upper", _)
237
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
238
+ ],
239
+ examples_per_page=4,
240
+ inputs=cloth_image_flux,
241
+ label="Condition Upper Examples",
242
+ )
243
+ gr.Examples(
244
+ examples=[
245
+ os.path.join(root_path, "condition", "overall", _)
246
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
247
+ ],
248
+ examples_per_page=4,
249
+ inputs=cloth_image_flux,
250
+ label="Condition Overall Examples",
251
+ )
252
+ condition_person_exm = gr.Examples(
253
+ examples=[
254
+ os.path.join(root_path, "condition", "person", _)
255
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
256
+ ],
257
+ examples_per_page=4,
258
+ inputs=cloth_image_flux,
259
+ label="Condition Reference Person Examples",
260
+ )
261
+ gr.Markdown(
262
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
263
+ )
264
+
265
+
266
+ image_path_flux.change(
267
+ person_example_fn, inputs=image_path_flux, outputs=person_image_flux
268
+ )
269
+
270
+ submit_flux.click(
271
+ submit_function_flux,
272
+ [person_image_flux, cloth_image_flux, cloth_type, num_inference_steps_flux, guidance_scale_flux, seed_flux, show_type],
273
+ result_image_flux,
274
+ )
275
+
276
+
277
+ demo.queue().launch(share=True, show_error=True)
278
+
279
+ # 解析参数
280
+ args = parse_args()
281
+
282
+ # 加载模型
283
+ repo_path = snapshot_download(repo_id=args.resume_path)
284
+ pipeline_flux = FluxTryOnPipeline.from_pretrained(args.base_model_path)
285
+ pipeline_flux.load_lora_weights(
286
+ os.path.join(repo_path, "flux-lora"),
287
+ weight_name='pytorch_lora_weights.safetensors'
288
+ )
289
+ pipeline_flux.to("cuda", torch.bfloat16)
290
+
291
+ # 初始化 AutoMasker
292
+ mask_processor = VaeImageProcessor(
293
+ vae_scale_factor=8,
294
+ do_normalize=False,
295
+ do_binarize=True,
296
+ do_convert_grayscale=True
297
+ )
298
+ automasker = AutoMasker(
299
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
300
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
301
+ device='cuda'
302
+ )
303
+
304
+ if __name__ == "__main__":
305
+ app_gradio()
app_p2p.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from datetime import datetime
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from huggingface_hub import snapshot_download
10
+ from PIL import Image
11
+
12
+ from model.cloth_masker import AutoMasker, vis_mask
13
+ from model.pipeline import CatVTONPipeline, CatVTONPix2PixPipeline
14
+ from utils import init_weight_dtype, resize_and_crop, resize_and_padding
15
+
16
+
17
+ def parse_args():
18
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
19
+ parser.add_argument(
20
+ "--p2p_base_model_path",
21
+ type=str,
22
+ default="timbrooks/instruct-pix2pix",
23
+ help=(
24
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
25
+ ),
26
+ )
27
+ parser.add_argument(
28
+ "--ip_base_model_path",
29
+ type=str,
30
+ default="booksforcharlie/stable-diffusion-inpainting",
31
+ help=(
32
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
33
+ ),
34
+ )
35
+ parser.add_argument(
36
+ "--p2p_resume_path",
37
+ type=str,
38
+ default="zhengchong/CatVTON-MaskFree",
39
+ help=(
40
+ "The Path to the checkpoint of trained tryon model."
41
+ ),
42
+ )
43
+ parser.add_argument(
44
+ "--ip_resume_path",
45
+ type=str,
46
+ default="zhengchong/CatVTON",
47
+ help=(
48
+ "The Path to the checkpoint of trained tryon model."
49
+ ),
50
+ )
51
+ parser.add_argument(
52
+ "--output_dir",
53
+ type=str,
54
+ default="resource/demo/output",
55
+ help="The output directory where the model predictions will be written.",
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--width",
60
+ type=int,
61
+ default=768,
62
+ help=(
63
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
64
+ " resolution"
65
+ ),
66
+ )
67
+ parser.add_argument(
68
+ "--height",
69
+ type=int,
70
+ default=1024,
71
+ help=(
72
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
73
+ " resolution"
74
+ ),
75
+ )
76
+ parser.add_argument(
77
+ "--repaint",
78
+ action="store_true",
79
+ help="Whether to repaint the result image with the original background."
80
+ )
81
+ parser.add_argument(
82
+ "--allow_tf32",
83
+ action="store_true",
84
+ default=True,
85
+ help=(
86
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
87
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
88
+ ),
89
+ )
90
+ parser.add_argument(
91
+ "--mixed_precision",
92
+ type=str,
93
+ default="bf16",
94
+ choices=["no", "fp16", "bf16"],
95
+ help=(
96
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
97
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
98
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
99
+ ),
100
+ )
101
+
102
+ args = parser.parse_args()
103
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
104
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
105
+ args.local_rank = env_local_rank
106
+
107
+ return args
108
+
109
+ def image_grid(imgs, rows, cols):
110
+ assert len(imgs) == rows * cols
111
+
112
+ w, h = imgs[0].size
113
+ grid = Image.new("RGB", size=(cols * w, rows * h))
114
+
115
+ for i, img in enumerate(imgs):
116
+ grid.paste(img, box=(i % cols * w, i // cols * h))
117
+ return grid
118
+
119
+
120
+ args = parse_args()
121
+ repo_path = snapshot_download(repo_id=args.ip_resume_path)
122
+ # Pipeline
123
+ pipeline_p2p = CatVTONPix2PixPipeline(
124
+ base_ckpt=args.p2p_base_model_path,
125
+ attn_ckpt=repo_path,
126
+ attn_ckpt_version="mix-48k-1024",
127
+ weight_dtype=init_weight_dtype(args.mixed_precision),
128
+ use_tf32=args.allow_tf32,
129
+ device='cuda'
130
+ )
131
+
132
+ # Pipeline
133
+ repo_path = snapshot_download(repo_id=args.ip_resume_path)
134
+ pipeline = CatVTONPipeline(
135
+ base_ckpt=args.ip_base_model_path,
136
+ attn_ckpt=repo_path,
137
+ attn_ckpt_version="mix",
138
+ weight_dtype=init_weight_dtype(args.mixed_precision),
139
+ use_tf32=args.allow_tf32,
140
+ device='cuda'
141
+ )
142
+
143
+ # AutoMasker
144
+ mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
145
+ automasker = AutoMasker(
146
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
147
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
148
+ device='cuda',
149
+ )
150
+
151
+
152
+ def submit_function_p2p(
153
+ person_image,
154
+ cloth_image,
155
+ num_inference_steps,
156
+ guidance_scale,
157
+ seed):
158
+ person_image= person_image["background"]
159
+
160
+ tmp_folder = args.output_dir
161
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
162
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
163
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
164
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
165
+
166
+ generator = None
167
+ if seed != -1:
168
+ generator = torch.Generator(device='cuda').manual_seed(seed)
169
+
170
+ person_image = Image.open(person_image).convert("RGB")
171
+ cloth_image = Image.open(cloth_image).convert("RGB")
172
+ person_image = resize_and_crop(person_image, (args.width, args.height))
173
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
174
+
175
+ # Inference
176
+ try:
177
+ result_image = pipeline_p2p(
178
+ image=person_image,
179
+ condition_image=cloth_image,
180
+ num_inference_steps=num_inference_steps,
181
+ guidance_scale=guidance_scale,
182
+ generator=generator
183
+ )[0]
184
+ except Exception as e:
185
+ raise gr.Error(
186
+ "An error occurred. Please try again later: {}".format(e)
187
+ )
188
+
189
+ # Post-process
190
+ save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
191
+ save_result_image.save(result_save_path)
192
+ return result_image
193
+
194
+ def submit_function(
195
+ person_image,
196
+ cloth_image,
197
+ cloth_type,
198
+ num_inference_steps,
199
+ guidance_scale,
200
+ seed,
201
+ show_type
202
+ ):
203
+ person_image, mask = person_image["background"], person_image["layers"][0]
204
+ mask = Image.open(mask).convert("L")
205
+ if len(np.unique(np.array(mask))) == 1:
206
+ mask = None
207
+ else:
208
+ mask = np.array(mask)
209
+ mask[mask > 0] = 255
210
+ mask = Image.fromarray(mask)
211
+
212
+ tmp_folder = args.output_dir
213
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
214
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
215
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
216
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
217
+
218
+ generator = None
219
+ if seed != -1:
220
+ generator = torch.Generator(device='cuda').manual_seed(seed)
221
+
222
+ person_image = Image.open(person_image).convert("RGB")
223
+ cloth_image = Image.open(cloth_image).convert("RGB")
224
+ person_image = resize_and_crop(person_image, (args.width, args.height))
225
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
226
+
227
+ # Process mask
228
+ if mask is not None:
229
+ mask = resize_and_crop(mask, (args.width, args.height))
230
+ else:
231
+ mask = automasker(
232
+ person_image,
233
+ cloth_type
234
+ )['mask']
235
+ mask = mask_processor.blur(mask, blur_factor=9)
236
+
237
+ # Inference
238
+ # try:
239
+ result_image = pipeline(
240
+ image=person_image,
241
+ condition_image=cloth_image,
242
+ mask=mask,
243
+ num_inference_steps=num_inference_steps,
244
+ guidance_scale=guidance_scale,
245
+ generator=generator
246
+ )[0]
247
+ # except Exception as e:
248
+ # raise gr.Error(
249
+ # "An error occurred. Please try again later: {}".format(e)
250
+ # )
251
+
252
+ # Post-process
253
+ masked_person = vis_mask(person_image, mask)
254
+ save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
255
+ save_result_image.save(result_save_path)
256
+ if show_type == "result only":
257
+ return result_image
258
+ else:
259
+ width, height = person_image.size
260
+ if show_type == "input & result":
261
+ condition_width = width // 2
262
+ conditions = image_grid([person_image, cloth_image], 2, 1)
263
+ else:
264
+ condition_width = width // 3
265
+ conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
266
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
267
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
268
+ new_result_image.paste(conditions, (0, 0))
269
+ new_result_image.paste(result_image, (condition_width + 5, 0))
270
+ return new_result_image
271
+
272
+
273
+
274
+ def person_example_fn(image_path):
275
+ return image_path
276
+
277
+ HEADER = """
278
+ <h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
279
+ <div style="display: flex; justify-content: center; align-items: center;">
280
+ <a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
281
+ <img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
282
+ </a>
283
+ <a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
284
+ <img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
285
+ </a>
286
+ <a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
287
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
288
+ </a>
289
+ <a href="http://120.76.142.206:8888" style="margin: 0 2px;">
290
+ <img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
291
+ </a>
292
+ <a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
293
+ <img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
294
+ </a>
295
+ <a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
296
+ <img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
297
+ </a>
298
+ <a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
299
+ <img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
300
+ </a>
301
+ </div>
302
+ <br>
303
+ · This demo and our weights are only for <span>Non-commercial Use</span>. <br>
304
+ · You can try CatVTON in our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a> or our <a href="http://120.76.142.206:8888">online demo</a> (run on 3090). <br>
305
+ · Thanks to <a href="https://huggingface.co/zero-gpu-explorers">ZeroGPU</a> for providing A100 for our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a>. <br>
306
+ · SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
307
+ """
308
+
309
+ def app_gradio():
310
+ with gr.Blocks(title="CatVTON") as demo:
311
+ gr.Markdown(HEADER)
312
+ with gr.Tab("Mask-based Virtual Try-On"):
313
+ with gr.Row():
314
+ with gr.Column(scale=1, min_width=350):
315
+ with gr.Row():
316
+ image_path = gr.Image(
317
+ type="filepath",
318
+ interactive=True,
319
+ visible=False,
320
+ )
321
+ person_image = gr.ImageEditor(
322
+ interactive=True, label="Person Image", type="filepath"
323
+ )
324
+
325
+ with gr.Row():
326
+ with gr.Column(scale=1, min_width=230):
327
+ cloth_image = gr.Image(
328
+ interactive=True, label="Condition Image", type="filepath"
329
+ )
330
+ with gr.Column(scale=1, min_width=120):
331
+ gr.Markdown(
332
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
333
+ )
334
+ cloth_type = gr.Radio(
335
+ label="Try-On Cloth Type",
336
+ choices=["upper", "lower", "overall"],
337
+ value="upper",
338
+ )
339
+
340
+
341
+ submit = gr.Button("Submit")
342
+ gr.Markdown(
343
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
344
+ )
345
+
346
+ gr.Markdown(
347
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
348
+ )
349
+ with gr.Accordion("Advanced Options", open=False):
350
+ num_inference_steps = gr.Slider(
351
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
352
+ )
353
+ # Guidence Scale
354
+ guidance_scale = gr.Slider(
355
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
356
+ )
357
+ # Random Seed
358
+ seed = gr.Slider(
359
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
360
+ )
361
+ show_type = gr.Radio(
362
+ label="Show Type",
363
+ choices=["result only", "input & result", "input & mask & result"],
364
+ value="input & mask & result",
365
+ )
366
+
367
+ with gr.Column(scale=2, min_width=500):
368
+ result_image = gr.Image(interactive=False, label="Result")
369
+ with gr.Row():
370
+ # Photo Examples
371
+ root_path = "resource/demo/example"
372
+ with gr.Column():
373
+ men_exm = gr.Examples(
374
+ examples=[
375
+ os.path.join(root_path, "person", "men", _)
376
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
377
+ ],
378
+ examples_per_page=4,
379
+ inputs=image_path,
380
+ label="Person Examples ①",
381
+ )
382
+ women_exm = gr.Examples(
383
+ examples=[
384
+ os.path.join(root_path, "person", "women", _)
385
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
386
+ ],
387
+ examples_per_page=4,
388
+ inputs=image_path,
389
+ label="Person Examples ②",
390
+ )
391
+ gr.Markdown(
392
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
393
+ )
394
+ with gr.Column():
395
+ condition_upper_exm = gr.Examples(
396
+ examples=[
397
+ os.path.join(root_path, "condition", "upper", _)
398
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
399
+ ],
400
+ examples_per_page=4,
401
+ inputs=cloth_image,
402
+ label="Condition Upper Examples",
403
+ )
404
+ condition_overall_exm = gr.Examples(
405
+ examples=[
406
+ os.path.join(root_path, "condition", "overall", _)
407
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
408
+ ],
409
+ examples_per_page=4,
410
+ inputs=cloth_image,
411
+ label="Condition Overall Examples",
412
+ )
413
+ condition_person_exm = gr.Examples(
414
+ examples=[
415
+ os.path.join(root_path, "condition", "person", _)
416
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
417
+ ],
418
+ examples_per_page=4,
419
+ inputs=cloth_image,
420
+ label="Condition Reference Person Examples",
421
+ )
422
+ gr.Markdown(
423
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
424
+ )
425
+
426
+ image_path.change(
427
+ person_example_fn, inputs=image_path, outputs=person_image
428
+ )
429
+
430
+ submit.click(
431
+ submit_function,
432
+ [
433
+ person_image,
434
+ cloth_image,
435
+ cloth_type,
436
+ num_inference_steps,
437
+ guidance_scale,
438
+ seed,
439
+ show_type,
440
+ ],
441
+ result_image,
442
+ )
443
+
444
+ with gr.Tab("Mask-Free Virtual Try-On"):
445
+ with gr.Row():
446
+ with gr.Column(scale=1, min_width=350):
447
+ with gr.Row():
448
+ image_path_p2p = gr.Image(
449
+ type="filepath",
450
+ interactive=True,
451
+ visible=False,
452
+ )
453
+ person_image_p2p = gr.ImageEditor(
454
+ interactive=True, label="Person Image", type="filepath"
455
+ )
456
+
457
+ with gr.Row():
458
+ with gr.Column(scale=1, min_width=230):
459
+ cloth_image_p2p = gr.Image(
460
+ interactive=True, label="Condition Image", type="filepath"
461
+ )
462
+
463
+ submit_p2p = gr.Button("Submit")
464
+ gr.Markdown(
465
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
466
+ )
467
+
468
+ gr.Markdown(
469
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
470
+ )
471
+ with gr.Accordion("Advanced Options", open=False):
472
+ num_inference_steps_p2p = gr.Slider(
473
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
474
+ )
475
+ # Guidence Scale
476
+ guidance_scale_p2p = gr.Slider(
477
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
478
+ )
479
+ # Random Seed
480
+ seed_p2p = gr.Slider(
481
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
482
+ )
483
+ # show_type = gr.Radio(
484
+ # label="Show Type",
485
+ # choices=["result only", "input & result", "input & mask & result"],
486
+ # value="input & mask & result",
487
+ # )
488
+
489
+ with gr.Column(scale=2, min_width=500):
490
+ result_image_p2p = gr.Image(interactive=False, label="Result")
491
+ with gr.Row():
492
+ # Photo Examples
493
+ root_path = "resource/demo/example"
494
+ with gr.Column():
495
+ gr.Examples(
496
+ examples=[
497
+ os.path.join(root_path, "person", "men", _)
498
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
499
+ ],
500
+ examples_per_page=4,
501
+ inputs=image_path_p2p,
502
+ label="Person Examples ①",
503
+ )
504
+ gr.Examples(
505
+ examples=[
506
+ os.path.join(root_path, "person", "women", _)
507
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
508
+ ],
509
+ examples_per_page=4,
510
+ inputs=image_path_p2p,
511
+ label="Person Examples ②",
512
+ )
513
+ gr.Markdown(
514
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
515
+ )
516
+ with gr.Column():
517
+ gr.Examples(
518
+ examples=[
519
+ os.path.join(root_path, "condition", "upper", _)
520
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
521
+ ],
522
+ examples_per_page=4,
523
+ inputs=cloth_image_p2p,
524
+ label="Condition Upper Examples",
525
+ )
526
+ gr.Examples(
527
+ examples=[
528
+ os.path.join(root_path, "condition", "overall", _)
529
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
530
+ ],
531
+ examples_per_page=4,
532
+ inputs=cloth_image_p2p,
533
+ label="Condition Overall Examples",
534
+ )
535
+ condition_person_exm = gr.Examples(
536
+ examples=[
537
+ os.path.join(root_path, "condition", "person", _)
538
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
539
+ ],
540
+ examples_per_page=4,
541
+ inputs=cloth_image_p2p,
542
+ label="Condition Reference Person Examples",
543
+ )
544
+ gr.Markdown(
545
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
546
+ )
547
+
548
+ image_path_p2p.change(
549
+ person_example_fn, inputs=image_path_p2p, outputs=person_image_p2p
550
+ )
551
+
552
+ submit_p2p.click(
553
+ submit_function_p2p,
554
+ [
555
+ person_image_p2p,
556
+ cloth_image_p2p,
557
+ num_inference_steps_p2p,
558
+ guidance_scale_p2p,
559
+ seed_p2p],
560
+ result_image_p2p,
561
+ )
562
+
563
+ demo.queue().launch(share=True, show_error=True)
564
+
565
+
566
+ if __name__ == "__main__":
567
+ app_gradio()
densepose/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from .data.datasets import builtin # just to register data
5
+ from .converters import builtin as builtin_converters # register converters
6
+ from .config import (
7
+ add_densepose_config,
8
+ add_densepose_head_config,
9
+ add_hrnet_config,
10
+ add_dataset_category_config,
11
+ add_bootstrap_config,
12
+ load_bootstrap_config,
13
+ )
14
+ from .structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
15
+ from .evaluation import DensePoseCOCOEvaluator
16
+ from .modeling.roi_heads import DensePoseROIHeads
17
+ from .modeling.test_time_augmentation import (
18
+ DensePoseGeneralizedRCNNWithTTA,
19
+ DensePoseDatasetMapperTTA,
20
+ )
21
+ from .utils.transform import load_from_cfg
22
+ from .modeling.hrfpn import build_hrfpn_backbone
densepose/config.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding = utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # pyre-ignore-all-errors
4
+
5
+ from detectron2.config import CfgNode as CN
6
+
7
+
8
+ def add_dataset_category_config(cfg: CN) -> None:
9
+ """
10
+ Add config for additional category-related dataset options
11
+ - category whitelisting
12
+ - category mapping
13
+ """
14
+ _C = cfg
15
+ _C.DATASETS.CATEGORY_MAPS = CN(new_allowed=True)
16
+ _C.DATASETS.WHITELISTED_CATEGORIES = CN(new_allowed=True)
17
+ # class to mesh mapping
18
+ _C.DATASETS.CLASS_TO_MESH_NAME_MAPPING = CN(new_allowed=True)
19
+
20
+
21
+ def add_evaluation_config(cfg: CN) -> None:
22
+ _C = cfg
23
+ _C.DENSEPOSE_EVALUATION = CN()
24
+ # evaluator type, possible values:
25
+ # - "iou": evaluator for models that produce iou data
26
+ # - "cse": evaluator for models that produce cse data
27
+ _C.DENSEPOSE_EVALUATION.TYPE = "iou"
28
+ # storage for DensePose results, possible values:
29
+ # - "none": no explicit storage, all the results are stored in the
30
+ # dictionary with predictions, memory intensive;
31
+ # historically the default storage type
32
+ # - "ram": RAM storage, uses per-process RAM storage, which is
33
+ # reduced to a single process storage on later stages,
34
+ # less memory intensive
35
+ # - "file": file storage, uses per-process file-based storage,
36
+ # the least memory intensive, but may create bottlenecks
37
+ # on file system accesses
38
+ _C.DENSEPOSE_EVALUATION.STORAGE = "none"
39
+ # minimum threshold for IOU values: the lower its values is,
40
+ # the more matches are produced (and the higher the AP score)
41
+ _C.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD = 0.5
42
+ # Non-distributed inference is slower (at inference time) but can avoid RAM OOM
43
+ _C.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE = True
44
+ # evaluate mesh alignment based on vertex embeddings, only makes sense in CSE context
45
+ _C.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT = False
46
+ # meshes to compute mesh alignment for
47
+ _C.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES = []
48
+
49
+
50
+ def add_bootstrap_config(cfg: CN) -> None:
51
+ """ """
52
+ _C = cfg
53
+ _C.BOOTSTRAP_DATASETS = []
54
+ _C.BOOTSTRAP_MODEL = CN()
55
+ _C.BOOTSTRAP_MODEL.WEIGHTS = ""
56
+ _C.BOOTSTRAP_MODEL.DEVICE = "cuda"
57
+
58
+
59
+ def get_bootstrap_dataset_config() -> CN:
60
+ _C = CN()
61
+ _C.DATASET = ""
62
+ # ratio used to mix data loaders
63
+ _C.RATIO = 0.1
64
+ # image loader
65
+ _C.IMAGE_LOADER = CN(new_allowed=True)
66
+ _C.IMAGE_LOADER.TYPE = ""
67
+ _C.IMAGE_LOADER.BATCH_SIZE = 4
68
+ _C.IMAGE_LOADER.NUM_WORKERS = 4
69
+ _C.IMAGE_LOADER.CATEGORIES = []
70
+ _C.IMAGE_LOADER.MAX_COUNT_PER_CATEGORY = 1_000_000
71
+ _C.IMAGE_LOADER.CATEGORY_TO_CLASS_MAPPING = CN(new_allowed=True)
72
+ # inference
73
+ _C.INFERENCE = CN()
74
+ # batch size for model inputs
75
+ _C.INFERENCE.INPUT_BATCH_SIZE = 4
76
+ # batch size to group model outputs
77
+ _C.INFERENCE.OUTPUT_BATCH_SIZE = 2
78
+ # sampled data
79
+ _C.DATA_SAMPLER = CN(new_allowed=True)
80
+ _C.DATA_SAMPLER.TYPE = ""
81
+ _C.DATA_SAMPLER.USE_GROUND_TRUTH_CATEGORIES = False
82
+ # filter
83
+ _C.FILTER = CN(new_allowed=True)
84
+ _C.FILTER.TYPE = ""
85
+ return _C
86
+
87
+
88
+ def load_bootstrap_config(cfg: CN) -> None:
89
+ """
90
+ Bootstrap datasets are given as a list of `dict` that are not automatically
91
+ converted into CfgNode. This method processes all bootstrap dataset entries
92
+ and ensures that they are in CfgNode format and comply with the specification
93
+ """
94
+ if not cfg.BOOTSTRAP_DATASETS:
95
+ return
96
+
97
+ bootstrap_datasets_cfgnodes = []
98
+ for dataset_cfg in cfg.BOOTSTRAP_DATASETS:
99
+ _C = get_bootstrap_dataset_config().clone()
100
+ _C.merge_from_other_cfg(CN(dataset_cfg))
101
+ bootstrap_datasets_cfgnodes.append(_C)
102
+ cfg.BOOTSTRAP_DATASETS = bootstrap_datasets_cfgnodes
103
+
104
+
105
+ def add_densepose_head_cse_config(cfg: CN) -> None:
106
+ """
107
+ Add configuration options for Continuous Surface Embeddings (CSE)
108
+ """
109
+ _C = cfg
110
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE = CN()
111
+ # Dimensionality D of the embedding space
112
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE = 16
113
+ # Embedder specifications for various mesh IDs
114
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS = CN(new_allowed=True)
115
+ # normalization coefficient for embedding distances
116
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA = 0.01
117
+ # normalization coefficient for geodesic distances
118
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA = 0.01
119
+ # embedding loss weight
120
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT = 0.6
121
+ # embedding loss name, currently the following options are supported:
122
+ # - EmbeddingLoss: cross-entropy on vertex labels
123
+ # - SoftEmbeddingLoss: cross-entropy on vertex label combined with
124
+ # Gaussian penalty on distance between vertices
125
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME = "EmbeddingLoss"
126
+ # optimizer hyperparameters
127
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR = 1.0
128
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR = 1.0
129
+ # Shape to shape cycle consistency loss parameters:
130
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
131
+ # shape to shape cycle consistency loss weight
132
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.025
133
+ # norm type used for loss computation
134
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
135
+ # normalization term for embedding similarity matrices
136
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.TEMPERATURE = 0.05
137
+ # maximum number of vertices to include into shape to shape cycle loss
138
+ # if negative or zero, all vertices are considered
139
+ # if positive, random subset of vertices of given size is considered
140
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.MAX_NUM_VERTICES = 4936
141
+ # Pixel to shape cycle consistency loss parameters:
142
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
143
+ # pixel to shape cycle consistency loss weight
144
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.0001
145
+ # norm type used for loss computation
146
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
147
+ # map images to all meshes and back (if false, use only gt meshes from the batch)
148
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY = False
149
+ # Randomly select at most this number of pixels from every instance
150
+ # if negative or zero, all vertices are considered
151
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE = 100
152
+ # normalization factor for pixel to pixel distances (higher value = smoother distribution)
153
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA = 5.0
154
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX = 0.05
155
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL = 0.05
156
+
157
+
158
+ def add_densepose_head_config(cfg: CN) -> None:
159
+ """
160
+ Add config for densepose head.
161
+ """
162
+ _C = cfg
163
+
164
+ _C.MODEL.DENSEPOSE_ON = True
165
+
166
+ _C.MODEL.ROI_DENSEPOSE_HEAD = CN()
167
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NAME = ""
168
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS = 8
169
+ # Number of parts used for point labels
170
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES = 24
171
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL = 4
172
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM = 512
173
+ _C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL = 3
174
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE = 2
175
+ _C.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE = 112
176
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_TYPE = "ROIAlignV2"
177
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_RESOLUTION = 28
178
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_SAMPLING_RATIO = 2
179
+ _C.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS = 2 # 15 or 2
180
+ # Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD)
181
+ _C.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD = 0.7
182
+ # Loss weights for annotation masks.(14 Parts)
183
+ _C.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS = 5.0
184
+ # Loss weights for surface parts. (24 Parts)
185
+ _C.MODEL.ROI_DENSEPOSE_HEAD.PART_WEIGHTS = 1.0
186
+ # Loss weights for UV regression.
187
+ _C.MODEL.ROI_DENSEPOSE_HEAD.POINT_REGRESSION_WEIGHTS = 0.01
188
+ # Coarse segmentation is trained using instance segmentation task data
189
+ _C.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS = False
190
+ # For Decoder
191
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_ON = True
192
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NUM_CLASSES = 256
193
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_CONV_DIMS = 256
194
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NORM = ""
195
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_COMMON_STRIDE = 4
196
+ # For DeepLab head
197
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB = CN()
198
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM = "GN"
199
+ _C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON = 0
200
+ # Predictor class name, must be registered in DENSEPOSE_PREDICTOR_REGISTRY
201
+ # Some registered predictors:
202
+ # "DensePoseChartPredictor": predicts segmentation and UV coordinates for predefined charts
203
+ # "DensePoseChartWithConfidencePredictor": predicts segmentation, UV coordinates
204
+ # and associated confidences for predefined charts (default)
205
+ # "DensePoseEmbeddingWithConfidencePredictor": predicts segmentation, embeddings
206
+ # and associated confidences for CSE
207
+ _C.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME = "DensePoseChartWithConfidencePredictor"
208
+ # Loss class name, must be registered in DENSEPOSE_LOSS_REGISTRY
209
+ # Some registered losses:
210
+ # "DensePoseChartLoss": loss for chart-based models that estimate
211
+ # segmentation and UV coordinates
212
+ # "DensePoseChartWithConfidenceLoss": loss for chart-based models that estimate
213
+ # segmentation, UV coordinates and the corresponding confidences (default)
214
+ _C.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME = "DensePoseChartWithConfidenceLoss"
215
+ # Confidences
216
+ # Enable learning UV confidences (variances) along with the actual values
217
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE = CN({"ENABLED": False})
218
+ # UV confidence lower bound
219
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.EPSILON = 0.01
220
+ # Enable learning segmentation confidences (variances) along with the actual values
221
+ _C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE = CN({"ENABLED": False})
222
+ # Segmentation confidence lower bound
223
+ _C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.EPSILON = 0.01
224
+ # Statistical model type for confidence learning, possible values:
225
+ # - "iid_iso": statistically independent identically distributed residuals
226
+ # with isotropic covariance
227
+ # - "indep_aniso": statistically independent residuals with anisotropic
228
+ # covariances
229
+ _C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.TYPE = "iid_iso"
230
+ # List of angles for rotation in data augmentation during training
231
+ _C.INPUT.ROTATION_ANGLES = [0]
232
+ _C.TEST.AUG.ROTATION_ANGLES = () # Rotation TTA
233
+
234
+ add_densepose_head_cse_config(cfg)
235
+
236
+
237
+ def add_hrnet_config(cfg: CN) -> None:
238
+ """
239
+ Add config for HRNet backbone.
240
+ """
241
+ _C = cfg
242
+
243
+ # For HigherHRNet w32
244
+ _C.MODEL.HRNET = CN()
245
+ _C.MODEL.HRNET.STEM_INPLANES = 64
246
+ _C.MODEL.HRNET.STAGE2 = CN()
247
+ _C.MODEL.HRNET.STAGE2.NUM_MODULES = 1
248
+ _C.MODEL.HRNET.STAGE2.NUM_BRANCHES = 2
249
+ _C.MODEL.HRNET.STAGE2.BLOCK = "BASIC"
250
+ _C.MODEL.HRNET.STAGE2.NUM_BLOCKS = [4, 4]
251
+ _C.MODEL.HRNET.STAGE2.NUM_CHANNELS = [32, 64]
252
+ _C.MODEL.HRNET.STAGE2.FUSE_METHOD = "SUM"
253
+ _C.MODEL.HRNET.STAGE3 = CN()
254
+ _C.MODEL.HRNET.STAGE3.NUM_MODULES = 4
255
+ _C.MODEL.HRNET.STAGE3.NUM_BRANCHES = 3
256
+ _C.MODEL.HRNET.STAGE3.BLOCK = "BASIC"
257
+ _C.MODEL.HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4]
258
+ _C.MODEL.HRNET.STAGE3.NUM_CHANNELS = [32, 64, 128]
259
+ _C.MODEL.HRNET.STAGE3.FUSE_METHOD = "SUM"
260
+ _C.MODEL.HRNET.STAGE4 = CN()
261
+ _C.MODEL.HRNET.STAGE4.NUM_MODULES = 3
262
+ _C.MODEL.HRNET.STAGE4.NUM_BRANCHES = 4
263
+ _C.MODEL.HRNET.STAGE4.BLOCK = "BASIC"
264
+ _C.MODEL.HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
265
+ _C.MODEL.HRNET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
266
+ _C.MODEL.HRNET.STAGE4.FUSE_METHOD = "SUM"
267
+
268
+ _C.MODEL.HRNET.HRFPN = CN()
269
+ _C.MODEL.HRNET.HRFPN.OUT_CHANNELS = 256
270
+
271
+
272
+ def add_densepose_config(cfg: CN) -> None:
273
+ add_densepose_head_config(cfg)
274
+ add_hrnet_config(cfg)
275
+ add_bootstrap_config(cfg)
276
+ add_dataset_category_config(cfg)
277
+ add_evaluation_config(cfg)
densepose/converters/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .hflip import HFlipConverter
6
+ from .to_mask import ToMaskConverter
7
+ from .to_chart_result import ToChartResultConverter, ToChartResultConverterWithConfidences
8
+ from .segm_to_mask import (
9
+ predictor_output_with_fine_and_coarse_segm_to_mask,
10
+ predictor_output_with_coarse_segm_to_mask,
11
+ resample_fine_and_coarse_segm_to_bbox,
12
+ )
13
+ from .chart_output_to_chart_result import (
14
+ densepose_chart_predictor_output_to_result,
15
+ densepose_chart_predictor_output_to_result_with_confidences,
16
+ )
17
+ from .chart_output_hflip import densepose_chart_predictor_output_hflip
densepose/converters/base.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Tuple, Type
6
+ import torch
7
+
8
+
9
+ class BaseConverter:
10
+ """
11
+ Converter base class to be reused by various converters.
12
+ Converter allows one to convert data from various source types to a particular
13
+ destination type. Each source type needs to register its converter. The
14
+ registration for each source type is valid for all descendants of that type.
15
+ """
16
+
17
+ @classmethod
18
+ def register(cls, from_type: Type, converter: Any = None):
19
+ """
20
+ Registers a converter for the specified type.
21
+ Can be used as a decorator (if converter is None), or called as a method.
22
+
23
+ Args:
24
+ from_type (type): type to register the converter for;
25
+ all instances of this type will use the same converter
26
+ converter (callable): converter to be registered for the given
27
+ type; if None, this method is assumed to be a decorator for the converter
28
+ """
29
+
30
+ if converter is not None:
31
+ cls._do_register(from_type, converter)
32
+
33
+ def wrapper(converter: Any) -> Any:
34
+ cls._do_register(from_type, converter)
35
+ return converter
36
+
37
+ return wrapper
38
+
39
+ @classmethod
40
+ def _do_register(cls, from_type: Type, converter: Any):
41
+ cls.registry[from_type] = converter # pyre-ignore[16]
42
+
43
+ @classmethod
44
+ def _lookup_converter(cls, from_type: Type) -> Any:
45
+ """
46
+ Perform recursive lookup for the given type
47
+ to find registered converter. If a converter was found for some base
48
+ class, it gets registered for this class to save on further lookups.
49
+
50
+ Args:
51
+ from_type: type for which to find a converter
52
+ Return:
53
+ callable or None - registered converter or None
54
+ if no suitable entry was found in the registry
55
+ """
56
+ if from_type in cls.registry: # pyre-ignore[16]
57
+ return cls.registry[from_type]
58
+ for base in from_type.__bases__:
59
+ converter = cls._lookup_converter(base)
60
+ if converter is not None:
61
+ cls._do_register(from_type, converter)
62
+ return converter
63
+ return None
64
+
65
+ @classmethod
66
+ def convert(cls, instance: Any, *args, **kwargs):
67
+ """
68
+ Convert an instance to the destination type using some registered
69
+ converter. Does recursive lookup for base classes, so there's no need
70
+ for explicit registration for derived classes.
71
+
72
+ Args:
73
+ instance: source instance to convert to the destination type
74
+ Return:
75
+ An instance of the destination type obtained from the source instance
76
+ Raises KeyError, if no suitable converter found
77
+ """
78
+ instance_type = type(instance)
79
+ converter = cls._lookup_converter(instance_type)
80
+ if converter is None:
81
+ if cls.dst_type is None: # pyre-ignore[16]
82
+ output_type_str = "itself"
83
+ else:
84
+ output_type_str = cls.dst_type
85
+ raise KeyError(f"Could not find converter from {instance_type} to {output_type_str}")
86
+ return converter(instance, *args, **kwargs)
87
+
88
+
89
+ IntTupleBox = Tuple[int, int, int, int]
90
+
91
+
92
+ def make_int_box(box: torch.Tensor) -> IntTupleBox:
93
+ int_box = [0, 0, 0, 0]
94
+ int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
95
+ return int_box[0], int_box[1], int_box[2], int_box[3]
densepose/converters/builtin.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from ..structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
6
+ from . import (
7
+ HFlipConverter,
8
+ ToChartResultConverter,
9
+ ToChartResultConverterWithConfidences,
10
+ ToMaskConverter,
11
+ densepose_chart_predictor_output_hflip,
12
+ densepose_chart_predictor_output_to_result,
13
+ densepose_chart_predictor_output_to_result_with_confidences,
14
+ predictor_output_with_coarse_segm_to_mask,
15
+ predictor_output_with_fine_and_coarse_segm_to_mask,
16
+ )
17
+
18
+ ToMaskConverter.register(
19
+ DensePoseChartPredictorOutput, predictor_output_with_fine_and_coarse_segm_to_mask
20
+ )
21
+ ToMaskConverter.register(
22
+ DensePoseEmbeddingPredictorOutput, predictor_output_with_coarse_segm_to_mask
23
+ )
24
+
25
+ ToChartResultConverter.register(
26
+ DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result
27
+ )
28
+
29
+ ToChartResultConverterWithConfidences.register(
30
+ DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result_with_confidences
31
+ )
32
+
33
+ HFlipConverter.register(DensePoseChartPredictorOutput, densepose_chart_predictor_output_hflip)
densepose/converters/chart_output_hflip.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+ from dataclasses import fields
5
+ import torch
6
+
7
+ from densepose.structures import DensePoseChartPredictorOutput, DensePoseTransformData
8
+
9
+
10
+ def densepose_chart_predictor_output_hflip(
11
+ densepose_predictor_output: DensePoseChartPredictorOutput,
12
+ transform_data: DensePoseTransformData,
13
+ ) -> DensePoseChartPredictorOutput:
14
+ """
15
+ Change to take into account a Horizontal flip.
16
+ """
17
+ if len(densepose_predictor_output) > 0:
18
+
19
+ PredictorOutput = type(densepose_predictor_output)
20
+ output_dict = {}
21
+
22
+ for field in fields(densepose_predictor_output):
23
+ field_value = getattr(densepose_predictor_output, field.name)
24
+ # flip tensors
25
+ if isinstance(field_value, torch.Tensor):
26
+ setattr(densepose_predictor_output, field.name, torch.flip(field_value, [3]))
27
+
28
+ densepose_predictor_output = _flip_iuv_semantics_tensor(
29
+ densepose_predictor_output, transform_data
30
+ )
31
+ densepose_predictor_output = _flip_segm_semantics_tensor(
32
+ densepose_predictor_output, transform_data
33
+ )
34
+
35
+ for field in fields(densepose_predictor_output):
36
+ output_dict[field.name] = getattr(densepose_predictor_output, field.name)
37
+
38
+ return PredictorOutput(**output_dict)
39
+ else:
40
+ return densepose_predictor_output
41
+
42
+
43
+ def _flip_iuv_semantics_tensor(
44
+ densepose_predictor_output: DensePoseChartPredictorOutput,
45
+ dp_transform_data: DensePoseTransformData,
46
+ ) -> DensePoseChartPredictorOutput:
47
+ point_label_symmetries = dp_transform_data.point_label_symmetries
48
+ uv_symmetries = dp_transform_data.uv_symmetries
49
+
50
+ N, C, H, W = densepose_predictor_output.u.shape
51
+ u_loc = (densepose_predictor_output.u[:, 1:, :, :].clamp(0, 1) * 255).long()
52
+ v_loc = (densepose_predictor_output.v[:, 1:, :, :].clamp(0, 1) * 255).long()
53
+ Iindex = torch.arange(C - 1, device=densepose_predictor_output.u.device)[
54
+ None, :, None, None
55
+ ].expand(N, C - 1, H, W)
56
+ densepose_predictor_output.u[:, 1:, :, :] = uv_symmetries["U_transforms"][Iindex, v_loc, u_loc]
57
+ densepose_predictor_output.v[:, 1:, :, :] = uv_symmetries["V_transforms"][Iindex, v_loc, u_loc]
58
+
59
+ for el in ["fine_segm", "u", "v"]:
60
+ densepose_predictor_output.__dict__[el] = densepose_predictor_output.__dict__[el][
61
+ :, point_label_symmetries, :, :
62
+ ]
63
+ return densepose_predictor_output
64
+
65
+
66
+ def _flip_segm_semantics_tensor(
67
+ densepose_predictor_output: DensePoseChartPredictorOutput, dp_transform_data
68
+ ):
69
+ if densepose_predictor_output.coarse_segm.shape[1] > 2:
70
+ densepose_predictor_output.coarse_segm = densepose_predictor_output.coarse_segm[
71
+ :, dp_transform_data.mask_label_symmetries, :, :
72
+ ]
73
+ return densepose_predictor_output
densepose/converters/chart_output_to_chart_result.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Dict
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.structures.boxes import Boxes, BoxMode
10
+
11
+ from ..structures import (
12
+ DensePoseChartPredictorOutput,
13
+ DensePoseChartResult,
14
+ DensePoseChartResultWithConfidences,
15
+ )
16
+ from . import resample_fine_and_coarse_segm_to_bbox
17
+ from .base import IntTupleBox, make_int_box
18
+
19
+
20
+ def resample_uv_tensors_to_bbox(
21
+ u: torch.Tensor,
22
+ v: torch.Tensor,
23
+ labels: torch.Tensor,
24
+ box_xywh_abs: IntTupleBox,
25
+ ) -> torch.Tensor:
26
+ """
27
+ Resamples U and V coordinate estimates for the given bounding box
28
+
29
+ Args:
30
+ u (tensor [1, C, H, W] of float): U coordinates
31
+ v (tensor [1, C, H, W] of float): V coordinates
32
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
33
+ outputs for the given bounding box
34
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
35
+ Return:
36
+ Resampled U and V coordinates - a tensor [2, H, W] of float
37
+ """
38
+ x, y, w, h = box_xywh_abs
39
+ w = max(int(w), 1)
40
+ h = max(int(h), 1)
41
+ u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
42
+ v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
43
+ uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
44
+ for part_id in range(1, u_bbox.size(1)):
45
+ uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
46
+ uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
47
+ return uv
48
+
49
+
50
+ def resample_uv_to_bbox(
51
+ predictor_output: DensePoseChartPredictorOutput,
52
+ labels: torch.Tensor,
53
+ box_xywh_abs: IntTupleBox,
54
+ ) -> torch.Tensor:
55
+ """
56
+ Resamples U and V coordinate estimates for the given bounding box
57
+
58
+ Args:
59
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
60
+ output to be resampled
61
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
62
+ outputs for the given bounding box
63
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
64
+ Return:
65
+ Resampled U and V coordinates - a tensor [2, H, W] of float
66
+ """
67
+ return resample_uv_tensors_to_bbox(
68
+ predictor_output.u,
69
+ predictor_output.v,
70
+ labels,
71
+ box_xywh_abs,
72
+ )
73
+
74
+
75
+ def densepose_chart_predictor_output_to_result(
76
+ predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
77
+ ) -> DensePoseChartResult:
78
+ """
79
+ Convert densepose chart predictor outputs to results
80
+
81
+ Args:
82
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
83
+ output to be converted to results, must contain only 1 output
84
+ boxes (Boxes): bounding box that corresponds to the predictor output,
85
+ must contain only 1 bounding box
86
+ Return:
87
+ DensePose chart-based result (DensePoseChartResult)
88
+ """
89
+ assert len(predictor_output) == 1 and len(boxes) == 1, (
90
+ f"Predictor output to result conversion can operate only single outputs"
91
+ f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
92
+ )
93
+
94
+ boxes_xyxy_abs = boxes.tensor.clone()
95
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
96
+ box_xywh = make_int_box(boxes_xywh_abs[0])
97
+
98
+ labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
99
+ uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
100
+ return DensePoseChartResult(labels=labels, uv=uv)
101
+
102
+
103
+ def resample_confidences_to_bbox(
104
+ predictor_output: DensePoseChartPredictorOutput,
105
+ labels: torch.Tensor,
106
+ box_xywh_abs: IntTupleBox,
107
+ ) -> Dict[str, torch.Tensor]:
108
+ """
109
+ Resamples confidences for the given bounding box
110
+
111
+ Args:
112
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
113
+ output to be resampled
114
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
115
+ outputs for the given bounding box
116
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
117
+ Return:
118
+ Resampled confidences - a dict of [H, W] tensors of float
119
+ """
120
+
121
+ x, y, w, h = box_xywh_abs
122
+ w = max(int(w), 1)
123
+ h = max(int(h), 1)
124
+
125
+ confidence_names = [
126
+ "sigma_1",
127
+ "sigma_2",
128
+ "kappa_u",
129
+ "kappa_v",
130
+ "fine_segm_confidence",
131
+ "coarse_segm_confidence",
132
+ ]
133
+ confidence_results = {key: None for key in confidence_names}
134
+ confidence_names = [
135
+ key for key in confidence_names if getattr(predictor_output, key) is not None
136
+ ]
137
+ confidence_base = torch.zeros([h, w], dtype=torch.float32, device=predictor_output.u.device)
138
+
139
+ # assign data from channels that correspond to the labels
140
+ for key in confidence_names:
141
+ resampled_confidence = F.interpolate(
142
+ getattr(predictor_output, key),
143
+ (h, w),
144
+ mode="bilinear",
145
+ align_corners=False,
146
+ )
147
+ result = confidence_base.clone()
148
+ for part_id in range(1, predictor_output.u.size(1)):
149
+ if resampled_confidence.size(1) != predictor_output.u.size(1):
150
+ # confidence is not part-based, don't try to fill it part by part
151
+ continue
152
+ result[labels == part_id] = resampled_confidence[0, part_id][labels == part_id]
153
+
154
+ if resampled_confidence.size(1) != predictor_output.u.size(1):
155
+ # confidence is not part-based, fill the data with the first channel
156
+ # (targeted for segmentation confidences that have only 1 channel)
157
+ result = resampled_confidence[0, 0]
158
+
159
+ confidence_results[key] = result
160
+
161
+ return confidence_results # pyre-ignore[7]
162
+
163
+
164
+ def densepose_chart_predictor_output_to_result_with_confidences(
165
+ predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
166
+ ) -> DensePoseChartResultWithConfidences:
167
+ """
168
+ Convert densepose chart predictor outputs to results
169
+
170
+ Args:
171
+ predictor_output (DensePoseChartPredictorOutput): DensePose predictor
172
+ output with confidences to be converted to results, must contain only 1 output
173
+ boxes (Boxes): bounding box that corresponds to the predictor output,
174
+ must contain only 1 bounding box
175
+ Return:
176
+ DensePose chart-based result with confidences (DensePoseChartResultWithConfidences)
177
+ """
178
+ assert len(predictor_output) == 1 and len(boxes) == 1, (
179
+ f"Predictor output to result conversion can operate only single outputs"
180
+ f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
181
+ )
182
+
183
+ boxes_xyxy_abs = boxes.tensor.clone()
184
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
185
+ box_xywh = make_int_box(boxes_xywh_abs[0])
186
+
187
+ labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
188
+ uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
189
+ confidences = resample_confidences_to_bbox(predictor_output, labels, box_xywh)
190
+ return DensePoseChartResultWithConfidences(labels=labels, uv=uv, **confidences)
densepose/converters/hflip.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any
6
+
7
+ from .base import BaseConverter
8
+
9
+
10
+ class HFlipConverter(BaseConverter):
11
+ """
12
+ Converts various DensePose predictor outputs to DensePose results.
13
+ Each DensePose predictor output type has to register its convertion strategy.
14
+ """
15
+
16
+ registry = {}
17
+ dst_type = None
18
+
19
+ @classmethod
20
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
21
+ # inconsistently.
22
+ def convert(cls, predictor_outputs: Any, transform_data: Any, *args, **kwargs):
23
+ """
24
+ Performs an horizontal flip on DensePose predictor outputs.
25
+ Does recursive lookup for base classes, so there's no need
26
+ for explicit registration for derived classes.
27
+
28
+ Args:
29
+ predictor_outputs: DensePose predictor output to be converted to BitMasks
30
+ transform_data: Anything useful for the flip
31
+ Return:
32
+ An instance of the same type as predictor_outputs
33
+ """
34
+ return super(HFlipConverter, cls).convert(
35
+ predictor_outputs, transform_data, *args, **kwargs
36
+ )
densepose/converters/segm_to_mask.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.structures import BitMasks, Boxes, BoxMode
10
+
11
+ from .base import IntTupleBox, make_int_box
12
+ from .to_mask import ImageSizeType
13
+
14
+
15
+ def resample_coarse_segm_tensor_to_bbox(coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox):
16
+ """
17
+ Resample coarse segmentation tensor to the given
18
+ bounding box and derive labels for each pixel of the bounding box
19
+
20
+ Args:
21
+ coarse_segm: float tensor of shape [1, K, Hout, Wout]
22
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
23
+ corner coordinates, width (W) and height (H)
24
+ Return:
25
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
26
+ """
27
+ x, y, w, h = box_xywh_abs
28
+ w = max(int(w), 1)
29
+ h = max(int(h), 1)
30
+ labels = F.interpolate(coarse_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
31
+ return labels
32
+
33
+
34
+ def resample_fine_and_coarse_segm_tensors_to_bbox(
35
+ fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
36
+ ):
37
+ """
38
+ Resample fine and coarse segmentation tensors to the given
39
+ bounding box and derive labels for each pixel of the bounding box
40
+
41
+ Args:
42
+ fine_segm: float tensor of shape [1, C, Hout, Wout]
43
+ coarse_segm: float tensor of shape [1, K, Hout, Wout]
44
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
45
+ corner coordinates, width (W) and height (H)
46
+ Return:
47
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
48
+ """
49
+ x, y, w, h = box_xywh_abs
50
+ w = max(int(w), 1)
51
+ h = max(int(h), 1)
52
+ # coarse segmentation
53
+ coarse_segm_bbox = F.interpolate(
54
+ coarse_segm,
55
+ (h, w),
56
+ mode="bilinear",
57
+ align_corners=False,
58
+ ).argmax(dim=1)
59
+ # combined coarse and fine segmentation
60
+ labels = (
61
+ F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
62
+ * (coarse_segm_bbox > 0).long()
63
+ )
64
+ return labels
65
+
66
+
67
+ def resample_fine_and_coarse_segm_to_bbox(predictor_output: Any, box_xywh_abs: IntTupleBox):
68
+ """
69
+ Resample fine and coarse segmentation outputs from a predictor to the given
70
+ bounding box and derive labels for each pixel of the bounding box
71
+
72
+ Args:
73
+ predictor_output: DensePose predictor output that contains segmentation
74
+ results to be resampled
75
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
76
+ corner coordinates, width (W) and height (H)
77
+ Return:
78
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
79
+ """
80
+ return resample_fine_and_coarse_segm_tensors_to_bbox(
81
+ predictor_output.fine_segm,
82
+ predictor_output.coarse_segm,
83
+ box_xywh_abs,
84
+ )
85
+
86
+
87
+ def predictor_output_with_coarse_segm_to_mask(
88
+ predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
89
+ ) -> BitMasks:
90
+ """
91
+ Convert predictor output with coarse and fine segmentation to a mask.
92
+ Assumes that predictor output has the following attributes:
93
+ - coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
94
+ unnormalized scores for N instances; D is the number of coarse
95
+ segmentation labels, H and W is the resolution of the estimate
96
+
97
+ Args:
98
+ predictor_output: DensePose predictor output to be converted to mask
99
+ boxes (Boxes): bounding boxes that correspond to the DensePose
100
+ predictor outputs
101
+ image_size_hw (tuple [int, int]): image height Himg and width Wimg
102
+ Return:
103
+ BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
104
+ a mask of the size of the image for each instance
105
+ """
106
+ H, W = image_size_hw
107
+ boxes_xyxy_abs = boxes.tensor.clone()
108
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
109
+ N = len(boxes_xywh_abs)
110
+ masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
111
+ for i in range(len(boxes_xywh_abs)):
112
+ box_xywh = make_int_box(boxes_xywh_abs[i])
113
+ box_mask = resample_coarse_segm_tensor_to_bbox(predictor_output[i].coarse_segm, box_xywh)
114
+ x, y, w, h = box_xywh
115
+ masks[i, y : y + h, x : x + w] = box_mask
116
+
117
+ return BitMasks(masks)
118
+
119
+
120
+ def predictor_output_with_fine_and_coarse_segm_to_mask(
121
+ predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
122
+ ) -> BitMasks:
123
+ """
124
+ Convert predictor output with coarse and fine segmentation to a mask.
125
+ Assumes that predictor output has the following attributes:
126
+ - coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
127
+ unnormalized scores for N instances; D is the number of coarse
128
+ segmentation labels, H and W is the resolution of the estimate
129
+ - fine_segm (tensor of size [N, C, H, W]): fine segmentation
130
+ unnormalized scores for N instances; C is the number of fine
131
+ segmentation labels, H and W is the resolution of the estimate
132
+
133
+ Args:
134
+ predictor_output: DensePose predictor output to be converted to mask
135
+ boxes (Boxes): bounding boxes that correspond to the DensePose
136
+ predictor outputs
137
+ image_size_hw (tuple [int, int]): image height Himg and width Wimg
138
+ Return:
139
+ BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
140
+ a mask of the size of the image for each instance
141
+ """
142
+ H, W = image_size_hw
143
+ boxes_xyxy_abs = boxes.tensor.clone()
144
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
145
+ N = len(boxes_xywh_abs)
146
+ masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
147
+ for i in range(len(boxes_xywh_abs)):
148
+ box_xywh = make_int_box(boxes_xywh_abs[i])
149
+ labels_i = resample_fine_and_coarse_segm_to_bbox(predictor_output[i], box_xywh)
150
+ x, y, w, h = box_xywh
151
+ masks[i, y : y + h, x : x + w] = labels_i > 0
152
+ return BitMasks(masks)
densepose/converters/to_chart_result.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any
6
+
7
+ from detectron2.structures import Boxes
8
+
9
+ from ..structures import DensePoseChartResult, DensePoseChartResultWithConfidences
10
+ from .base import BaseConverter
11
+
12
+
13
+ class ToChartResultConverter(BaseConverter):
14
+ """
15
+ Converts various DensePose predictor outputs to DensePose results.
16
+ Each DensePose predictor output type has to register its convertion strategy.
17
+ """
18
+
19
+ registry = {}
20
+ dst_type = DensePoseChartResult
21
+
22
+ @classmethod
23
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
24
+ # inconsistently.
25
+ def convert(cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs) -> DensePoseChartResult:
26
+ """
27
+ Convert DensePose predictor outputs to DensePoseResult using some registered
28
+ converter. Does recursive lookup for base classes, so there's no need
29
+ for explicit registration for derived classes.
30
+
31
+ Args:
32
+ densepose_predictor_outputs: DensePose predictor output to be
33
+ converted to BitMasks
34
+ boxes (Boxes): bounding boxes that correspond to the DensePose
35
+ predictor outputs
36
+ Return:
37
+ An instance of DensePoseResult. If no suitable converter was found, raises KeyError
38
+ """
39
+ return super(ToChartResultConverter, cls).convert(predictor_outputs, boxes, *args, **kwargs)
40
+
41
+
42
+ class ToChartResultConverterWithConfidences(BaseConverter):
43
+ """
44
+ Converts various DensePose predictor outputs to DensePose results.
45
+ Each DensePose predictor output type has to register its convertion strategy.
46
+ """
47
+
48
+ registry = {}
49
+ dst_type = DensePoseChartResultWithConfidences
50
+
51
+ @classmethod
52
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
53
+ # inconsistently.
54
+ def convert(
55
+ cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs
56
+ ) -> DensePoseChartResultWithConfidences:
57
+ """
58
+ Convert DensePose predictor outputs to DensePoseResult with confidences
59
+ using some registered converter. Does recursive lookup for base classes,
60
+ so there's no need for explicit registration for derived classes.
61
+
62
+ Args:
63
+ densepose_predictor_outputs: DensePose predictor output with confidences
64
+ to be converted to BitMasks
65
+ boxes (Boxes): bounding boxes that correspond to the DensePose
66
+ predictor outputs
67
+ Return:
68
+ An instance of DensePoseResult. If no suitable converter was found, raises KeyError
69
+ """
70
+ return super(ToChartResultConverterWithConfidences, cls).convert(
71
+ predictor_outputs, boxes, *args, **kwargs
72
+ )
densepose/converters/to_mask.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Tuple
6
+
7
+ from detectron2.structures import BitMasks, Boxes
8
+
9
+ from .base import BaseConverter
10
+
11
+ ImageSizeType = Tuple[int, int]
12
+
13
+
14
+ class ToMaskConverter(BaseConverter):
15
+ """
16
+ Converts various DensePose predictor outputs to masks
17
+ in bit mask format (see `BitMasks`). Each DensePose predictor output type
18
+ has to register its convertion strategy.
19
+ """
20
+
21
+ registry = {}
22
+ dst_type = BitMasks
23
+
24
+ @classmethod
25
+ # pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
26
+ # inconsistently.
27
+ def convert(
28
+ cls,
29
+ densepose_predictor_outputs: Any,
30
+ boxes: Boxes,
31
+ image_size_hw: ImageSizeType,
32
+ *args,
33
+ **kwargs
34
+ ) -> BitMasks:
35
+ """
36
+ Convert DensePose predictor outputs to BitMasks using some registered
37
+ converter. Does recursive lookup for base classes, so there's no need
38
+ for explicit registration for derived classes.
39
+
40
+ Args:
41
+ densepose_predictor_outputs: DensePose predictor output to be
42
+ converted to BitMasks
43
+ boxes (Boxes): bounding boxes that correspond to the DensePose
44
+ predictor outputs
45
+ image_size_hw (tuple [int, int]): image height and width
46
+ Return:
47
+ An instance of `BitMasks`. If no suitable converter was found, raises KeyError
48
+ """
49
+ return super(ToMaskConverter, cls).convert(
50
+ densepose_predictor_outputs, boxes, image_size_hw, *args, **kwargs
51
+ )
densepose/data/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .meshes import builtin
6
+ from .build import (
7
+ build_detection_test_loader,
8
+ build_detection_train_loader,
9
+ build_combined_loader,
10
+ build_frame_selector,
11
+ build_inference_based_loaders,
12
+ has_inference_based_loaders,
13
+ BootstrapDatasetFactoryCatalog,
14
+ )
15
+ from .combined_loader import CombinedDataLoader
16
+ from .dataset_mapper import DatasetMapper
17
+ from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter
18
+ from .image_list_dataset import ImageListDataset
19
+ from .utils import is_relative_local_path, maybe_prepend_base_path
20
+
21
+ # ensure the builtin datasets are registered
22
+ from . import datasets
23
+
24
+ # ensure the bootstrap datasets builders are registered
25
+ from . import build
26
+
27
+ __all__ = [k for k in globals().keys() if not k.startswith("_")]
densepose/data/build.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import itertools
6
+ import logging
7
+ import numpy as np
8
+ from collections import UserDict, defaultdict
9
+ from dataclasses import dataclass
10
+ from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Sequence, Tuple
11
+ import torch
12
+ from torch.utils.data.dataset import Dataset
13
+
14
+ from detectron2.config import CfgNode
15
+ from detectron2.data.build import build_detection_test_loader as d2_build_detection_test_loader
16
+ from detectron2.data.build import build_detection_train_loader as d2_build_detection_train_loader
17
+ from detectron2.data.build import (
18
+ load_proposals_into_dataset,
19
+ print_instances_class_histogram,
20
+ trivial_batch_collator,
21
+ worker_init_reset_seed,
22
+ )
23
+ from detectron2.data.catalog import DatasetCatalog, Metadata, MetadataCatalog
24
+ from detectron2.data.samplers import TrainingSampler
25
+ from detectron2.utils.comm import get_world_size
26
+
27
+ from densepose.config import get_bootstrap_dataset_config
28
+ from densepose.modeling import build_densepose_embedder
29
+
30
+ from .combined_loader import CombinedDataLoader, Loader
31
+ from .dataset_mapper import DatasetMapper
32
+ from .datasets.coco import DENSEPOSE_CSE_KEYS_WITHOUT_MASK, DENSEPOSE_IUV_KEYS_WITHOUT_MASK
33
+ from .datasets.dataset_type import DatasetType
34
+ from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter
35
+ from .samplers import (
36
+ DensePoseConfidenceBasedSampler,
37
+ DensePoseCSEConfidenceBasedSampler,
38
+ DensePoseCSEUniformSampler,
39
+ DensePoseUniformSampler,
40
+ MaskFromDensePoseSampler,
41
+ PredictionToGroundTruthSampler,
42
+ )
43
+ from .transform import ImageResizeTransform
44
+ from .utils import get_category_to_class_mapping, get_class_to_mesh_name_mapping
45
+ from .video import (
46
+ FirstKFramesSelector,
47
+ FrameSelectionStrategy,
48
+ LastKFramesSelector,
49
+ RandomKFramesSelector,
50
+ VideoKeyframeDataset,
51
+ video_list_from_file,
52
+ )
53
+
54
+ __all__ = ["build_detection_train_loader", "build_detection_test_loader"]
55
+
56
+
57
+ Instance = Dict[str, Any]
58
+ InstancePredicate = Callable[[Instance], bool]
59
+
60
+
61
+ def _compute_num_images_per_worker(cfg: CfgNode) -> int:
62
+ num_workers = get_world_size()
63
+ images_per_batch = cfg.SOLVER.IMS_PER_BATCH
64
+ assert (
65
+ images_per_batch % num_workers == 0
66
+ ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format(
67
+ images_per_batch, num_workers
68
+ )
69
+ assert (
70
+ images_per_batch >= num_workers
71
+ ), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format(
72
+ images_per_batch, num_workers
73
+ )
74
+ images_per_worker = images_per_batch // num_workers
75
+ return images_per_worker
76
+
77
+
78
+ def _map_category_id_to_contiguous_id(dataset_name: str, dataset_dicts: Iterable[Instance]) -> None:
79
+ meta = MetadataCatalog.get(dataset_name)
80
+ for dataset_dict in dataset_dicts:
81
+ for ann in dataset_dict["annotations"]:
82
+ ann["category_id"] = meta.thing_dataset_id_to_contiguous_id[ann["category_id"]]
83
+
84
+
85
+ @dataclass
86
+ class _DatasetCategory:
87
+ """
88
+ Class representing category data in a dataset:
89
+ - id: category ID, as specified in the dataset annotations file
90
+ - name: category name, as specified in the dataset annotations file
91
+ - mapped_id: category ID after applying category maps (DATASETS.CATEGORY_MAPS config option)
92
+ - mapped_name: category name after applying category maps
93
+ - dataset_name: dataset in which the category is defined
94
+
95
+ For example, when training models in a class-agnostic manner, one could take LVIS 1.0
96
+ dataset and map the animal categories to the same category as human data from COCO:
97
+ id = 225
98
+ name = "cat"
99
+ mapped_id = 1
100
+ mapped_name = "person"
101
+ dataset_name = "lvis_v1_animals_dp_train"
102
+ """
103
+
104
+ id: int
105
+ name: str
106
+ mapped_id: int
107
+ mapped_name: str
108
+ dataset_name: str
109
+
110
+
111
+ _MergedCategoriesT = Dict[int, List[_DatasetCategory]]
112
+
113
+
114
+ def _add_category_id_to_contiguous_id_maps_to_metadata(
115
+ merged_categories: _MergedCategoriesT,
116
+ ) -> None:
117
+ merged_categories_per_dataset = {}
118
+ for contiguous_cat_id, cat_id in enumerate(sorted(merged_categories.keys())):
119
+ for cat in merged_categories[cat_id]:
120
+ if cat.dataset_name not in merged_categories_per_dataset:
121
+ merged_categories_per_dataset[cat.dataset_name] = defaultdict(list)
122
+ merged_categories_per_dataset[cat.dataset_name][cat_id].append(
123
+ (
124
+ contiguous_cat_id,
125
+ cat,
126
+ )
127
+ )
128
+
129
+ logger = logging.getLogger(__name__)
130
+ for dataset_name, merged_categories in merged_categories_per_dataset.items():
131
+ meta = MetadataCatalog.get(dataset_name)
132
+ if not hasattr(meta, "thing_classes"):
133
+ meta.thing_classes = []
134
+ meta.thing_dataset_id_to_contiguous_id = {}
135
+ meta.thing_dataset_id_to_merged_id = {}
136
+ else:
137
+ meta.thing_classes.clear()
138
+ meta.thing_dataset_id_to_contiguous_id.clear()
139
+ meta.thing_dataset_id_to_merged_id.clear()
140
+ logger.info(f"Dataset {dataset_name}: category ID to contiguous ID mapping:")
141
+ for _cat_id, categories in sorted(merged_categories.items()):
142
+ added_to_thing_classes = False
143
+ for contiguous_cat_id, cat in categories:
144
+ if not added_to_thing_classes:
145
+ meta.thing_classes.append(cat.mapped_name)
146
+ added_to_thing_classes = True
147
+ meta.thing_dataset_id_to_contiguous_id[cat.id] = contiguous_cat_id
148
+ meta.thing_dataset_id_to_merged_id[cat.id] = cat.mapped_id
149
+ logger.info(f"{cat.id} ({cat.name}) -> {contiguous_cat_id}")
150
+
151
+
152
+ def _maybe_create_general_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
153
+ def has_annotations(instance: Instance) -> bool:
154
+ return "annotations" in instance
155
+
156
+ def has_only_crowd_anotations(instance: Instance) -> bool:
157
+ for ann in instance["annotations"]:
158
+ if ann.get("is_crowd", 0) == 0:
159
+ return False
160
+ return True
161
+
162
+ def general_keep_instance_predicate(instance: Instance) -> bool:
163
+ return has_annotations(instance) and not has_only_crowd_anotations(instance)
164
+
165
+ if not cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS:
166
+ return None
167
+ return general_keep_instance_predicate
168
+
169
+
170
+ def _maybe_create_keypoints_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
171
+
172
+ min_num_keypoints = cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
173
+
174
+ def has_sufficient_num_keypoints(instance: Instance) -> bool:
175
+ num_kpts = sum(
176
+ (np.array(ann["keypoints"][2::3]) > 0).sum()
177
+ for ann in instance["annotations"]
178
+ if "keypoints" in ann
179
+ )
180
+ return num_kpts >= min_num_keypoints
181
+
182
+ if cfg.MODEL.KEYPOINT_ON and (min_num_keypoints > 0):
183
+ return has_sufficient_num_keypoints
184
+ return None
185
+
186
+
187
+ def _maybe_create_mask_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
188
+ if not cfg.MODEL.MASK_ON:
189
+ return None
190
+
191
+ def has_mask_annotations(instance: Instance) -> bool:
192
+ return any("segmentation" in ann for ann in instance["annotations"])
193
+
194
+ return has_mask_annotations
195
+
196
+
197
+ def _maybe_create_densepose_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
198
+ if not cfg.MODEL.DENSEPOSE_ON:
199
+ return None
200
+
201
+ use_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
202
+
203
+ def has_densepose_annotations(instance: Instance) -> bool:
204
+ for ann in instance["annotations"]:
205
+ if all(key in ann for key in DENSEPOSE_IUV_KEYS_WITHOUT_MASK) or all(
206
+ key in ann for key in DENSEPOSE_CSE_KEYS_WITHOUT_MASK
207
+ ):
208
+ return True
209
+ if use_masks and "segmentation" in ann:
210
+ return True
211
+ return False
212
+
213
+ return has_densepose_annotations
214
+
215
+
216
+ def _maybe_create_specific_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
217
+ specific_predicate_creators = [
218
+ _maybe_create_keypoints_keep_instance_predicate,
219
+ _maybe_create_mask_keep_instance_predicate,
220
+ _maybe_create_densepose_keep_instance_predicate,
221
+ ]
222
+ predicates = [creator(cfg) for creator in specific_predicate_creators]
223
+ predicates = [p for p in predicates if p is not None]
224
+ if not predicates:
225
+ return None
226
+
227
+ def combined_predicate(instance: Instance) -> bool:
228
+ return any(p(instance) for p in predicates)
229
+
230
+ return combined_predicate
231
+
232
+
233
+ def _get_train_keep_instance_predicate(cfg: CfgNode):
234
+ general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg)
235
+ combined_specific_keep_predicate = _maybe_create_specific_keep_instance_predicate(cfg)
236
+
237
+ def combined_general_specific_keep_predicate(instance: Instance) -> bool:
238
+ return general_keep_predicate(instance) and combined_specific_keep_predicate(instance)
239
+
240
+ if (general_keep_predicate is None) and (combined_specific_keep_predicate is None):
241
+ return None
242
+ if general_keep_predicate is None:
243
+ return combined_specific_keep_predicate
244
+ if combined_specific_keep_predicate is None:
245
+ return general_keep_predicate
246
+ return combined_general_specific_keep_predicate
247
+
248
+
249
+ def _get_test_keep_instance_predicate(cfg: CfgNode):
250
+ general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg)
251
+ return general_keep_predicate
252
+
253
+
254
+ def _maybe_filter_and_map_categories(
255
+ dataset_name: str, dataset_dicts: List[Instance]
256
+ ) -> List[Instance]:
257
+ meta = MetadataCatalog.get(dataset_name)
258
+ category_id_map = meta.thing_dataset_id_to_contiguous_id
259
+ filtered_dataset_dicts = []
260
+ for dataset_dict in dataset_dicts:
261
+ anns = []
262
+ for ann in dataset_dict["annotations"]:
263
+ cat_id = ann["category_id"]
264
+ if cat_id not in category_id_map:
265
+ continue
266
+ ann["category_id"] = category_id_map[cat_id]
267
+ anns.append(ann)
268
+ dataset_dict["annotations"] = anns
269
+ filtered_dataset_dicts.append(dataset_dict)
270
+ return filtered_dataset_dicts
271
+
272
+
273
+ def _add_category_whitelists_to_metadata(cfg: CfgNode) -> None:
274
+ for dataset_name, whitelisted_cat_ids in cfg.DATASETS.WHITELISTED_CATEGORIES.items():
275
+ meta = MetadataCatalog.get(dataset_name)
276
+ meta.whitelisted_categories = whitelisted_cat_ids
277
+ logger = logging.getLogger(__name__)
278
+ logger.info(
279
+ "Whitelisted categories for dataset {}: {}".format(
280
+ dataset_name, meta.whitelisted_categories
281
+ )
282
+ )
283
+
284
+
285
+ def _add_category_maps_to_metadata(cfg: CfgNode) -> None:
286
+ for dataset_name, category_map in cfg.DATASETS.CATEGORY_MAPS.items():
287
+ category_map = {
288
+ int(cat_id_src): int(cat_id_dst) for cat_id_src, cat_id_dst in category_map.items()
289
+ }
290
+ meta = MetadataCatalog.get(dataset_name)
291
+ meta.category_map = category_map
292
+ logger = logging.getLogger(__name__)
293
+ logger.info("Category maps for dataset {}: {}".format(dataset_name, meta.category_map))
294
+
295
+
296
+ def _add_category_info_to_bootstrapping_metadata(dataset_name: str, dataset_cfg: CfgNode) -> None:
297
+ meta = MetadataCatalog.get(dataset_name)
298
+ meta.category_to_class_mapping = get_category_to_class_mapping(dataset_cfg)
299
+ meta.categories = dataset_cfg.CATEGORIES
300
+ meta.max_count_per_category = dataset_cfg.MAX_COUNT_PER_CATEGORY
301
+ logger = logging.getLogger(__name__)
302
+ logger.info(
303
+ "Category to class mapping for dataset {}: {}".format(
304
+ dataset_name, meta.category_to_class_mapping
305
+ )
306
+ )
307
+
308
+
309
+ def _maybe_add_class_to_mesh_name_map_to_metadata(dataset_names: List[str], cfg: CfgNode) -> None:
310
+ for dataset_name in dataset_names:
311
+ meta = MetadataCatalog.get(dataset_name)
312
+ if not hasattr(meta, "class_to_mesh_name"):
313
+ meta.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
314
+
315
+
316
+ def _merge_categories(dataset_names: Collection[str]) -> _MergedCategoriesT:
317
+ merged_categories = defaultdict(list)
318
+ category_names = {}
319
+ for dataset_name in dataset_names:
320
+ meta = MetadataCatalog.get(dataset_name)
321
+ whitelisted_categories = meta.get("whitelisted_categories")
322
+ category_map = meta.get("category_map", {})
323
+ cat_ids = (
324
+ whitelisted_categories if whitelisted_categories is not None else meta.categories.keys()
325
+ )
326
+ for cat_id in cat_ids:
327
+ cat_name = meta.categories[cat_id]
328
+ cat_id_mapped = category_map.get(cat_id, cat_id)
329
+ if cat_id_mapped == cat_id or cat_id_mapped in cat_ids:
330
+ category_names[cat_id] = cat_name
331
+ else:
332
+ category_names[cat_id] = str(cat_id_mapped)
333
+ # assign temporary mapped category name, this name can be changed
334
+ # during the second pass, since mapped ID can correspond to a category
335
+ # from a different dataset
336
+ cat_name_mapped = meta.categories[cat_id_mapped]
337
+ merged_categories[cat_id_mapped].append(
338
+ _DatasetCategory(
339
+ id=cat_id,
340
+ name=cat_name,
341
+ mapped_id=cat_id_mapped,
342
+ mapped_name=cat_name_mapped,
343
+ dataset_name=dataset_name,
344
+ )
345
+ )
346
+ # second pass to assign proper mapped category names
347
+ for cat_id, categories in merged_categories.items():
348
+ for cat in categories:
349
+ if cat_id in category_names and cat.mapped_name != category_names[cat_id]:
350
+ cat.mapped_name = category_names[cat_id]
351
+
352
+ return merged_categories
353
+
354
+
355
+ def _warn_if_merged_different_categories(merged_categories: _MergedCategoriesT) -> None:
356
+ logger = logging.getLogger(__name__)
357
+ for cat_id in merged_categories:
358
+ merged_categories_i = merged_categories[cat_id]
359
+ first_cat_name = merged_categories_i[0].name
360
+ if len(merged_categories_i) > 1 and not all(
361
+ cat.name == first_cat_name for cat in merged_categories_i[1:]
362
+ ):
363
+ cat_summary_str = ", ".join(
364
+ [f"{cat.id} ({cat.name}) from {cat.dataset_name}" for cat in merged_categories_i]
365
+ )
366
+ logger.warning(
367
+ f"Merged category {cat_id} corresponds to the following categories: "
368
+ f"{cat_summary_str}"
369
+ )
370
+
371
+
372
+ def combine_detection_dataset_dicts(
373
+ dataset_names: Collection[str],
374
+ keep_instance_predicate: Optional[InstancePredicate] = None,
375
+ proposal_files: Optional[Collection[str]] = None,
376
+ ) -> List[Instance]:
377
+ """
378
+ Load and prepare dataset dicts for training / testing
379
+
380
+ Args:
381
+ dataset_names (Collection[str]): a list of dataset names
382
+ keep_instance_predicate (Callable: Dict[str, Any] -> bool): predicate
383
+ applied to instance dicts which defines whether to keep the instance
384
+ proposal_files (Collection[str]): if given, a list of object proposal files
385
+ that match each dataset in `dataset_names`.
386
+ """
387
+ assert len(dataset_names)
388
+ if proposal_files is None:
389
+ proposal_files = [None] * len(dataset_names)
390
+ assert len(dataset_names) == len(proposal_files)
391
+ # load datasets and metadata
392
+ dataset_name_to_dicts = {}
393
+ for dataset_name in dataset_names:
394
+ dataset_name_to_dicts[dataset_name] = DatasetCatalog.get(dataset_name)
395
+ assert len(dataset_name_to_dicts), f"Dataset '{dataset_name}' is empty!"
396
+ # merge categories, requires category metadata to be loaded
397
+ # cat_id -> [(orig_cat_id, cat_name, dataset_name)]
398
+ merged_categories = _merge_categories(dataset_names)
399
+ _warn_if_merged_different_categories(merged_categories)
400
+ merged_category_names = [
401
+ merged_categories[cat_id][0].mapped_name for cat_id in sorted(merged_categories)
402
+ ]
403
+ # map to contiguous category IDs
404
+ _add_category_id_to_contiguous_id_maps_to_metadata(merged_categories)
405
+ # load annotations and dataset metadata
406
+ for dataset_name, proposal_file in zip(dataset_names, proposal_files):
407
+ dataset_dicts = dataset_name_to_dicts[dataset_name]
408
+ assert len(dataset_dicts), f"Dataset '{dataset_name}' is empty!"
409
+ if proposal_file is not None:
410
+ dataset_dicts = load_proposals_into_dataset(dataset_dicts, proposal_file)
411
+ dataset_dicts = _maybe_filter_and_map_categories(dataset_name, dataset_dicts)
412
+ print_instances_class_histogram(dataset_dicts, merged_category_names)
413
+ dataset_name_to_dicts[dataset_name] = dataset_dicts
414
+
415
+ if keep_instance_predicate is not None:
416
+ all_datasets_dicts_plain = [
417
+ d
418
+ for d in itertools.chain.from_iterable(dataset_name_to_dicts.values())
419
+ if keep_instance_predicate(d)
420
+ ]
421
+ else:
422
+ all_datasets_dicts_plain = list(
423
+ itertools.chain.from_iterable(dataset_name_to_dicts.values())
424
+ )
425
+ return all_datasets_dicts_plain
426
+
427
+
428
+ def build_detection_train_loader(cfg: CfgNode, mapper=None):
429
+ """
430
+ A data loader is created in a way similar to that of Detectron2.
431
+ The main differences are:
432
+ - it allows to combine datasets with different but compatible object category sets
433
+
434
+ The data loader is created by the following steps:
435
+ 1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
436
+ 2. Start workers to work on the dicts. Each worker will:
437
+ * Map each metadata dict into another format to be consumed by the model.
438
+ * Batch them by simply putting dicts into a list.
439
+ The batched ``list[mapped_dict]`` is what this dataloader will return.
440
+
441
+ Args:
442
+ cfg (CfgNode): the config
443
+ mapper (callable): a callable which takes a sample (dict) from dataset and
444
+ returns the format to be consumed by the model.
445
+ By default it will be `DatasetMapper(cfg, True)`.
446
+
447
+ Returns:
448
+ an infinite iterator of training data
449
+ """
450
+
451
+ _add_category_whitelists_to_metadata(cfg)
452
+ _add_category_maps_to_metadata(cfg)
453
+ _maybe_add_class_to_mesh_name_map_to_metadata(cfg.DATASETS.TRAIN, cfg)
454
+ dataset_dicts = combine_detection_dataset_dicts(
455
+ cfg.DATASETS.TRAIN,
456
+ keep_instance_predicate=_get_train_keep_instance_predicate(cfg),
457
+ proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
458
+ )
459
+ if mapper is None:
460
+ mapper = DatasetMapper(cfg, True)
461
+ return d2_build_detection_train_loader(cfg, dataset=dataset_dicts, mapper=mapper)
462
+
463
+
464
+ def build_detection_test_loader(cfg, dataset_name, mapper=None):
465
+ """
466
+ Similar to `build_detection_train_loader`.
467
+ But this function uses the given `dataset_name` argument (instead of the names in cfg),
468
+ and uses batch size 1.
469
+
470
+ Args:
471
+ cfg: a detectron2 CfgNode
472
+ dataset_name (str): a name of the dataset that's available in the DatasetCatalog
473
+ mapper (callable): a callable which takes a sample (dict) from dataset
474
+ and returns the format to be consumed by the model.
475
+ By default it will be `DatasetMapper(cfg, False)`.
476
+
477
+ Returns:
478
+ DataLoader: a torch DataLoader, that loads the given detection
479
+ dataset, with test-time transformation and batching.
480
+ """
481
+ _add_category_whitelists_to_metadata(cfg)
482
+ _add_category_maps_to_metadata(cfg)
483
+ _maybe_add_class_to_mesh_name_map_to_metadata([dataset_name], cfg)
484
+ dataset_dicts = combine_detection_dataset_dicts(
485
+ [dataset_name],
486
+ keep_instance_predicate=_get_test_keep_instance_predicate(cfg),
487
+ proposal_files=(
488
+ [cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]]
489
+ if cfg.MODEL.LOAD_PROPOSALS
490
+ else None
491
+ ),
492
+ )
493
+ sampler = None
494
+ if not cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE:
495
+ sampler = torch.utils.data.SequentialSampler(dataset_dicts)
496
+ if mapper is None:
497
+ mapper = DatasetMapper(cfg, False)
498
+ return d2_build_detection_test_loader(
499
+ dataset_dicts, mapper=mapper, num_workers=cfg.DATALOADER.NUM_WORKERS, sampler=sampler
500
+ )
501
+
502
+
503
+ def build_frame_selector(cfg: CfgNode):
504
+ strategy = FrameSelectionStrategy(cfg.STRATEGY)
505
+ if strategy == FrameSelectionStrategy.RANDOM_K:
506
+ frame_selector = RandomKFramesSelector(cfg.NUM_IMAGES)
507
+ elif strategy == FrameSelectionStrategy.FIRST_K:
508
+ frame_selector = FirstKFramesSelector(cfg.NUM_IMAGES)
509
+ elif strategy == FrameSelectionStrategy.LAST_K:
510
+ frame_selector = LastKFramesSelector(cfg.NUM_IMAGES)
511
+ elif strategy == FrameSelectionStrategy.ALL:
512
+ frame_selector = None
513
+ # pyre-fixme[61]: `frame_selector` may not be initialized here.
514
+ return frame_selector
515
+
516
+
517
+ def build_transform(cfg: CfgNode, data_type: str):
518
+ if cfg.TYPE == "resize":
519
+ if data_type == "image":
520
+ return ImageResizeTransform(cfg.MIN_SIZE, cfg.MAX_SIZE)
521
+ raise ValueError(f"Unknown transform {cfg.TYPE} for data type {data_type}")
522
+
523
+
524
+ def build_combined_loader(cfg: CfgNode, loaders: Collection[Loader], ratios: Sequence[float]):
525
+ images_per_worker = _compute_num_images_per_worker(cfg)
526
+ return CombinedDataLoader(loaders, images_per_worker, ratios)
527
+
528
+
529
+ def build_bootstrap_dataset(dataset_name: str, cfg: CfgNode) -> Sequence[torch.Tensor]:
530
+ """
531
+ Build dataset that provides data to bootstrap on
532
+
533
+ Args:
534
+ dataset_name (str): Name of the dataset, needs to have associated metadata
535
+ to load the data
536
+ cfg (CfgNode): bootstrapping config
537
+ Returns:
538
+ Sequence[Tensor] - dataset that provides image batches, Tensors of size
539
+ [N, C, H, W] of type float32
540
+ """
541
+ logger = logging.getLogger(__name__)
542
+ _add_category_info_to_bootstrapping_metadata(dataset_name, cfg)
543
+ meta = MetadataCatalog.get(dataset_name)
544
+ factory = BootstrapDatasetFactoryCatalog.get(meta.dataset_type)
545
+ dataset = None
546
+ if factory is not None:
547
+ dataset = factory(meta, cfg)
548
+ if dataset is None:
549
+ logger.warning(f"Failed to create dataset {dataset_name} of type {meta.dataset_type}")
550
+ return dataset
551
+
552
+
553
+ def build_data_sampler(cfg: CfgNode, sampler_cfg: CfgNode, embedder: Optional[torch.nn.Module]):
554
+ if sampler_cfg.TYPE == "densepose_uniform":
555
+ data_sampler = PredictionToGroundTruthSampler()
556
+ # transform densepose pred -> gt
557
+ data_sampler.register_sampler(
558
+ "pred_densepose",
559
+ "gt_densepose",
560
+ DensePoseUniformSampler(count_per_class=sampler_cfg.COUNT_PER_CLASS),
561
+ )
562
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
563
+ return data_sampler
564
+ elif sampler_cfg.TYPE == "densepose_UV_confidence":
565
+ data_sampler = PredictionToGroundTruthSampler()
566
+ # transform densepose pred -> gt
567
+ data_sampler.register_sampler(
568
+ "pred_densepose",
569
+ "gt_densepose",
570
+ DensePoseConfidenceBasedSampler(
571
+ confidence_channel="sigma_2",
572
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
573
+ search_proportion=0.5,
574
+ ),
575
+ )
576
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
577
+ return data_sampler
578
+ elif sampler_cfg.TYPE == "densepose_fine_segm_confidence":
579
+ data_sampler = PredictionToGroundTruthSampler()
580
+ # transform densepose pred -> gt
581
+ data_sampler.register_sampler(
582
+ "pred_densepose",
583
+ "gt_densepose",
584
+ DensePoseConfidenceBasedSampler(
585
+ confidence_channel="fine_segm_confidence",
586
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
587
+ search_proportion=0.5,
588
+ ),
589
+ )
590
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
591
+ return data_sampler
592
+ elif sampler_cfg.TYPE == "densepose_coarse_segm_confidence":
593
+ data_sampler = PredictionToGroundTruthSampler()
594
+ # transform densepose pred -> gt
595
+ data_sampler.register_sampler(
596
+ "pred_densepose",
597
+ "gt_densepose",
598
+ DensePoseConfidenceBasedSampler(
599
+ confidence_channel="coarse_segm_confidence",
600
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
601
+ search_proportion=0.5,
602
+ ),
603
+ )
604
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
605
+ return data_sampler
606
+ elif sampler_cfg.TYPE == "densepose_cse_uniform":
607
+ assert embedder is not None
608
+ data_sampler = PredictionToGroundTruthSampler()
609
+ # transform densepose pred -> gt
610
+ data_sampler.register_sampler(
611
+ "pred_densepose",
612
+ "gt_densepose",
613
+ DensePoseCSEUniformSampler(
614
+ cfg=cfg,
615
+ use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES,
616
+ embedder=embedder,
617
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
618
+ ),
619
+ )
620
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
621
+ return data_sampler
622
+ elif sampler_cfg.TYPE == "densepose_cse_coarse_segm_confidence":
623
+ assert embedder is not None
624
+ data_sampler = PredictionToGroundTruthSampler()
625
+ # transform densepose pred -> gt
626
+ data_sampler.register_sampler(
627
+ "pred_densepose",
628
+ "gt_densepose",
629
+ DensePoseCSEConfidenceBasedSampler(
630
+ cfg=cfg,
631
+ use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES,
632
+ embedder=embedder,
633
+ confidence_channel="coarse_segm_confidence",
634
+ count_per_class=sampler_cfg.COUNT_PER_CLASS,
635
+ search_proportion=0.5,
636
+ ),
637
+ )
638
+ data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
639
+ return data_sampler
640
+
641
+ raise ValueError(f"Unknown data sampler type {sampler_cfg.TYPE}")
642
+
643
+
644
+ def build_data_filter(cfg: CfgNode):
645
+ if cfg.TYPE == "detection_score":
646
+ min_score = cfg.MIN_VALUE
647
+ return ScoreBasedFilter(min_score=min_score)
648
+ raise ValueError(f"Unknown data filter type {cfg.TYPE}")
649
+
650
+
651
+ def build_inference_based_loader(
652
+ cfg: CfgNode,
653
+ dataset_cfg: CfgNode,
654
+ model: torch.nn.Module,
655
+ embedder: Optional[torch.nn.Module] = None,
656
+ ) -> InferenceBasedLoader:
657
+ """
658
+ Constructs data loader based on inference results of a model.
659
+ """
660
+ dataset = build_bootstrap_dataset(dataset_cfg.DATASET, dataset_cfg.IMAGE_LOADER)
661
+ meta = MetadataCatalog.get(dataset_cfg.DATASET)
662
+ training_sampler = TrainingSampler(len(dataset))
663
+ data_loader = torch.utils.data.DataLoader(
664
+ dataset, # pyre-ignore[6]
665
+ batch_size=dataset_cfg.IMAGE_LOADER.BATCH_SIZE,
666
+ sampler=training_sampler,
667
+ num_workers=dataset_cfg.IMAGE_LOADER.NUM_WORKERS,
668
+ collate_fn=trivial_batch_collator,
669
+ worker_init_fn=worker_init_reset_seed,
670
+ )
671
+ return InferenceBasedLoader(
672
+ model,
673
+ data_loader=data_loader,
674
+ data_sampler=build_data_sampler(cfg, dataset_cfg.DATA_SAMPLER, embedder),
675
+ data_filter=build_data_filter(dataset_cfg.FILTER),
676
+ shuffle=True,
677
+ batch_size=dataset_cfg.INFERENCE.OUTPUT_BATCH_SIZE,
678
+ inference_batch_size=dataset_cfg.INFERENCE.INPUT_BATCH_SIZE,
679
+ category_to_class_mapping=meta.category_to_class_mapping,
680
+ )
681
+
682
+
683
+ def has_inference_based_loaders(cfg: CfgNode) -> bool:
684
+ """
685
+ Returns True, if at least one inferense-based loader must
686
+ be instantiated for training
687
+ """
688
+ return len(cfg.BOOTSTRAP_DATASETS) > 0
689
+
690
+
691
+ def build_inference_based_loaders(
692
+ cfg: CfgNode, model: torch.nn.Module
693
+ ) -> Tuple[List[InferenceBasedLoader], List[float]]:
694
+ loaders = []
695
+ ratios = []
696
+ embedder = build_densepose_embedder(cfg).to(device=model.device) # pyre-ignore[16]
697
+ for dataset_spec in cfg.BOOTSTRAP_DATASETS:
698
+ dataset_cfg = get_bootstrap_dataset_config().clone()
699
+ dataset_cfg.merge_from_other_cfg(CfgNode(dataset_spec))
700
+ loader = build_inference_based_loader(cfg, dataset_cfg, model, embedder)
701
+ loaders.append(loader)
702
+ ratios.append(dataset_cfg.RATIO)
703
+ return loaders, ratios
704
+
705
+
706
+ def build_video_list_dataset(meta: Metadata, cfg: CfgNode):
707
+ video_list_fpath = meta.video_list_fpath
708
+ video_base_path = meta.video_base_path
709
+ category = meta.category
710
+ if cfg.TYPE == "video_keyframe":
711
+ frame_selector = build_frame_selector(cfg.SELECT)
712
+ transform = build_transform(cfg.TRANSFORM, data_type="image")
713
+ video_list = video_list_from_file(video_list_fpath, video_base_path)
714
+ keyframe_helper_fpath = getattr(cfg, "KEYFRAME_HELPER", None)
715
+ return VideoKeyframeDataset(
716
+ video_list, category, frame_selector, transform, keyframe_helper_fpath
717
+ )
718
+
719
+
720
+ class _BootstrapDatasetFactoryCatalog(UserDict):
721
+ """
722
+ A global dictionary that stores information about bootstrapped datasets creation functions
723
+ from metadata and config, for diverse DatasetType
724
+ """
725
+
726
+ def register(self, dataset_type: DatasetType, factory: Callable[[Metadata, CfgNode], Dataset]):
727
+ """
728
+ Args:
729
+ dataset_type (DatasetType): a DatasetType e.g. DatasetType.VIDEO_LIST
730
+ factory (Callable[Metadata, CfgNode]): a callable which takes Metadata and cfg
731
+ arguments and returns a dataset object.
732
+ """
733
+ assert dataset_type not in self, "Dataset '{}' is already registered!".format(dataset_type)
734
+ self[dataset_type] = factory
735
+
736
+
737
+ BootstrapDatasetFactoryCatalog = _BootstrapDatasetFactoryCatalog()
738
+ BootstrapDatasetFactoryCatalog.register(DatasetType.VIDEO_LIST, build_video_list_dataset)
densepose/data/combined_loader.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from collections import deque
7
+ from typing import Any, Collection, Deque, Iterable, Iterator, List, Sequence
8
+
9
+ Loader = Iterable[Any]
10
+
11
+
12
+ def _pooled_next(iterator: Iterator[Any], pool: Deque[Any]):
13
+ if not pool:
14
+ pool.extend(next(iterator))
15
+ return pool.popleft()
16
+
17
+
18
+ class CombinedDataLoader:
19
+ """
20
+ Combines data loaders using the provided sampling ratios
21
+ """
22
+
23
+ BATCH_COUNT = 100
24
+
25
+ def __init__(self, loaders: Collection[Loader], batch_size: int, ratios: Sequence[float]):
26
+ self.loaders = loaders
27
+ self.batch_size = batch_size
28
+ self.ratios = ratios
29
+
30
+ def __iter__(self) -> Iterator[List[Any]]:
31
+ iters = [iter(loader) for loader in self.loaders]
32
+ indices = []
33
+ pool = [deque()] * len(iters)
34
+ # infinite iterator, as in D2
35
+ while True:
36
+ if not indices:
37
+ # just a buffer of indices, its size doesn't matter
38
+ # as long as it's a multiple of batch_size
39
+ k = self.batch_size * self.BATCH_COUNT
40
+ indices = random.choices(range(len(self.loaders)), self.ratios, k=k)
41
+ try:
42
+ batch = [_pooled_next(iters[i], pool[i]) for i in indices[: self.batch_size]]
43
+ except StopIteration:
44
+ break
45
+ indices = indices[self.batch_size :]
46
+ yield batch
densepose/data/dataset_mapper.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ # pyre-unsafe
5
+
6
+ import copy
7
+ import logging
8
+ from typing import Any, Dict, List, Tuple
9
+ import torch
10
+
11
+ from detectron2.data import MetadataCatalog
12
+ from detectron2.data import detection_utils as utils
13
+ from detectron2.data import transforms as T
14
+ from detectron2.layers import ROIAlign
15
+ from detectron2.structures import BoxMode
16
+ from detectron2.utils.file_io import PathManager
17
+
18
+ from densepose.structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
19
+
20
+
21
+ def build_augmentation(cfg, is_train):
22
+ logger = logging.getLogger(__name__)
23
+ result = utils.build_augmentation(cfg, is_train)
24
+ if is_train:
25
+ random_rotation = T.RandomRotation(
26
+ cfg.INPUT.ROTATION_ANGLES, expand=False, sample_style="choice"
27
+ )
28
+ result.append(random_rotation)
29
+ logger.info("DensePose-specific augmentation used in training: " + str(random_rotation))
30
+ return result
31
+
32
+
33
+ class DatasetMapper:
34
+ """
35
+ A customized version of `detectron2.data.DatasetMapper`
36
+ """
37
+
38
+ def __init__(self, cfg, is_train=True):
39
+ self.augmentation = build_augmentation(cfg, is_train)
40
+
41
+ # fmt: off
42
+ self.img_format = cfg.INPUT.FORMAT
43
+ self.mask_on = (
44
+ cfg.MODEL.MASK_ON or (
45
+ cfg.MODEL.DENSEPOSE_ON
46
+ and cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS)
47
+ )
48
+ self.keypoint_on = cfg.MODEL.KEYPOINT_ON
49
+ self.densepose_on = cfg.MODEL.DENSEPOSE_ON
50
+ assert not cfg.MODEL.LOAD_PROPOSALS, "not supported yet"
51
+ # fmt: on
52
+ if self.keypoint_on and is_train:
53
+ # Flip only makes sense in training
54
+ self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
55
+ else:
56
+ self.keypoint_hflip_indices = None
57
+
58
+ if self.densepose_on:
59
+ densepose_transform_srcs = [
60
+ MetadataCatalog.get(ds).densepose_transform_src
61
+ for ds in cfg.DATASETS.TRAIN + cfg.DATASETS.TEST
62
+ ]
63
+ assert len(densepose_transform_srcs) > 0
64
+ # TODO: check that DensePose transformation data is the same for
65
+ # all the datasets. Otherwise one would have to pass DB ID with
66
+ # each entry to select proper transformation data. For now, since
67
+ # all DensePose annotated data uses the same data semantics, we
68
+ # omit this check.
69
+ densepose_transform_data_fpath = PathManager.get_local_path(densepose_transform_srcs[0])
70
+ self.densepose_transform_data = DensePoseTransformData.load(
71
+ densepose_transform_data_fpath
72
+ )
73
+
74
+ self.is_train = is_train
75
+
76
+ def __call__(self, dataset_dict):
77
+ """
78
+ Args:
79
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
80
+
81
+ Returns:
82
+ dict: a format that builtin models in detectron2 accept
83
+ """
84
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
85
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
86
+ utils.check_image_size(dataset_dict, image)
87
+
88
+ image, transforms = T.apply_transform_gens(self.augmentation, image)
89
+ image_shape = image.shape[:2] # h, w
90
+ dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
91
+
92
+ if not self.is_train:
93
+ dataset_dict.pop("annotations", None)
94
+ return dataset_dict
95
+
96
+ for anno in dataset_dict["annotations"]:
97
+ if not self.mask_on:
98
+ anno.pop("segmentation", None)
99
+ if not self.keypoint_on:
100
+ anno.pop("keypoints", None)
101
+
102
+ # USER: Implement additional transformations if you have other types of data
103
+ # USER: Don't call transpose_densepose if you don't need
104
+ annos = [
105
+ self._transform_densepose(
106
+ utils.transform_instance_annotations(
107
+ obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
108
+ ),
109
+ transforms,
110
+ )
111
+ for obj in dataset_dict.pop("annotations")
112
+ if obj.get("iscrowd", 0) == 0
113
+ ]
114
+
115
+ if self.mask_on:
116
+ self._add_densepose_masks_as_segmentation(annos, image_shape)
117
+
118
+ instances = utils.annotations_to_instances(annos, image_shape, mask_format="bitmask")
119
+ densepose_annotations = [obj.get("densepose") for obj in annos]
120
+ if densepose_annotations and not all(v is None for v in densepose_annotations):
121
+ instances.gt_densepose = DensePoseList(
122
+ densepose_annotations, instances.gt_boxes, image_shape
123
+ )
124
+
125
+ dataset_dict["instances"] = instances[instances.gt_boxes.nonempty()]
126
+ return dataset_dict
127
+
128
+ def _transform_densepose(self, annotation, transforms):
129
+ if not self.densepose_on:
130
+ return annotation
131
+
132
+ # Handle densepose annotations
133
+ is_valid, reason_not_valid = DensePoseDataRelative.validate_annotation(annotation)
134
+ if is_valid:
135
+ densepose_data = DensePoseDataRelative(annotation, cleanup=True)
136
+ densepose_data.apply_transform(transforms, self.densepose_transform_data)
137
+ annotation["densepose"] = densepose_data
138
+ else:
139
+ # logger = logging.getLogger(__name__)
140
+ # logger.debug("Could not load DensePose annotation: {}".format(reason_not_valid))
141
+ DensePoseDataRelative.cleanup_annotation(annotation)
142
+ # NOTE: annotations for certain instances may be unavailable.
143
+ # 'None' is accepted by the DensePostList data structure.
144
+ annotation["densepose"] = None
145
+ return annotation
146
+
147
+ def _add_densepose_masks_as_segmentation(
148
+ self, annotations: List[Dict[str, Any]], image_shape_hw: Tuple[int, int]
149
+ ):
150
+ for obj in annotations:
151
+ if ("densepose" not in obj) or ("segmentation" in obj):
152
+ continue
153
+ # DP segmentation: torch.Tensor [S, S] of float32, S=256
154
+ segm_dp = torch.zeros_like(obj["densepose"].segm)
155
+ segm_dp[obj["densepose"].segm > 0] = 1
156
+ segm_h, segm_w = segm_dp.shape
157
+ bbox_segm_dp = torch.tensor((0, 0, segm_h - 1, segm_w - 1), dtype=torch.float32)
158
+ # image bbox
159
+ x0, y0, x1, y1 = (
160
+ v.item() for v in BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS)
161
+ )
162
+ segm_aligned = (
163
+ ROIAlign((y1 - y0, x1 - x0), 1.0, 0, aligned=True)
164
+ .forward(segm_dp.view(1, 1, *segm_dp.shape), bbox_segm_dp)
165
+ .squeeze()
166
+ )
167
+ image_mask = torch.zeros(*image_shape_hw, dtype=torch.float32)
168
+ image_mask[y0:y1, x0:x1] = segm_aligned
169
+ # segmentation for BitMask: np.array [H, W] of bool
170
+ obj["segmentation"] = image_mask >= 0.5
densepose/data/image_list_dataset.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ # pyre-unsafe
5
+
6
+ import logging
7
+ import numpy as np
8
+ from typing import Any, Callable, Dict, List, Optional, Union
9
+ import torch
10
+ from torch.utils.data.dataset import Dataset
11
+
12
+ from detectron2.data.detection_utils import read_image
13
+
14
+ ImageTransform = Callable[[torch.Tensor], torch.Tensor]
15
+
16
+
17
+ class ImageListDataset(Dataset):
18
+ """
19
+ Dataset that provides images from a list.
20
+ """
21
+
22
+ _EMPTY_IMAGE = torch.empty((0, 3, 1, 1))
23
+
24
+ def __init__(
25
+ self,
26
+ image_list: List[str],
27
+ category_list: Union[str, List[str], None] = None,
28
+ transform: Optional[ImageTransform] = None,
29
+ ):
30
+ """
31
+ Args:
32
+ image_list (List[str]): list of paths to image files
33
+ category_list (Union[str, List[str], None]): list of animal categories for
34
+ each image. If it is a string, or None, this applies to all images
35
+ """
36
+ if type(category_list) is list:
37
+ self.category_list = category_list
38
+ else:
39
+ self.category_list = [category_list] * len(image_list)
40
+ assert len(image_list) == len(
41
+ self.category_list
42
+ ), "length of image and category lists must be equal"
43
+ self.image_list = image_list
44
+ self.transform = transform
45
+
46
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
47
+ """
48
+ Gets selected images from the list
49
+
50
+ Args:
51
+ idx (int): video index in the video list file
52
+ Returns:
53
+ A dictionary containing two keys:
54
+ images (torch.Tensor): tensor of size [N, 3, H, W] (N = 1, or 0 for _EMPTY_IMAGE)
55
+ categories (List[str]): categories of the frames
56
+ """
57
+ categories = [self.category_list[idx]]
58
+ fpath = self.image_list[idx]
59
+ transform = self.transform
60
+
61
+ try:
62
+ image = torch.from_numpy(np.ascontiguousarray(read_image(fpath, format="BGR")))
63
+ image = image.permute(2, 0, 1).unsqueeze(0).float() # HWC -> NCHW
64
+ if transform is not None:
65
+ image = transform(image)
66
+ return {"images": image, "categories": categories}
67
+ except (OSError, RuntimeError) as e:
68
+ logger = logging.getLogger(__name__)
69
+ logger.warning(f"Error opening image file container {fpath}: {e}")
70
+
71
+ return {"images": self._EMPTY_IMAGE, "categories": []}
72
+
73
+ def __len__(self):
74
+ return len(self.image_list)
densepose/data/inference_based_loader.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
7
+ import torch
8
+ from torch import nn
9
+
10
+ SampledData = Any
11
+ ModelOutput = Any
12
+
13
+
14
+ def _grouper(iterable: Iterable[Any], n: int, fillvalue=None) -> Iterator[Tuple[Any]]:
15
+ """
16
+ Group elements of an iterable by chunks of size `n`, e.g.
17
+ grouper(range(9), 4) ->
18
+ (0, 1, 2, 3), (4, 5, 6, 7), (8, None, None, None)
19
+ """
20
+ it = iter(iterable)
21
+ while True:
22
+ values = []
23
+ for _ in range(n):
24
+ try:
25
+ value = next(it)
26
+ except StopIteration:
27
+ if values:
28
+ values.extend([fillvalue] * (n - len(values)))
29
+ yield tuple(values)
30
+ return
31
+ values.append(value)
32
+ yield tuple(values)
33
+
34
+
35
+ class ScoreBasedFilter:
36
+ """
37
+ Filters entries in model output based on their scores
38
+ Discards all entries with score less than the specified minimum
39
+ """
40
+
41
+ def __init__(self, min_score: float = 0.8):
42
+ self.min_score = min_score
43
+
44
+ def __call__(self, model_output: ModelOutput) -> ModelOutput:
45
+ for model_output_i in model_output:
46
+ instances = model_output_i["instances"]
47
+ if not instances.has("scores"):
48
+ continue
49
+ instances_filtered = instances[instances.scores >= self.min_score]
50
+ model_output_i["instances"] = instances_filtered
51
+ return model_output
52
+
53
+
54
+ class InferenceBasedLoader:
55
+ """
56
+ Data loader based on results inferred by a model. Consists of:
57
+ - a data loader that provides batches of images
58
+ - a model that is used to infer the results
59
+ - a data sampler that converts inferred results to annotations
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ model: nn.Module,
65
+ data_loader: Iterable[List[Dict[str, Any]]],
66
+ data_sampler: Optional[Callable[[ModelOutput], List[SampledData]]] = None,
67
+ data_filter: Optional[Callable[[ModelOutput], ModelOutput]] = None,
68
+ shuffle: bool = True,
69
+ batch_size: int = 4,
70
+ inference_batch_size: int = 4,
71
+ drop_last: bool = False,
72
+ category_to_class_mapping: Optional[dict] = None,
73
+ ):
74
+ """
75
+ Constructor
76
+
77
+ Args:
78
+ model (torch.nn.Module): model used to produce data
79
+ data_loader (Iterable[List[Dict[str, Any]]]): iterable that provides
80
+ dictionaries with "images" and "categories" fields to perform inference on
81
+ data_sampler (Callable: ModelOutput -> SampledData): functor
82
+ that produces annotation data from inference results;
83
+ (optional, default: None)
84
+ data_filter (Callable: ModelOutput -> ModelOutput): filter
85
+ that selects model outputs for further processing
86
+ (optional, default: None)
87
+ shuffle (bool): if True, the input images get shuffled
88
+ batch_size (int): batch size for the produced annotation data
89
+ inference_batch_size (int): batch size for input images
90
+ drop_last (bool): if True, drop the last batch if it is undersized
91
+ category_to_class_mapping (dict): category to class mapping
92
+ """
93
+ self.model = model
94
+ self.model.eval()
95
+ self.data_loader = data_loader
96
+ self.data_sampler = data_sampler
97
+ self.data_filter = data_filter
98
+ self.shuffle = shuffle
99
+ self.batch_size = batch_size
100
+ self.inference_batch_size = inference_batch_size
101
+ self.drop_last = drop_last
102
+ if category_to_class_mapping is not None:
103
+ self.category_to_class_mapping = category_to_class_mapping
104
+ else:
105
+ self.category_to_class_mapping = {}
106
+
107
+ def __iter__(self) -> Iterator[List[SampledData]]:
108
+ for batch in self.data_loader:
109
+ # batch : List[Dict[str: Tensor[N, C, H, W], str: Optional[str]]]
110
+ # images_batch : Tensor[N, C, H, W]
111
+ # image : Tensor[C, H, W]
112
+ images_and_categories = [
113
+ {"image": image, "category": category}
114
+ for element in batch
115
+ for image, category in zip(element["images"], element["categories"])
116
+ ]
117
+ if not images_and_categories:
118
+ continue
119
+ if self.shuffle:
120
+ random.shuffle(images_and_categories)
121
+ yield from self._produce_data(images_and_categories) # pyre-ignore[6]
122
+
123
+ def _produce_data(
124
+ self, images_and_categories: List[Tuple[torch.Tensor, Optional[str]]]
125
+ ) -> Iterator[List[SampledData]]:
126
+ """
127
+ Produce batches of data from images
128
+
129
+ Args:
130
+ images_and_categories (List[Tuple[torch.Tensor, Optional[str]]]):
131
+ list of images and corresponding categories to process
132
+
133
+ Returns:
134
+ Iterator over batches of data sampled from model outputs
135
+ """
136
+ data_batches: List[SampledData] = []
137
+ category_to_class_mapping = self.category_to_class_mapping
138
+ batched_images_and_categories = _grouper(images_and_categories, self.inference_batch_size)
139
+ for batch in batched_images_and_categories:
140
+ batch = [
141
+ {
142
+ "image": image_and_category["image"].to(self.model.device),
143
+ "category": image_and_category["category"],
144
+ }
145
+ for image_and_category in batch
146
+ if image_and_category is not None
147
+ ]
148
+ if not batch:
149
+ continue
150
+ with torch.no_grad():
151
+ model_output = self.model(batch)
152
+ for model_output_i, batch_i in zip(model_output, batch):
153
+ assert len(batch_i["image"].shape) == 3
154
+ model_output_i["image"] = batch_i["image"]
155
+ instance_class = category_to_class_mapping.get(batch_i["category"], 0)
156
+ model_output_i["instances"].dataset_classes = torch.tensor(
157
+ [instance_class] * len(model_output_i["instances"])
158
+ )
159
+ model_output_filtered = (
160
+ model_output if self.data_filter is None else self.data_filter(model_output)
161
+ )
162
+ data = (
163
+ model_output_filtered
164
+ if self.data_sampler is None
165
+ else self.data_sampler(model_output_filtered)
166
+ )
167
+ for data_i in data:
168
+ if len(data_i["instances"]):
169
+ data_batches.append(data_i)
170
+ if len(data_batches) >= self.batch_size:
171
+ yield data_batches[: self.batch_size]
172
+ data_batches = data_batches[self.batch_size :]
173
+ if not self.drop_last and data_batches:
174
+ yield data_batches
densepose/data/meshes/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ from . import builtin
6
+
7
+ __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
densepose/data/meshes/builtin.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ from .catalog import MeshInfo, register_meshes
6
+
7
+ DENSEPOSE_MESHES_DIR = "https://dl.fbaipublicfiles.com/densepose/meshes/"
8
+
9
+ MESHES = [
10
+ MeshInfo(
11
+ name="smpl_27554",
12
+ data="smpl_27554.pkl",
13
+ geodists="geodists/geodists_smpl_27554.pkl",
14
+ symmetry="symmetry/symmetry_smpl_27554.pkl",
15
+ texcoords="texcoords/texcoords_smpl_27554.pkl",
16
+ ),
17
+ MeshInfo(
18
+ name="chimp_5029",
19
+ data="chimp_5029.pkl",
20
+ geodists="geodists/geodists_chimp_5029.pkl",
21
+ symmetry="symmetry/symmetry_chimp_5029.pkl",
22
+ texcoords="texcoords/texcoords_chimp_5029.pkl",
23
+ ),
24
+ MeshInfo(
25
+ name="cat_5001",
26
+ data="cat_5001.pkl",
27
+ geodists="geodists/geodists_cat_5001.pkl",
28
+ symmetry="symmetry/symmetry_cat_5001.pkl",
29
+ texcoords="texcoords/texcoords_cat_5001.pkl",
30
+ ),
31
+ MeshInfo(
32
+ name="cat_7466",
33
+ data="cat_7466.pkl",
34
+ geodists="geodists/geodists_cat_7466.pkl",
35
+ symmetry="symmetry/symmetry_cat_7466.pkl",
36
+ texcoords="texcoords/texcoords_cat_7466.pkl",
37
+ ),
38
+ MeshInfo(
39
+ name="sheep_5004",
40
+ data="sheep_5004.pkl",
41
+ geodists="geodists/geodists_sheep_5004.pkl",
42
+ symmetry="symmetry/symmetry_sheep_5004.pkl",
43
+ texcoords="texcoords/texcoords_sheep_5004.pkl",
44
+ ),
45
+ MeshInfo(
46
+ name="zebra_5002",
47
+ data="zebra_5002.pkl",
48
+ geodists="geodists/geodists_zebra_5002.pkl",
49
+ symmetry="symmetry/symmetry_zebra_5002.pkl",
50
+ texcoords="texcoords/texcoords_zebra_5002.pkl",
51
+ ),
52
+ MeshInfo(
53
+ name="horse_5004",
54
+ data="horse_5004.pkl",
55
+ geodists="geodists/geodists_horse_5004.pkl",
56
+ symmetry="symmetry/symmetry_horse_5004.pkl",
57
+ texcoords="texcoords/texcoords_zebra_5002.pkl",
58
+ ),
59
+ MeshInfo(
60
+ name="giraffe_5002",
61
+ data="giraffe_5002.pkl",
62
+ geodists="geodists/geodists_giraffe_5002.pkl",
63
+ symmetry="symmetry/symmetry_giraffe_5002.pkl",
64
+ texcoords="texcoords/texcoords_giraffe_5002.pkl",
65
+ ),
66
+ MeshInfo(
67
+ name="elephant_5002",
68
+ data="elephant_5002.pkl",
69
+ geodists="geodists/geodists_elephant_5002.pkl",
70
+ symmetry="symmetry/symmetry_elephant_5002.pkl",
71
+ texcoords="texcoords/texcoords_elephant_5002.pkl",
72
+ ),
73
+ MeshInfo(
74
+ name="dog_5002",
75
+ data="dog_5002.pkl",
76
+ geodists="geodists/geodists_dog_5002.pkl",
77
+ symmetry="symmetry/symmetry_dog_5002.pkl",
78
+ texcoords="texcoords/texcoords_dog_5002.pkl",
79
+ ),
80
+ MeshInfo(
81
+ name="dog_7466",
82
+ data="dog_7466.pkl",
83
+ geodists="geodists/geodists_dog_7466.pkl",
84
+ symmetry="symmetry/symmetry_dog_7466.pkl",
85
+ texcoords="texcoords/texcoords_dog_7466.pkl",
86
+ ),
87
+ MeshInfo(
88
+ name="cow_5002",
89
+ data="cow_5002.pkl",
90
+ geodists="geodists/geodists_cow_5002.pkl",
91
+ symmetry="symmetry/symmetry_cow_5002.pkl",
92
+ texcoords="texcoords/texcoords_cow_5002.pkl",
93
+ ),
94
+ MeshInfo(
95
+ name="bear_4936",
96
+ data="bear_4936.pkl",
97
+ geodists="geodists/geodists_bear_4936.pkl",
98
+ symmetry="symmetry/symmetry_bear_4936.pkl",
99
+ texcoords="texcoords/texcoords_bear_4936.pkl",
100
+ ),
101
+ ]
102
+
103
+ register_meshes(MESHES, DENSEPOSE_MESHES_DIR)
densepose/data/meshes/catalog.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ import logging
6
+ from collections import UserDict
7
+ from dataclasses import dataclass
8
+ from typing import Iterable, Optional
9
+
10
+ from ..utils import maybe_prepend_base_path
11
+
12
+
13
+ @dataclass
14
+ class MeshInfo:
15
+ name: str
16
+ data: str
17
+ geodists: Optional[str] = None
18
+ symmetry: Optional[str] = None
19
+ texcoords: Optional[str] = None
20
+
21
+
22
+ class _MeshCatalog(UserDict):
23
+ def __init__(self, *args, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+ self.mesh_ids = {}
26
+ self.mesh_names = {}
27
+ self.max_mesh_id = -1
28
+
29
+ def __setitem__(self, key, value):
30
+ if key in self:
31
+ logger = logging.getLogger(__name__)
32
+ logger.warning(
33
+ f"Overwriting mesh catalog entry '{key}': old value {self[key]}"
34
+ f", new value {value}"
35
+ )
36
+ mesh_id = self.mesh_ids[key]
37
+ else:
38
+ self.max_mesh_id += 1
39
+ mesh_id = self.max_mesh_id
40
+ super().__setitem__(key, value)
41
+ self.mesh_ids[key] = mesh_id
42
+ self.mesh_names[mesh_id] = key
43
+
44
+ def get_mesh_id(self, shape_name: str) -> int:
45
+ return self.mesh_ids[shape_name]
46
+
47
+ def get_mesh_name(self, mesh_id: int) -> str:
48
+ return self.mesh_names[mesh_id]
49
+
50
+
51
+ MeshCatalog = _MeshCatalog()
52
+
53
+
54
+ def register_mesh(mesh_info: MeshInfo, base_path: Optional[str]) -> None:
55
+ geodists, symmetry, texcoords = mesh_info.geodists, mesh_info.symmetry, mesh_info.texcoords
56
+ if geodists:
57
+ geodists = maybe_prepend_base_path(base_path, geodists)
58
+ if symmetry:
59
+ symmetry = maybe_prepend_base_path(base_path, symmetry)
60
+ if texcoords:
61
+ texcoords = maybe_prepend_base_path(base_path, texcoords)
62
+ MeshCatalog[mesh_info.name] = MeshInfo(
63
+ name=mesh_info.name,
64
+ data=maybe_prepend_base_path(base_path, mesh_info.data),
65
+ geodists=geodists,
66
+ symmetry=symmetry,
67
+ texcoords=texcoords,
68
+ )
69
+
70
+
71
+ def register_meshes(mesh_infos: Iterable[MeshInfo], base_path: Optional[str]) -> None:
72
+ for mesh_info in mesh_infos:
73
+ register_mesh(mesh_info, base_path)
densepose/data/samplers/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .densepose_uniform import DensePoseUniformSampler
6
+ from .densepose_confidence_based import DensePoseConfidenceBasedSampler
7
+ from .densepose_cse_uniform import DensePoseCSEUniformSampler
8
+ from .densepose_cse_confidence_based import DensePoseCSEConfidenceBasedSampler
9
+ from .mask_from_densepose import MaskFromDensePoseSampler
10
+ from .prediction_to_gt import PredictionToGroundTruthSampler
densepose/data/samplers/densepose_base.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Dict, List, Tuple
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.structures import BoxMode, Instances
10
+
11
+ from densepose.converters import ToChartResultConverter
12
+ from densepose.converters.base import IntTupleBox, make_int_box
13
+ from densepose.structures import DensePoseDataRelative, DensePoseList
14
+
15
+
16
+ class DensePoseBaseSampler:
17
+ """
18
+ Base DensePose sampler to produce DensePose data from DensePose predictions.
19
+ Samples for each class are drawn according to some distribution over all pixels estimated
20
+ to belong to that class.
21
+ """
22
+
23
+ def __init__(self, count_per_class: int = 8):
24
+ """
25
+ Constructor
26
+
27
+ Args:
28
+ count_per_class (int): the sampler produces at most `count_per_class`
29
+ samples for each category
30
+ """
31
+ self.count_per_class = count_per_class
32
+
33
+ def __call__(self, instances: Instances) -> DensePoseList:
34
+ """
35
+ Convert DensePose predictions (an instance of `DensePoseChartPredictorOutput`)
36
+ into DensePose annotations data (an instance of `DensePoseList`)
37
+ """
38
+ boxes_xyxy_abs = instances.pred_boxes.tensor.clone().cpu()
39
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
40
+ dp_datas = []
41
+ for i in range(len(boxes_xywh_abs)):
42
+ annotation_i = self._sample(instances[i], make_int_box(boxes_xywh_abs[i]))
43
+ annotation_i[DensePoseDataRelative.S_KEY] = self._resample_mask( # pyre-ignore[6]
44
+ instances[i].pred_densepose
45
+ )
46
+ dp_datas.append(DensePoseDataRelative(annotation_i))
47
+ # create densepose annotations on CPU
48
+ dp_list = DensePoseList(dp_datas, boxes_xyxy_abs, instances.image_size)
49
+ return dp_list
50
+
51
+ def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
52
+ """
53
+ Sample DensPoseDataRelative from estimation results
54
+ """
55
+ labels, dp_result = self._produce_labels_and_results(instance)
56
+ annotation = {
57
+ DensePoseDataRelative.X_KEY: [],
58
+ DensePoseDataRelative.Y_KEY: [],
59
+ DensePoseDataRelative.U_KEY: [],
60
+ DensePoseDataRelative.V_KEY: [],
61
+ DensePoseDataRelative.I_KEY: [],
62
+ }
63
+ n, h, w = dp_result.shape
64
+ for part_id in range(1, DensePoseDataRelative.N_PART_LABELS + 1):
65
+ # indices - tuple of 3 1D tensors of size k
66
+ # 0: index along the first dimension N
67
+ # 1: index along H dimension
68
+ # 2: index along W dimension
69
+ indices = torch.nonzero(labels.expand(n, h, w) == part_id, as_tuple=True)
70
+ # values - an array of size [n, k]
71
+ # n: number of channels (U, V, confidences)
72
+ # k: number of points labeled with part_id
73
+ values = dp_result[indices].view(n, -1)
74
+ k = values.shape[1]
75
+ count = min(self.count_per_class, k)
76
+ if count <= 0:
77
+ continue
78
+ index_sample = self._produce_index_sample(values, count)
79
+ sampled_values = values[:, index_sample]
80
+ sampled_y = indices[1][index_sample] + 0.5
81
+ sampled_x = indices[2][index_sample] + 0.5
82
+ # prepare / normalize data
83
+ x = (sampled_x / w * 256.0).cpu().tolist()
84
+ y = (sampled_y / h * 256.0).cpu().tolist()
85
+ u = sampled_values[0].clamp(0, 1).cpu().tolist()
86
+ v = sampled_values[1].clamp(0, 1).cpu().tolist()
87
+ fine_segm_labels = [part_id] * count
88
+ # extend annotations
89
+ annotation[DensePoseDataRelative.X_KEY].extend(x)
90
+ annotation[DensePoseDataRelative.Y_KEY].extend(y)
91
+ annotation[DensePoseDataRelative.U_KEY].extend(u)
92
+ annotation[DensePoseDataRelative.V_KEY].extend(v)
93
+ annotation[DensePoseDataRelative.I_KEY].extend(fine_segm_labels)
94
+ return annotation
95
+
96
+ def _produce_index_sample(self, values: torch.Tensor, count: int):
97
+ """
98
+ Abstract method to produce a sample of indices to select data
99
+ To be implemented in descendants
100
+
101
+ Args:
102
+ values (torch.Tensor): an array of size [n, k] that contains
103
+ estimated values (U, V, confidences);
104
+ n: number of channels (U, V, confidences)
105
+ k: number of points labeled with part_id
106
+ count (int): number of samples to produce, should be positive and <= k
107
+
108
+ Return:
109
+ list(int): indices of values (along axis 1) selected as a sample
110
+ """
111
+ raise NotImplementedError
112
+
113
+ def _produce_labels_and_results(self, instance: Instances) -> Tuple[torch.Tensor, torch.Tensor]:
114
+ """
115
+ Method to get labels and DensePose results from an instance
116
+
117
+ Args:
118
+ instance (Instances): an instance of `DensePoseChartPredictorOutput`
119
+
120
+ Return:
121
+ labels (torch.Tensor): shape [H, W], DensePose segmentation labels
122
+ dp_result (torch.Tensor): shape [2, H, W], stacked DensePose results u and v
123
+ """
124
+ converter = ToChartResultConverter
125
+ chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
126
+ labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
127
+ return labels, dp_result
128
+
129
+ def _resample_mask(self, output: Any) -> torch.Tensor:
130
+ """
131
+ Convert DensePose predictor output to segmentation annotation - tensors of size
132
+ (256, 256) and type `int64`.
133
+
134
+ Args:
135
+ output: DensePose predictor output with the following attributes:
136
+ - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
137
+ segmentation scores
138
+ - fine_segm: tensor of size [N, C, H, W] with unnormalized fine
139
+ segmentation scores
140
+ Return:
141
+ Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
142
+ where S = DensePoseDataRelative.MASK_SIZE
143
+ """
144
+ sz = DensePoseDataRelative.MASK_SIZE
145
+ S = (
146
+ F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
147
+ .argmax(dim=1)
148
+ .long()
149
+ )
150
+ I = (
151
+ (
152
+ F.interpolate(
153
+ output.fine_segm,
154
+ (sz, sz),
155
+ mode="bilinear",
156
+ align_corners=False,
157
+ ).argmax(dim=1)
158
+ * (S > 0).long()
159
+ )
160
+ .squeeze()
161
+ .cpu()
162
+ )
163
+ # Map fine segmentation results to coarse segmentation ground truth
164
+ # TODO: extract this into separate classes
165
+ # coarse segmentation: 1 = Torso, 2 = Right Hand, 3 = Left Hand,
166
+ # 4 = Left Foot, 5 = Right Foot, 6 = Upper Leg Right, 7 = Upper Leg Left,
167
+ # 8 = Lower Leg Right, 9 = Lower Leg Left, 10 = Upper Arm Left,
168
+ # 11 = Upper Arm Right, 12 = Lower Arm Left, 13 = Lower Arm Right,
169
+ # 14 = Head
170
+ # fine segmentation: 1, 2 = Torso, 3 = Right Hand, 4 = Left Hand,
171
+ # 5 = Left Foot, 6 = Right Foot, 7, 9 = Upper Leg Right,
172
+ # 8, 10 = Upper Leg Left, 11, 13 = Lower Leg Right,
173
+ # 12, 14 = Lower Leg Left, 15, 17 = Upper Arm Left,
174
+ # 16, 18 = Upper Arm Right, 19, 21 = Lower Arm Left,
175
+ # 20, 22 = Lower Arm Right, 23, 24 = Head
176
+ FINE_TO_COARSE_SEGMENTATION = {
177
+ 1: 1,
178
+ 2: 1,
179
+ 3: 2,
180
+ 4: 3,
181
+ 5: 4,
182
+ 6: 5,
183
+ 7: 6,
184
+ 8: 7,
185
+ 9: 6,
186
+ 10: 7,
187
+ 11: 8,
188
+ 12: 9,
189
+ 13: 8,
190
+ 14: 9,
191
+ 15: 10,
192
+ 16: 11,
193
+ 17: 10,
194
+ 18: 11,
195
+ 19: 12,
196
+ 20: 13,
197
+ 21: 12,
198
+ 22: 13,
199
+ 23: 14,
200
+ 24: 14,
201
+ }
202
+ mask = torch.zeros((sz, sz), dtype=torch.int64, device=torch.device("cpu"))
203
+ for i in range(DensePoseDataRelative.N_PART_LABELS):
204
+ mask[I == i + 1] = FINE_TO_COARSE_SEGMENTATION[i + 1]
205
+ return mask
densepose/data/samplers/densepose_confidence_based.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from typing import Optional, Tuple
7
+ import torch
8
+
9
+ from densepose.converters import ToChartResultConverterWithConfidences
10
+
11
+ from .densepose_base import DensePoseBaseSampler
12
+
13
+
14
+ class DensePoseConfidenceBasedSampler(DensePoseBaseSampler):
15
+ """
16
+ Samples DensePose data from DensePose predictions.
17
+ Samples for each class are drawn using confidence value estimates.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ confidence_channel: str,
23
+ count_per_class: int = 8,
24
+ search_count_multiplier: Optional[float] = None,
25
+ search_proportion: Optional[float] = None,
26
+ ):
27
+ """
28
+ Constructor
29
+
30
+ Args:
31
+ confidence_channel (str): confidence channel to use for sampling;
32
+ possible values:
33
+ "sigma_2": confidences for UV values
34
+ "fine_segm_confidence": confidences for fine segmentation
35
+ "coarse_segm_confidence": confidences for coarse segmentation
36
+ (default: "sigma_2")
37
+ count_per_class (int): the sampler produces at most `count_per_class`
38
+ samples for each category (default: 8)
39
+ search_count_multiplier (float or None): if not None, the total number
40
+ of the most confident estimates of a given class to consider is
41
+ defined as `min(search_count_multiplier * count_per_class, N)`,
42
+ where `N` is the total number of estimates of the class; cannot be
43
+ specified together with `search_proportion` (default: None)
44
+ search_proportion (float or None): if not None, the total number of the
45
+ of the most confident estimates of a given class to consider is
46
+ defined as `min(max(search_proportion * N, count_per_class), N)`,
47
+ where `N` is the total number of estimates of the class; cannot be
48
+ specified together with `search_count_multiplier` (default: None)
49
+ """
50
+ super().__init__(count_per_class)
51
+ self.confidence_channel = confidence_channel
52
+ self.search_count_multiplier = search_count_multiplier
53
+ self.search_proportion = search_proportion
54
+ assert (search_count_multiplier is None) or (search_proportion is None), (
55
+ f"Cannot specify both search_count_multiplier (={search_count_multiplier})"
56
+ f"and search_proportion (={search_proportion})"
57
+ )
58
+
59
+ def _produce_index_sample(self, values: torch.Tensor, count: int):
60
+ """
61
+ Produce a sample of indices to select data based on confidences
62
+
63
+ Args:
64
+ values (torch.Tensor): an array of size [n, k] that contains
65
+ estimated values (U, V, confidences);
66
+ n: number of channels (U, V, confidences)
67
+ k: number of points labeled with part_id
68
+ count (int): number of samples to produce, should be positive and <= k
69
+
70
+ Return:
71
+ list(int): indices of values (along axis 1) selected as a sample
72
+ """
73
+ k = values.shape[1]
74
+ if k == count:
75
+ index_sample = list(range(k))
76
+ else:
77
+ # take the best count * search_count_multiplier pixels,
78
+ # sample from them uniformly
79
+ # (here best = smallest variance)
80
+ _, sorted_confidence_indices = torch.sort(values[2])
81
+ if self.search_count_multiplier is not None:
82
+ search_count = min(int(count * self.search_count_multiplier), k)
83
+ elif self.search_proportion is not None:
84
+ search_count = min(max(int(k * self.search_proportion), count), k)
85
+ else:
86
+ search_count = min(count, k)
87
+ sample_from_top = random.sample(range(search_count), count)
88
+ index_sample = sorted_confidence_indices[:search_count][sample_from_top]
89
+ return index_sample
90
+
91
+ def _produce_labels_and_results(self, instance) -> Tuple[torch.Tensor, torch.Tensor]:
92
+ """
93
+ Method to get labels and DensePose results from an instance, with confidences
94
+
95
+ Args:
96
+ instance (Instances): an instance of `DensePoseChartPredictorOutputWithConfidences`
97
+
98
+ Return:
99
+ labels (torch.Tensor): shape [H, W], DensePose segmentation labels
100
+ dp_result (torch.Tensor): shape [3, H, W], DensePose results u and v
101
+ stacked with the confidence channel
102
+ """
103
+ converter = ToChartResultConverterWithConfidences
104
+ chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
105
+ labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
106
+ dp_result = torch.cat(
107
+ (dp_result, getattr(chart_result, self.confidence_channel)[None].cpu())
108
+ )
109
+
110
+ return labels, dp_result
densepose/data/samplers/densepose_cse_base.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from typing import Any, Dict, List, Tuple
6
+ import torch
7
+ from torch.nn import functional as F
8
+
9
+ from detectron2.config import CfgNode
10
+ from detectron2.structures import Instances
11
+
12
+ from densepose.converters.base import IntTupleBox
13
+ from densepose.data.utils import get_class_to_mesh_name_mapping
14
+ from densepose.modeling.cse.utils import squared_euclidean_distance_matrix
15
+ from densepose.structures import DensePoseDataRelative
16
+
17
+ from .densepose_base import DensePoseBaseSampler
18
+
19
+
20
+ class DensePoseCSEBaseSampler(DensePoseBaseSampler):
21
+ """
22
+ Base DensePose sampler to produce DensePose data from DensePose predictions.
23
+ Samples for each class are drawn according to some distribution over all pixels estimated
24
+ to belong to that class.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ cfg: CfgNode,
30
+ use_gt_categories: bool,
31
+ embedder: torch.nn.Module,
32
+ count_per_class: int = 8,
33
+ ):
34
+ """
35
+ Constructor
36
+
37
+ Args:
38
+ cfg (CfgNode): the config of the model
39
+ embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
40
+ count_per_class (int): the sampler produces at most `count_per_class`
41
+ samples for each category
42
+ """
43
+ super().__init__(count_per_class)
44
+ self.embedder = embedder
45
+ self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
46
+ self.use_gt_categories = use_gt_categories
47
+
48
+ def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
49
+ """
50
+ Sample DensPoseDataRelative from estimation results
51
+ """
52
+ if self.use_gt_categories:
53
+ instance_class = instance.dataset_classes.tolist()[0]
54
+ else:
55
+ instance_class = instance.pred_classes.tolist()[0]
56
+ mesh_name = self.class_to_mesh_name[instance_class]
57
+
58
+ annotation = {
59
+ DensePoseDataRelative.X_KEY: [],
60
+ DensePoseDataRelative.Y_KEY: [],
61
+ DensePoseDataRelative.VERTEX_IDS_KEY: [],
62
+ DensePoseDataRelative.MESH_NAME_KEY: mesh_name,
63
+ }
64
+
65
+ mask, embeddings, other_values = self._produce_mask_and_results(instance, bbox_xywh)
66
+ indices = torch.nonzero(mask, as_tuple=True)
67
+ selected_embeddings = embeddings.permute(1, 2, 0)[indices].cpu()
68
+ values = other_values[:, indices[0], indices[1]]
69
+ k = values.shape[1]
70
+
71
+ count = min(self.count_per_class, k)
72
+ if count <= 0:
73
+ return annotation
74
+
75
+ index_sample = self._produce_index_sample(values, count)
76
+ closest_vertices = squared_euclidean_distance_matrix(
77
+ selected_embeddings[index_sample], self.embedder(mesh_name)
78
+ )
79
+ closest_vertices = torch.argmin(closest_vertices, dim=1)
80
+
81
+ sampled_y = indices[0][index_sample] + 0.5
82
+ sampled_x = indices[1][index_sample] + 0.5
83
+ # prepare / normalize data
84
+ _, _, w, h = bbox_xywh
85
+ x = (sampled_x / w * 256.0).cpu().tolist()
86
+ y = (sampled_y / h * 256.0).cpu().tolist()
87
+ # extend annotations
88
+ annotation[DensePoseDataRelative.X_KEY].extend(x)
89
+ annotation[DensePoseDataRelative.Y_KEY].extend(y)
90
+ annotation[DensePoseDataRelative.VERTEX_IDS_KEY].extend(closest_vertices.cpu().tolist())
91
+ return annotation
92
+
93
+ def _produce_mask_and_results(
94
+ self, instance: Instances, bbox_xywh: IntTupleBox
95
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
96
+ """
97
+ Method to get labels and DensePose results from an instance
98
+
99
+ Args:
100
+ instance (Instances): an instance of `DensePoseEmbeddingPredictorOutput`
101
+ bbox_xywh (IntTupleBox): the corresponding bounding box
102
+
103
+ Return:
104
+ mask (torch.Tensor): shape [H, W], DensePose segmentation mask
105
+ embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W],
106
+ DensePose CSE Embeddings
107
+ other_values (Tuple[torch.Tensor]): a tensor of shape [0, H, W],
108
+ for potential other values
109
+ """
110
+ densepose_output = instance.pred_densepose
111
+ S = densepose_output.coarse_segm
112
+ E = densepose_output.embedding
113
+ _, _, w, h = bbox_xywh
114
+ embeddings = F.interpolate(E, size=(h, w), mode="bilinear")[0]
115
+ coarse_segm_resized = F.interpolate(S, size=(h, w), mode="bilinear")[0]
116
+ mask = coarse_segm_resized.argmax(0) > 0
117
+ other_values = torch.empty((0, h, w), device=E.device)
118
+ return mask, embeddings, other_values
119
+
120
+ def _resample_mask(self, output: Any) -> torch.Tensor:
121
+ """
122
+ Convert DensePose predictor output to segmentation annotation - tensors of size
123
+ (256, 256) and type `int64`.
124
+
125
+ Args:
126
+ output: DensePose predictor output with the following attributes:
127
+ - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
128
+ segmentation scores
129
+ Return:
130
+ Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
131
+ where S = DensePoseDataRelative.MASK_SIZE
132
+ """
133
+ sz = DensePoseDataRelative.MASK_SIZE
134
+ mask = (
135
+ F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
136
+ .argmax(dim=1)
137
+ .long()
138
+ .squeeze()
139
+ .cpu()
140
+ )
141
+ return mask
densepose/data/samplers/densepose_cse_confidence_based.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from typing import Optional, Tuple
7
+ import torch
8
+ from torch.nn import functional as F
9
+
10
+ from detectron2.config import CfgNode
11
+ from detectron2.structures import Instances
12
+
13
+ from densepose.converters.base import IntTupleBox
14
+
15
+ from .densepose_cse_base import DensePoseCSEBaseSampler
16
+
17
+
18
+ class DensePoseCSEConfidenceBasedSampler(DensePoseCSEBaseSampler):
19
+ """
20
+ Samples DensePose data from DensePose predictions.
21
+ Samples for each class are drawn using confidence value estimates.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ cfg: CfgNode,
27
+ use_gt_categories: bool,
28
+ embedder: torch.nn.Module,
29
+ confidence_channel: str,
30
+ count_per_class: int = 8,
31
+ search_count_multiplier: Optional[float] = None,
32
+ search_proportion: Optional[float] = None,
33
+ ):
34
+ """
35
+ Constructor
36
+
37
+ Args:
38
+ cfg (CfgNode): the config of the model
39
+ embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
40
+ confidence_channel (str): confidence channel to use for sampling;
41
+ possible values:
42
+ "coarse_segm_confidence": confidences for coarse segmentation
43
+ (default: "coarse_segm_confidence")
44
+ count_per_class (int): the sampler produces at most `count_per_class`
45
+ samples for each category (default: 8)
46
+ search_count_multiplier (float or None): if not None, the total number
47
+ of the most confident estimates of a given class to consider is
48
+ defined as `min(search_count_multiplier * count_per_class, N)`,
49
+ where `N` is the total number of estimates of the class; cannot be
50
+ specified together with `search_proportion` (default: None)
51
+ search_proportion (float or None): if not None, the total number of the
52
+ of the most confident estimates of a given class to consider is
53
+ defined as `min(max(search_proportion * N, count_per_class), N)`,
54
+ where `N` is the total number of estimates of the class; cannot be
55
+ specified together with `search_count_multiplier` (default: None)
56
+ """
57
+ super().__init__(cfg, use_gt_categories, embedder, count_per_class)
58
+ self.confidence_channel = confidence_channel
59
+ self.search_count_multiplier = search_count_multiplier
60
+ self.search_proportion = search_proportion
61
+ assert (search_count_multiplier is None) or (search_proportion is None), (
62
+ f"Cannot specify both search_count_multiplier (={search_count_multiplier})"
63
+ f"and search_proportion (={search_proportion})"
64
+ )
65
+
66
+ def _produce_index_sample(self, values: torch.Tensor, count: int):
67
+ """
68
+ Produce a sample of indices to select data based on confidences
69
+
70
+ Args:
71
+ values (torch.Tensor): a tensor of length k that contains confidences
72
+ k: number of points labeled with part_id
73
+ count (int): number of samples to produce, should be positive and <= k
74
+
75
+ Return:
76
+ list(int): indices of values (along axis 1) selected as a sample
77
+ """
78
+ k = values.shape[1]
79
+ if k == count:
80
+ index_sample = list(range(k))
81
+ else:
82
+ # take the best count * search_count_multiplier pixels,
83
+ # sample from them uniformly
84
+ # (here best = smallest variance)
85
+ _, sorted_confidence_indices = torch.sort(values[0])
86
+ if self.search_count_multiplier is not None:
87
+ search_count = min(int(count * self.search_count_multiplier), k)
88
+ elif self.search_proportion is not None:
89
+ search_count = min(max(int(k * self.search_proportion), count), k)
90
+ else:
91
+ search_count = min(count, k)
92
+ sample_from_top = random.sample(range(search_count), count)
93
+ index_sample = sorted_confidence_indices[-search_count:][sample_from_top]
94
+ return index_sample
95
+
96
+ def _produce_mask_and_results(
97
+ self, instance: Instances, bbox_xywh: IntTupleBox
98
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
99
+ """
100
+ Method to get labels and DensePose results from an instance
101
+
102
+ Args:
103
+ instance (Instances): an instance of
104
+ `DensePoseEmbeddingPredictorOutputWithConfidences`
105
+ bbox_xywh (IntTupleBox): the corresponding bounding box
106
+
107
+ Return:
108
+ mask (torch.Tensor): shape [H, W], DensePose segmentation mask
109
+ embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W]
110
+ DensePose CSE Embeddings
111
+ other_values: a tensor of shape [1, H, W], DensePose CSE confidence
112
+ """
113
+ _, _, w, h = bbox_xywh
114
+ densepose_output = instance.pred_densepose
115
+ mask, embeddings, _ = super()._produce_mask_and_results(instance, bbox_xywh)
116
+ other_values = F.interpolate(
117
+ getattr(densepose_output, self.confidence_channel),
118
+ size=(h, w),
119
+ mode="bilinear",
120
+ )[0].cpu()
121
+ return mask, embeddings, other_values
densepose/data/samplers/densepose_cse_uniform.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .densepose_cse_base import DensePoseCSEBaseSampler
6
+ from .densepose_uniform import DensePoseUniformSampler
7
+
8
+
9
+ class DensePoseCSEUniformSampler(DensePoseCSEBaseSampler, DensePoseUniformSampler):
10
+ """
11
+ Uniform Sampler for CSE
12
+ """
13
+
14
+ pass
densepose/data/samplers/densepose_uniform.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ import torch
7
+
8
+ from .densepose_base import DensePoseBaseSampler
9
+
10
+
11
+ class DensePoseUniformSampler(DensePoseBaseSampler):
12
+ """
13
+ Samples DensePose data from DensePose predictions.
14
+ Samples for each class are drawn uniformly over all pixels estimated
15
+ to belong to that class.
16
+ """
17
+
18
+ def __init__(self, count_per_class: int = 8):
19
+ """
20
+ Constructor
21
+
22
+ Args:
23
+ count_per_class (int): the sampler produces at most `count_per_class`
24
+ samples for each category
25
+ """
26
+ super().__init__(count_per_class)
27
+
28
+ def _produce_index_sample(self, values: torch.Tensor, count: int):
29
+ """
30
+ Produce a uniform sample of indices to select data
31
+
32
+ Args:
33
+ values (torch.Tensor): an array of size [n, k] that contains
34
+ estimated values (U, V, confidences);
35
+ n: number of channels (U, V, confidences)
36
+ k: number of points labeled with part_id
37
+ count (int): number of samples to produce, should be positive and <= k
38
+
39
+ Return:
40
+ list(int): indices of values (along axis 1) selected as a sample
41
+ """
42
+ k = values.shape[1]
43
+ return random.sample(range(k), count)
densepose/data/samplers/mask_from_densepose.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from detectron2.structures import BitMasks, Instances
6
+
7
+ from densepose.converters import ToMaskConverter
8
+
9
+
10
+ class MaskFromDensePoseSampler:
11
+ """
12
+ Produce mask GT from DensePose predictions
13
+ This sampler simply converts DensePose predictions to BitMasks
14
+ that a contain a bool tensor of the size of the input image
15
+ """
16
+
17
+ def __call__(self, instances: Instances) -> BitMasks:
18
+ """
19
+ Converts predicted data from `instances` into the GT mask data
20
+
21
+ Args:
22
+ instances (Instances): predicted results, expected to have `pred_densepose` field
23
+
24
+ Returns:
25
+ Boolean Tensor of the size of the input image that has non-zero
26
+ values at pixels that are estimated to belong to the detected object
27
+ """
28
+ return ToMaskConverter.convert(
29
+ instances.pred_densepose, instances.pred_boxes, instances.image_size
30
+ )
densepose/data/samplers/prediction_to_gt.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Callable, Dict, List, Optional
7
+
8
+ from detectron2.structures import Instances
9
+
10
+ ModelOutput = Dict[str, Any]
11
+ SampledData = Dict[str, Any]
12
+
13
+
14
+ @dataclass
15
+ class _Sampler:
16
+ """
17
+ Sampler registry entry that contains:
18
+ - src (str): source field to sample from (deleted after sampling)
19
+ - dst (Optional[str]): destination field to sample to, if not None
20
+ - func (Optional[Callable: Any -> Any]): function that performs sampling,
21
+ if None, reference copy is performed
22
+ """
23
+
24
+ src: str
25
+ dst: Optional[str]
26
+ func: Optional[Callable[[Any], Any]]
27
+
28
+
29
+ class PredictionToGroundTruthSampler:
30
+ """
31
+ Sampler implementation that converts predictions to GT using registered
32
+ samplers for different fields of `Instances`.
33
+ """
34
+
35
+ def __init__(self, dataset_name: str = ""):
36
+ self.dataset_name = dataset_name
37
+ self._samplers = {}
38
+ self.register_sampler("pred_boxes", "gt_boxes", None)
39
+ self.register_sampler("pred_classes", "gt_classes", None)
40
+ # delete scores
41
+ self.register_sampler("scores")
42
+
43
+ def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]:
44
+ """
45
+ Transform model output into ground truth data through sampling
46
+
47
+ Args:
48
+ model_output (Dict[str, Any]): model output
49
+ Returns:
50
+ Dict[str, Any]: sampled data
51
+ """
52
+ for model_output_i in model_output:
53
+ instances: Instances = model_output_i["instances"]
54
+ # transform data in each field
55
+ for _, sampler in self._samplers.items():
56
+ if not instances.has(sampler.src) or sampler.dst is None:
57
+ continue
58
+ if sampler.func is None:
59
+ instances.set(sampler.dst, instances.get(sampler.src))
60
+ else:
61
+ instances.set(sampler.dst, sampler.func(instances))
62
+ # delete model output data that was transformed
63
+ for _, sampler in self._samplers.items():
64
+ if sampler.src != sampler.dst and instances.has(sampler.src):
65
+ instances.remove(sampler.src)
66
+ model_output_i["dataset"] = self.dataset_name
67
+ return model_output
68
+
69
+ def register_sampler(
70
+ self,
71
+ prediction_attr: str,
72
+ gt_attr: Optional[str] = None,
73
+ func: Optional[Callable[[Any], Any]] = None,
74
+ ):
75
+ """
76
+ Register sampler for a field
77
+
78
+ Args:
79
+ prediction_attr (str): field to replace with a sampled value
80
+ gt_attr (Optional[str]): field to store the sampled value to, if not None
81
+ func (Optional[Callable: Any -> Any]): sampler function
82
+ """
83
+ self._samplers[(prediction_attr, gt_attr)] = _Sampler(
84
+ src=prediction_attr, dst=gt_attr, func=func
85
+ )
86
+
87
+ def remove_sampler(
88
+ self,
89
+ prediction_attr: str,
90
+ gt_attr: Optional[str] = None,
91
+ ):
92
+ """
93
+ Remove sampler for a field
94
+
95
+ Args:
96
+ prediction_attr (str): field to replace with a sampled value
97
+ gt_attr (Optional[str]): field to store the sampled value to, if not None
98
+ """
99
+ assert (prediction_attr, gt_attr) in self._samplers
100
+ del self._samplers[(prediction_attr, gt_attr)]
densepose/data/transform/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .image import ImageResizeTransform
densepose/data/transform/image.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import torch
6
+
7
+
8
+ class ImageResizeTransform:
9
+ """
10
+ Transform that resizes images loaded from a dataset
11
+ (BGR data in NCHW channel order, typically uint8) to a format ready to be
12
+ consumed by DensePose training (BGR float32 data in NCHW channel order)
13
+ """
14
+
15
+ def __init__(self, min_size: int = 800, max_size: int = 1333):
16
+ self.min_size = min_size
17
+ self.max_size = max_size
18
+
19
+ def __call__(self, images: torch.Tensor) -> torch.Tensor:
20
+ """
21
+ Args:
22
+ images (torch.Tensor): tensor of size [N, 3, H, W] that contains
23
+ BGR data (typically in uint8)
24
+ Returns:
25
+ images (torch.Tensor): tensor of size [N, 3, H1, W1] where
26
+ H1 and W1 are chosen to respect the specified min and max sizes
27
+ and preserve the original aspect ratio, the data channels
28
+ follow BGR order and the data type is `torch.float32`
29
+ """
30
+ # resize with min size
31
+ images = images.float()
32
+ min_size = min(images.shape[-2:])
33
+ max_size = max(images.shape[-2:])
34
+ scale = min(self.min_size / min_size, self.max_size / max_size)
35
+ images = torch.nn.functional.interpolate(
36
+ images,
37
+ scale_factor=scale,
38
+ mode="bilinear",
39
+ align_corners=False,
40
+ )
41
+ return images
densepose/data/utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import os
6
+ from typing import Dict, Optional
7
+
8
+ from detectron2.config import CfgNode
9
+
10
+
11
+ def is_relative_local_path(path: str) -> bool:
12
+ path_str = os.fsdecode(path)
13
+ return ("://" not in path_str) and not os.path.isabs(path)
14
+
15
+
16
+ def maybe_prepend_base_path(base_path: Optional[str], path: str):
17
+ """
18
+ Prepends the provided path with a base path prefix if:
19
+ 1) base path is not None;
20
+ 2) path is a local path
21
+ """
22
+ if base_path is None:
23
+ return path
24
+ if is_relative_local_path(path):
25
+ return os.path.join(base_path, path)
26
+ return path
27
+
28
+
29
+ def get_class_to_mesh_name_mapping(cfg: CfgNode) -> Dict[int, str]:
30
+ return {
31
+ int(class_id): mesh_name
32
+ for class_id, mesh_name in cfg.DATASETS.CLASS_TO_MESH_NAME_MAPPING.items()
33
+ }
34
+
35
+
36
+ def get_category_to_class_mapping(dataset_cfg: CfgNode) -> Dict[str, int]:
37
+ return {
38
+ category: int(class_id)
39
+ for category, class_id in dataset_cfg.CATEGORY_TO_CLASS_MAPPING.items()
40
+ }
densepose/data/video/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .frame_selector import (
6
+ FrameSelectionStrategy,
7
+ RandomKFramesSelector,
8
+ FirstKFramesSelector,
9
+ LastKFramesSelector,
10
+ FrameTsList,
11
+ FrameSelector,
12
+ )
13
+
14
+ from .video_keyframe_dataset import (
15
+ VideoKeyframeDataset,
16
+ video_list_from_file,
17
+ list_keyframes,
18
+ read_keyframes,
19
+ )
densepose/data/video/frame_selector.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import random
6
+ from collections.abc import Callable
7
+ from enum import Enum
8
+ from typing import Callable as TCallable
9
+ from typing import List
10
+
11
+ FrameTsList = List[int]
12
+ FrameSelector = TCallable[[FrameTsList], FrameTsList]
13
+
14
+
15
+ class FrameSelectionStrategy(Enum):
16
+ """
17
+ Frame selection strategy used with videos:
18
+ - "random_k": select k random frames
19
+ - "first_k": select k first frames
20
+ - "last_k": select k last frames
21
+ - "all": select all frames
22
+ """
23
+
24
+ # fmt: off
25
+ RANDOM_K = "random_k"
26
+ FIRST_K = "first_k"
27
+ LAST_K = "last_k"
28
+ ALL = "all"
29
+ # fmt: on
30
+
31
+
32
+ class RandomKFramesSelector(Callable): # pyre-ignore[39]
33
+ """
34
+ Selector that retains at most `k` random frames
35
+ """
36
+
37
+ def __init__(self, k: int):
38
+ self.k = k
39
+
40
+ def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
41
+ """
42
+ Select `k` random frames
43
+
44
+ Args:
45
+ frames_tss (List[int]): timestamps of input frames
46
+ Returns:
47
+ List[int]: timestamps of selected frames
48
+ """
49
+ return random.sample(frame_tss, min(self.k, len(frame_tss)))
50
+
51
+
52
+ class FirstKFramesSelector(Callable): # pyre-ignore[39]
53
+ """
54
+ Selector that retains at most `k` first frames
55
+ """
56
+
57
+ def __init__(self, k: int):
58
+ self.k = k
59
+
60
+ def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
61
+ """
62
+ Select `k` first frames
63
+
64
+ Args:
65
+ frames_tss (List[int]): timestamps of input frames
66
+ Returns:
67
+ List[int]: timestamps of selected frames
68
+ """
69
+ return frame_tss[: self.k]
70
+
71
+
72
+ class LastKFramesSelector(Callable): # pyre-ignore[39]
73
+ """
74
+ Selector that retains at most `k` last frames from video data
75
+ """
76
+
77
+ def __init__(self, k: int):
78
+ self.k = k
79
+
80
+ def __call__(self, frame_tss: FrameTsList) -> FrameTsList:
81
+ """
82
+ Select `k` last frames
83
+
84
+ Args:
85
+ frames_tss (List[int]): timestamps of input frames
86
+ Returns:
87
+ List[int]: timestamps of selected frames
88
+ """
89
+ return frame_tss[-self.k :]
densepose/data/video/video_keyframe_dataset.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ # pyre-unsafe
5
+
6
+ import csv
7
+ import logging
8
+ import numpy as np
9
+ from typing import Any, Callable, Dict, List, Optional, Union
10
+ import av
11
+ import torch
12
+ from torch.utils.data.dataset import Dataset
13
+
14
+ from detectron2.utils.file_io import PathManager
15
+
16
+ from ..utils import maybe_prepend_base_path
17
+ from .frame_selector import FrameSelector, FrameTsList
18
+
19
+ FrameList = List[av.frame.Frame] # pyre-ignore[16]
20
+ FrameTransform = Callable[[torch.Tensor], torch.Tensor]
21
+
22
+
23
+ def list_keyframes(video_fpath: str, video_stream_idx: int = 0) -> FrameTsList:
24
+ """
25
+ Traverses all keyframes of a video file. Returns a list of keyframe
26
+ timestamps. Timestamps are counts in timebase units.
27
+
28
+ Args:
29
+ video_fpath (str): Video file path
30
+ video_stream_idx (int): Video stream index (default: 0)
31
+ Returns:
32
+ List[int]: list of keyframe timestaps (timestamp is a count in timebase
33
+ units)
34
+ """
35
+ try:
36
+ with PathManager.open(video_fpath, "rb") as io:
37
+ # pyre-fixme[16]: Module `av` has no attribute `open`.
38
+ container = av.open(io, mode="r")
39
+ stream = container.streams.video[video_stream_idx]
40
+ keyframes = []
41
+ pts = -1
42
+ # Note: even though we request forward seeks for keyframes, sometimes
43
+ # a keyframe in backwards direction is returned. We introduce tolerance
44
+ # as a max count of ignored backward seeks
45
+ tolerance_backward_seeks = 2
46
+ while True:
47
+ try:
48
+ container.seek(pts + 1, backward=False, any_frame=False, stream=stream)
49
+ except av.AVError as e:
50
+ # the exception occurs when the video length is exceeded,
51
+ # we then return whatever data we've already collected
52
+ logger = logging.getLogger(__name__)
53
+ logger.debug(
54
+ f"List keyframes: Error seeking video file {video_fpath}, "
55
+ f"video stream {video_stream_idx}, pts {pts + 1}, AV error: {e}"
56
+ )
57
+ return keyframes
58
+ except OSError as e:
59
+ logger = logging.getLogger(__name__)
60
+ logger.warning(
61
+ f"List keyframes: Error seeking video file {video_fpath}, "
62
+ f"video stream {video_stream_idx}, pts {pts + 1}, OS error: {e}"
63
+ )
64
+ return []
65
+ packet = next(container.demux(video=video_stream_idx))
66
+ if packet.pts is not None and packet.pts <= pts:
67
+ logger = logging.getLogger(__name__)
68
+ logger.warning(
69
+ f"Video file {video_fpath}, stream {video_stream_idx}: "
70
+ f"bad seek for packet {pts + 1} (got packet {packet.pts}), "
71
+ f"tolerance {tolerance_backward_seeks}."
72
+ )
73
+ tolerance_backward_seeks -= 1
74
+ if tolerance_backward_seeks == 0:
75
+ return []
76
+ pts += 1
77
+ continue
78
+ tolerance_backward_seeks = 2
79
+ pts = packet.pts
80
+ if pts is None:
81
+ return keyframes
82
+ if packet.is_keyframe:
83
+ keyframes.append(pts)
84
+ return keyframes
85
+ except OSError as e:
86
+ logger = logging.getLogger(__name__)
87
+ logger.warning(
88
+ f"List keyframes: Error opening video file container {video_fpath}, " f"OS error: {e}"
89
+ )
90
+ except RuntimeError as e:
91
+ logger = logging.getLogger(__name__)
92
+ logger.warning(
93
+ f"List keyframes: Error opening video file container {video_fpath}, "
94
+ f"Runtime error: {e}"
95
+ )
96
+ return []
97
+
98
+
99
+ def read_keyframes(
100
+ video_fpath: str, keyframes: FrameTsList, video_stream_idx: int = 0
101
+ ) -> FrameList: # pyre-ignore[11]
102
+ """
103
+ Reads keyframe data from a video file.
104
+
105
+ Args:
106
+ video_fpath (str): Video file path
107
+ keyframes (List[int]): List of keyframe timestamps (as counts in
108
+ timebase units to be used in container seek operations)
109
+ video_stream_idx (int): Video stream index (default: 0)
110
+ Returns:
111
+ List[Frame]: list of frames that correspond to the specified timestamps
112
+ """
113
+ try:
114
+ with PathManager.open(video_fpath, "rb") as io:
115
+ # pyre-fixme[16]: Module `av` has no attribute `open`.
116
+ container = av.open(io)
117
+ stream = container.streams.video[video_stream_idx]
118
+ frames = []
119
+ for pts in keyframes:
120
+ try:
121
+ container.seek(pts, any_frame=False, stream=stream)
122
+ frame = next(container.decode(video=0))
123
+ frames.append(frame)
124
+ except av.AVError as e:
125
+ logger = logging.getLogger(__name__)
126
+ logger.warning(
127
+ f"Read keyframes: Error seeking video file {video_fpath}, "
128
+ f"video stream {video_stream_idx}, pts {pts}, AV error: {e}"
129
+ )
130
+ container.close()
131
+ return frames
132
+ except OSError as e:
133
+ logger = logging.getLogger(__name__)
134
+ logger.warning(
135
+ f"Read keyframes: Error seeking video file {video_fpath}, "
136
+ f"video stream {video_stream_idx}, pts {pts}, OS error: {e}"
137
+ )
138
+ container.close()
139
+ return frames
140
+ except StopIteration:
141
+ logger = logging.getLogger(__name__)
142
+ logger.warning(
143
+ f"Read keyframes: Error decoding frame from {video_fpath}, "
144
+ f"video stream {video_stream_idx}, pts {pts}"
145
+ )
146
+ container.close()
147
+ return frames
148
+
149
+ container.close()
150
+ return frames
151
+ except OSError as e:
152
+ logger = logging.getLogger(__name__)
153
+ logger.warning(
154
+ f"Read keyframes: Error opening video file container {video_fpath}, OS error: {e}"
155
+ )
156
+ except RuntimeError as e:
157
+ logger = logging.getLogger(__name__)
158
+ logger.warning(
159
+ f"Read keyframes: Error opening video file container {video_fpath}, Runtime error: {e}"
160
+ )
161
+ return []
162
+
163
+
164
+ def video_list_from_file(video_list_fpath: str, base_path: Optional[str] = None):
165
+ """
166
+ Create a list of paths to video files from a text file.
167
+
168
+ Args:
169
+ video_list_fpath (str): path to a plain text file with the list of videos
170
+ base_path (str): base path for entries from the video list (default: None)
171
+ """
172
+ video_list = []
173
+ with PathManager.open(video_list_fpath, "r") as io:
174
+ for line in io:
175
+ video_list.append(maybe_prepend_base_path(base_path, str(line.strip())))
176
+ return video_list
177
+
178
+
179
+ def read_keyframe_helper_data(fpath: str):
180
+ """
181
+ Read keyframe data from a file in CSV format: the header should contain
182
+ "video_id" and "keyframes" fields. Value specifications are:
183
+ video_id: int
184
+ keyframes: list(int)
185
+ Example of contents:
186
+ video_id,keyframes
187
+ 2,"[1,11,21,31,41,51,61,71,81]"
188
+
189
+ Args:
190
+ fpath (str): File containing keyframe data
191
+
192
+ Return:
193
+ video_id_to_keyframes (dict: int -> list(int)): for a given video ID it
194
+ contains a list of keyframes for that video
195
+ """
196
+ video_id_to_keyframes = {}
197
+ try:
198
+ with PathManager.open(fpath, "r") as io:
199
+ csv_reader = csv.reader(io)
200
+ header = next(csv_reader)
201
+ video_id_idx = header.index("video_id")
202
+ keyframes_idx = header.index("keyframes")
203
+ for row in csv_reader:
204
+ video_id = int(row[video_id_idx])
205
+ assert (
206
+ video_id not in video_id_to_keyframes
207
+ ), f"Duplicate keyframes entry for video {fpath}"
208
+ video_id_to_keyframes[video_id] = (
209
+ [int(v) for v in row[keyframes_idx][1:-1].split(",")]
210
+ if len(row[keyframes_idx]) > 2
211
+ else []
212
+ )
213
+ except Exception as e:
214
+ logger = logging.getLogger(__name__)
215
+ logger.warning(f"Error reading keyframe helper data from {fpath}: {e}")
216
+ return video_id_to_keyframes
217
+
218
+
219
+ class VideoKeyframeDataset(Dataset):
220
+ """
221
+ Dataset that provides keyframes for a set of videos.
222
+ """
223
+
224
+ _EMPTY_FRAMES = torch.empty((0, 3, 1, 1))
225
+
226
+ def __init__(
227
+ self,
228
+ video_list: List[str],
229
+ category_list: Union[str, List[str], None] = None,
230
+ frame_selector: Optional[FrameSelector] = None,
231
+ transform: Optional[FrameTransform] = None,
232
+ keyframe_helper_fpath: Optional[str] = None,
233
+ ):
234
+ """
235
+ Dataset constructor
236
+
237
+ Args:
238
+ video_list (List[str]): list of paths to video files
239
+ category_list (Union[str, List[str], None]): list of animal categories for each
240
+ video file. If it is a string, or None, this applies to all videos
241
+ frame_selector (Callable: KeyFrameList -> KeyFrameList):
242
+ selects keyframes to process, keyframes are given by
243
+ packet timestamps in timebase counts. If None, all keyframes
244
+ are selected (default: None)
245
+ transform (Callable: torch.Tensor -> torch.Tensor):
246
+ transforms a batch of RGB images (tensors of size [B, 3, H, W]),
247
+ returns a tensor of the same size. If None, no transform is
248
+ applied (default: None)
249
+
250
+ """
251
+ if type(category_list) is list:
252
+ self.category_list = category_list
253
+ else:
254
+ self.category_list = [category_list] * len(video_list)
255
+ assert len(video_list) == len(
256
+ self.category_list
257
+ ), "length of video and category lists must be equal"
258
+ self.video_list = video_list
259
+ self.frame_selector = frame_selector
260
+ self.transform = transform
261
+ self.keyframe_helper_data = (
262
+ read_keyframe_helper_data(keyframe_helper_fpath)
263
+ if keyframe_helper_fpath is not None
264
+ else None
265
+ )
266
+
267
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
268
+ """
269
+ Gets selected keyframes from a given video
270
+
271
+ Args:
272
+ idx (int): video index in the video list file
273
+ Returns:
274
+ A dictionary containing two keys:
275
+ images (torch.Tensor): tensor of size [N, H, W, 3] or of size
276
+ defined by the transform that contains keyframes data
277
+ categories (List[str]): categories of the frames
278
+ """
279
+ categories = [self.category_list[idx]]
280
+ fpath = self.video_list[idx]
281
+ keyframes = (
282
+ list_keyframes(fpath)
283
+ if self.keyframe_helper_data is None or idx not in self.keyframe_helper_data
284
+ else self.keyframe_helper_data[idx]
285
+ )
286
+ transform = self.transform
287
+ frame_selector = self.frame_selector
288
+ if not keyframes:
289
+ return {"images": self._EMPTY_FRAMES, "categories": []}
290
+ if frame_selector is not None:
291
+ keyframes = frame_selector(keyframes)
292
+ frames = read_keyframes(fpath, keyframes)
293
+ if not frames:
294
+ return {"images": self._EMPTY_FRAMES, "categories": []}
295
+ frames = np.stack([frame.to_rgb().to_ndarray() for frame in frames])
296
+ frames = torch.as_tensor(frames, device=torch.device("cpu"))
297
+ frames = frames[..., [2, 1, 0]] # RGB -> BGR
298
+ frames = frames.permute(0, 3, 1, 2).float() # NHWC -> NCHW
299
+ if transform is not None:
300
+ frames = transform(frames)
301
+ return {"images": frames, "categories": categories}
302
+
303
+ def __len__(self):
304
+ return len(self.video_list)
densepose/engine/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .trainer import Trainer
densepose/engine/trainer.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ import logging
6
+ import os
7
+ from collections import OrderedDict
8
+ from typing import List, Optional, Union
9
+ import torch
10
+ from torch import nn
11
+
12
+ from detectron2.checkpoint import DetectionCheckpointer
13
+ from detectron2.config import CfgNode
14
+ from detectron2.engine import DefaultTrainer
15
+ from detectron2.evaluation import (
16
+ DatasetEvaluator,
17
+ DatasetEvaluators,
18
+ inference_on_dataset,
19
+ print_csv_format,
20
+ )
21
+ from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping
22
+ from detectron2.utils import comm
23
+ from detectron2.utils.events import EventWriter, get_event_storage
24
+
25
+ from densepose import DensePoseDatasetMapperTTA, DensePoseGeneralizedRCNNWithTTA, load_from_cfg
26
+ from densepose.data import (
27
+ DatasetMapper,
28
+ build_combined_loader,
29
+ build_detection_test_loader,
30
+ build_detection_train_loader,
31
+ build_inference_based_loaders,
32
+ has_inference_based_loaders,
33
+ )
34
+ from densepose.evaluation.d2_evaluator_adapter import Detectron2COCOEvaluatorAdapter
35
+ from densepose.evaluation.evaluator import DensePoseCOCOEvaluator, build_densepose_evaluator_storage
36
+ from densepose.modeling.cse import Embedder
37
+
38
+
39
+ class SampleCountingLoader:
40
+ def __init__(self, loader):
41
+ self.loader = loader
42
+
43
+ def __iter__(self):
44
+ it = iter(self.loader)
45
+ storage = get_event_storage()
46
+ while True:
47
+ try:
48
+ batch = next(it)
49
+ num_inst_per_dataset = {}
50
+ for data in batch:
51
+ dataset_name = data["dataset"]
52
+ if dataset_name not in num_inst_per_dataset:
53
+ num_inst_per_dataset[dataset_name] = 0
54
+ num_inst = len(data["instances"])
55
+ num_inst_per_dataset[dataset_name] += num_inst
56
+ for dataset_name in num_inst_per_dataset:
57
+ storage.put_scalar(f"batch/{dataset_name}", num_inst_per_dataset[dataset_name])
58
+ yield batch
59
+ except StopIteration:
60
+ break
61
+
62
+
63
+ class SampleCountMetricPrinter(EventWriter):
64
+ def __init__(self):
65
+ self.logger = logging.getLogger(__name__)
66
+
67
+ def write(self):
68
+ storage = get_event_storage()
69
+ batch_stats_strs = []
70
+ for key, buf in storage.histories().items():
71
+ if key.startswith("batch/"):
72
+ batch_stats_strs.append(f"{key} {buf.avg(20)}")
73
+ self.logger.info(", ".join(batch_stats_strs))
74
+
75
+
76
+ class Trainer(DefaultTrainer):
77
+ @classmethod
78
+ def extract_embedder_from_model(cls, model: nn.Module) -> Optional[Embedder]:
79
+ if isinstance(model, nn.parallel.DistributedDataParallel):
80
+ model = model.module
81
+ if hasattr(model, "roi_heads") and hasattr(model.roi_heads, "embedder"):
82
+ return model.roi_heads.embedder
83
+ return None
84
+
85
+ # TODO: the only reason to copy the base class code here is to pass the embedder from
86
+ # the model to the evaluator; that should be refactored to avoid unnecessary copy-pasting
87
+ @classmethod
88
+ def test(
89
+ cls,
90
+ cfg: CfgNode,
91
+ model: nn.Module,
92
+ evaluators: Optional[Union[DatasetEvaluator, List[DatasetEvaluator]]] = None,
93
+ ):
94
+ """
95
+ Args:
96
+ cfg (CfgNode):
97
+ model (nn.Module):
98
+ evaluators (DatasetEvaluator, list[DatasetEvaluator] or None): if None, will call
99
+ :meth:`build_evaluator`. Otherwise, must have the same length as
100
+ ``cfg.DATASETS.TEST``.
101
+
102
+ Returns:
103
+ dict: a dict of result metrics
104
+ """
105
+ logger = logging.getLogger(__name__)
106
+ if isinstance(evaluators, DatasetEvaluator):
107
+ evaluators = [evaluators]
108
+ if evaluators is not None:
109
+ assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
110
+ len(cfg.DATASETS.TEST), len(evaluators)
111
+ )
112
+
113
+ results = OrderedDict()
114
+ for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
115
+ data_loader = cls.build_test_loader(cfg, dataset_name)
116
+ # When evaluators are passed in as arguments,
117
+ # implicitly assume that evaluators can be created before data_loader.
118
+ if evaluators is not None:
119
+ evaluator = evaluators[idx]
120
+ else:
121
+ try:
122
+ embedder = cls.extract_embedder_from_model(model)
123
+ evaluator = cls.build_evaluator(cfg, dataset_name, embedder=embedder)
124
+ except NotImplementedError:
125
+ logger.warn(
126
+ "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
127
+ "or implement its `build_evaluator` method."
128
+ )
129
+ results[dataset_name] = {}
130
+ continue
131
+ if cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE or comm.is_main_process():
132
+ results_i = inference_on_dataset(model, data_loader, evaluator)
133
+ else:
134
+ results_i = {}
135
+ results[dataset_name] = results_i
136
+ if comm.is_main_process():
137
+ assert isinstance(
138
+ results_i, dict
139
+ ), "Evaluator must return a dict on the main process. Got {} instead.".format(
140
+ results_i
141
+ )
142
+ logger.info("Evaluation results for {} in csv format:".format(dataset_name))
143
+ print_csv_format(results_i)
144
+
145
+ if len(results) == 1:
146
+ results = list(results.values())[0]
147
+ return results
148
+
149
+ @classmethod
150
+ def build_evaluator(
151
+ cls,
152
+ cfg: CfgNode,
153
+ dataset_name: str,
154
+ output_folder: Optional[str] = None,
155
+ embedder: Optional[Embedder] = None,
156
+ ) -> DatasetEvaluators:
157
+ if output_folder is None:
158
+ output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
159
+ evaluators = []
160
+ distributed = cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE
161
+ # Note: we currently use COCO evaluator for both COCO and LVIS datasets
162
+ # to have compatible metrics. LVIS bbox evaluator could also be used
163
+ # with an adapter to properly handle filtered / mapped categories
164
+ # evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
165
+ # if evaluator_type == "coco":
166
+ # evaluators.append(COCOEvaluator(dataset_name, output_dir=output_folder))
167
+ # elif evaluator_type == "lvis":
168
+ # evaluators.append(LVISEvaluator(dataset_name, output_dir=output_folder))
169
+ evaluators.append(
170
+ Detectron2COCOEvaluatorAdapter(
171
+ dataset_name, output_dir=output_folder, distributed=distributed
172
+ )
173
+ )
174
+ if cfg.MODEL.DENSEPOSE_ON:
175
+ storage = build_densepose_evaluator_storage(cfg, output_folder)
176
+ evaluators.append(
177
+ DensePoseCOCOEvaluator(
178
+ dataset_name,
179
+ distributed,
180
+ output_folder,
181
+ evaluator_type=cfg.DENSEPOSE_EVALUATION.TYPE,
182
+ min_iou_threshold=cfg.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD,
183
+ storage=storage,
184
+ embedder=embedder,
185
+ should_evaluate_mesh_alignment=cfg.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT,
186
+ mesh_alignment_mesh_names=cfg.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES,
187
+ )
188
+ )
189
+ return DatasetEvaluators(evaluators)
190
+
191
+ @classmethod
192
+ def build_optimizer(cls, cfg: CfgNode, model: nn.Module):
193
+ params = get_default_optimizer_params(
194
+ model,
195
+ base_lr=cfg.SOLVER.BASE_LR,
196
+ weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
197
+ bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
198
+ weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
199
+ overrides={
200
+ "features": {
201
+ "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR,
202
+ },
203
+ "embeddings": {
204
+ "lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR,
205
+ },
206
+ },
207
+ )
208
+ optimizer = torch.optim.SGD(
209
+ params,
210
+ cfg.SOLVER.BASE_LR,
211
+ momentum=cfg.SOLVER.MOMENTUM,
212
+ nesterov=cfg.SOLVER.NESTEROV,
213
+ weight_decay=cfg.SOLVER.WEIGHT_DECAY,
214
+ )
215
+ # pyre-fixme[6]: For 2nd param expected `Type[Optimizer]` but got `SGD`.
216
+ return maybe_add_gradient_clipping(cfg, optimizer)
217
+
218
+ @classmethod
219
+ def build_test_loader(cls, cfg: CfgNode, dataset_name):
220
+ return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False))
221
+
222
+ @classmethod
223
+ def build_train_loader(cls, cfg: CfgNode):
224
+ data_loader = build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True))
225
+ if not has_inference_based_loaders(cfg):
226
+ return data_loader
227
+ model = cls.build_model(cfg)
228
+ model.to(cfg.BOOTSTRAP_MODEL.DEVICE)
229
+ DetectionCheckpointer(model).resume_or_load(cfg.BOOTSTRAP_MODEL.WEIGHTS, resume=False)
230
+ inference_based_loaders, ratios = build_inference_based_loaders(cfg, model)
231
+ loaders = [data_loader] + inference_based_loaders
232
+ ratios = [1.0] + ratios
233
+ combined_data_loader = build_combined_loader(cfg, loaders, ratios)
234
+ sample_counting_loader = SampleCountingLoader(combined_data_loader)
235
+ return sample_counting_loader
236
+
237
+ def build_writers(self):
238
+ writers = super().build_writers()
239
+ writers.append(SampleCountMetricPrinter())
240
+ return writers
241
+
242
+ @classmethod
243
+ def test_with_TTA(cls, cfg: CfgNode, model):
244
+ logger = logging.getLogger("detectron2.trainer")
245
+ # In the end of training, run an evaluation with TTA
246
+ # Only support some R-CNN models.
247
+ logger.info("Running inference with test-time augmentation ...")
248
+ transform_data = load_from_cfg(cfg)
249
+ model = DensePoseGeneralizedRCNNWithTTA(
250
+ cfg, model, transform_data, DensePoseDatasetMapperTTA(cfg)
251
+ )
252
+ evaluators = [
253
+ cls.build_evaluator(
254
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
255
+ )
256
+ for name in cfg.DATASETS.TEST
257
+ ]
258
+ res = cls.test(cfg, model, evaluators) # pyre-ignore[6]
259
+ res = OrderedDict({k + "_TTA": v for k, v in res.items()})
260
+ return res
densepose/evaluation/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from .evaluator import DensePoseCOCOEvaluator
densepose/evaluation/d2_evaluator_adapter.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ from detectron2.data.catalog import Metadata
6
+ from detectron2.evaluation import COCOEvaluator
7
+
8
+ from densepose.data.datasets.coco import (
9
+ get_contiguous_id_to_category_id_map,
10
+ maybe_filter_categories_cocoapi,
11
+ )
12
+
13
+
14
+ def _maybe_add_iscrowd_annotations(cocoapi) -> None:
15
+ for ann in cocoapi.dataset["annotations"]:
16
+ if "iscrowd" not in ann:
17
+ ann["iscrowd"] = 0
18
+
19
+
20
+ class Detectron2COCOEvaluatorAdapter(COCOEvaluator):
21
+ def __init__(
22
+ self,
23
+ dataset_name,
24
+ output_dir=None,
25
+ distributed=True,
26
+ ):
27
+ super().__init__(dataset_name, output_dir=output_dir, distributed=distributed)
28
+ maybe_filter_categories_cocoapi(dataset_name, self._coco_api)
29
+ _maybe_add_iscrowd_annotations(self._coco_api)
30
+ # substitute category metadata to account for categories
31
+ # that are mapped to the same contiguous id
32
+ if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
33
+ self._maybe_substitute_metadata()
34
+
35
+ def _maybe_substitute_metadata(self):
36
+ cont_id_2_cat_id = get_contiguous_id_to_category_id_map(self._metadata)
37
+ cat_id_2_cont_id = self._metadata.thing_dataset_id_to_contiguous_id
38
+ if len(cont_id_2_cat_id) == len(cat_id_2_cont_id):
39
+ return
40
+
41
+ cat_id_2_cont_id_injective = {}
42
+ for cat_id, cont_id in cat_id_2_cont_id.items():
43
+ if (cont_id in cont_id_2_cat_id) and (cont_id_2_cat_id[cont_id] == cat_id):
44
+ cat_id_2_cont_id_injective[cat_id] = cont_id
45
+
46
+ metadata_new = Metadata(name=self._metadata.name)
47
+ for key, value in self._metadata.__dict__.items():
48
+ if key == "thing_dataset_id_to_contiguous_id":
49
+ setattr(metadata_new, key, cat_id_2_cont_id_injective)
50
+ else:
51
+ setattr(metadata_new, key, value)
52
+ self._metadata = metadata_new
densepose/evaluation/densepose_coco_evaluation.py ADDED
@@ -0,0 +1,1305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # This is a modified version of cocoeval.py where we also have the densepose evaluation.
7
+
8
+ # pyre-unsafe
9
+
10
+ __author__ = "tsungyi"
11
+
12
+ import copy
13
+ import datetime
14
+ import logging
15
+ import numpy as np
16
+ import pickle
17
+ import time
18
+ from collections import defaultdict
19
+ from enum import Enum
20
+ from typing import Any, Dict, Tuple
21
+ import scipy.spatial.distance as ssd
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from pycocotools import mask as maskUtils
25
+ from scipy.io import loadmat
26
+ from scipy.ndimage import zoom as spzoom
27
+
28
+ from detectron2.utils.file_io import PathManager
29
+
30
+ from densepose.converters.chart_output_to_chart_result import resample_uv_tensors_to_bbox
31
+ from densepose.converters.segm_to_mask import (
32
+ resample_coarse_segm_tensor_to_bbox,
33
+ resample_fine_and_coarse_segm_tensors_to_bbox,
34
+ )
35
+ from densepose.modeling.cse.utils import squared_euclidean_distance_matrix
36
+ from densepose.structures import DensePoseDataRelative
37
+ from densepose.structures.mesh import create_mesh
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class DensePoseEvalMode(str, Enum):
43
+ # use both masks and geodesic distances (GPS * IOU) to compute scores
44
+ GPSM = "gpsm"
45
+ # use only geodesic distances (GPS) to compute scores
46
+ GPS = "gps"
47
+ # use only masks (IOU) to compute scores
48
+ IOU = "iou"
49
+
50
+
51
+ class DensePoseDataMode(str, Enum):
52
+ # use estimated IUV data (default mode)
53
+ IUV_DT = "iuvdt"
54
+ # use ground truth IUV data
55
+ IUV_GT = "iuvgt"
56
+ # use ground truth labels I and set UV to 0
57
+ I_GT_UV_0 = "igtuv0"
58
+ # use ground truth labels I and estimated UV coordinates
59
+ I_GT_UV_DT = "igtuvdt"
60
+ # use estimated labels I and set UV to 0
61
+ I_DT_UV_0 = "idtuv0"
62
+
63
+
64
+ class DensePoseCocoEval:
65
+ # Interface for evaluating detection on the Microsoft COCO dataset.
66
+ #
67
+ # The usage for CocoEval is as follows:
68
+ # cocoGt=..., cocoDt=... # load dataset and results
69
+ # E = CocoEval(cocoGt,cocoDt); # initialize CocoEval object
70
+ # E.params.recThrs = ...; # set parameters as desired
71
+ # E.evaluate(); # run per image evaluation
72
+ # E.accumulate(); # accumulate per image results
73
+ # E.summarize(); # display summary metrics of results
74
+ # For example usage see evalDemo.m and http://mscoco.org/.
75
+ #
76
+ # The evaluation parameters are as follows (defaults in brackets):
77
+ # imgIds - [all] N img ids to use for evaluation
78
+ # catIds - [all] K cat ids to use for evaluation
79
+ # iouThrs - [.5:.05:.95] T=10 IoU thresholds for evaluation
80
+ # recThrs - [0:.01:1] R=101 recall thresholds for evaluation
81
+ # areaRng - [...] A=4 object area ranges for evaluation
82
+ # maxDets - [1 10 100] M=3 thresholds on max detections per image
83
+ # iouType - ['segm'] set iouType to 'segm', 'bbox', 'keypoints' or 'densepose'
84
+ # iouType replaced the now DEPRECATED useSegm parameter.
85
+ # useCats - [1] if true use category labels for evaluation
86
+ # Note: if useCats=0 category labels are ignored as in proposal scoring.
87
+ # Note: multiple areaRngs [Ax2] and maxDets [Mx1] can be specified.
88
+ #
89
+ # evaluate(): evaluates detections on every image and every category and
90
+ # concats the results into the "evalImgs" with fields:
91
+ # dtIds - [1xD] id for each of the D detections (dt)
92
+ # gtIds - [1xG] id for each of the G ground truths (gt)
93
+ # dtMatches - [TxD] matching gt id at each IoU or 0
94
+ # gtMatches - [TxG] matching dt id at each IoU or 0
95
+ # dtScores - [1xD] confidence of each dt
96
+ # gtIgnore - [1xG] ignore flag for each gt
97
+ # dtIgnore - [TxD] ignore flag for each dt at each IoU
98
+ #
99
+ # accumulate(): accumulates the per-image, per-category evaluation
100
+ # results in "evalImgs" into the dictionary "eval" with fields:
101
+ # params - parameters used for evaluation
102
+ # date - date evaluation was performed
103
+ # counts - [T,R,K,A,M] parameter dimensions (see above)
104
+ # precision - [TxRxKxAxM] precision for every evaluation setting
105
+ # recall - [TxKxAxM] max recall for every evaluation setting
106
+ # Note: precision and recall==-1 for settings with no gt objects.
107
+ #
108
+ # See also coco, mask, pycocoDemo, pycocoEvalDemo
109
+ #
110
+ # Microsoft COCO Toolbox. version 2.0
111
+ # Data, paper, and tutorials available at: http://mscoco.org/
112
+ # Code written by Piotr Dollar and Tsung-Yi Lin, 2015.
113
+ # Licensed under the Simplified BSD License [see coco/license.txt]
114
+ def __init__(
115
+ self,
116
+ cocoGt=None,
117
+ cocoDt=None,
118
+ iouType: str = "densepose",
119
+ multi_storage=None,
120
+ embedder=None,
121
+ dpEvalMode: DensePoseEvalMode = DensePoseEvalMode.GPS,
122
+ dpDataMode: DensePoseDataMode = DensePoseDataMode.IUV_DT,
123
+ ):
124
+ """
125
+ Initialize CocoEval using coco APIs for gt and dt
126
+ :param cocoGt: coco object with ground truth annotations
127
+ :param cocoDt: coco object with detection results
128
+ :return: None
129
+ """
130
+ self.cocoGt = cocoGt # ground truth COCO API
131
+ self.cocoDt = cocoDt # detections COCO API
132
+ self.multi_storage = multi_storage
133
+ self.embedder = embedder
134
+ self._dpEvalMode = dpEvalMode
135
+ self._dpDataMode = dpDataMode
136
+ self.evalImgs = defaultdict(list) # per-image per-category eval results [KxAxI]
137
+ self.eval = {} # accumulated evaluation results
138
+ self._gts = defaultdict(list) # gt for evaluation
139
+ self._dts = defaultdict(list) # dt for evaluation
140
+ self.params = Params(iouType=iouType) # parameters
141
+ self._paramsEval = {} # parameters for evaluation
142
+ self.stats = [] # result summarization
143
+ self.ious = {} # ious between all gts and dts
144
+ if cocoGt is not None:
145
+ self.params.imgIds = sorted(cocoGt.getImgIds())
146
+ self.params.catIds = sorted(cocoGt.getCatIds())
147
+ self.ignoreThrBB = 0.7
148
+ self.ignoreThrUV = 0.9
149
+
150
+ def _loadGEval(self):
151
+ smpl_subdiv_fpath = PathManager.get_local_path(
152
+ "https://dl.fbaipublicfiles.com/densepose/data/SMPL_subdiv.mat"
153
+ )
154
+ pdist_transform_fpath = PathManager.get_local_path(
155
+ "https://dl.fbaipublicfiles.com/densepose/data/SMPL_SUBDIV_TRANSFORM.mat"
156
+ )
157
+ pdist_matrix_fpath = PathManager.get_local_path(
158
+ "https://dl.fbaipublicfiles.com/densepose/data/Pdist_matrix.pkl", timeout_sec=120
159
+ )
160
+ SMPL_subdiv = loadmat(smpl_subdiv_fpath)
161
+ self.PDIST_transform = loadmat(pdist_transform_fpath)
162
+ self.PDIST_transform = self.PDIST_transform["index"].squeeze()
163
+ UV = np.array([SMPL_subdiv["U_subdiv"], SMPL_subdiv["V_subdiv"]]).squeeze()
164
+ ClosestVertInds = np.arange(UV.shape[1]) + 1
165
+ self.Part_UVs = []
166
+ self.Part_ClosestVertInds = []
167
+ for i in np.arange(24):
168
+ self.Part_UVs.append(UV[:, SMPL_subdiv["Part_ID_subdiv"].squeeze() == (i + 1)])
169
+ self.Part_ClosestVertInds.append(
170
+ ClosestVertInds[SMPL_subdiv["Part_ID_subdiv"].squeeze() == (i + 1)]
171
+ )
172
+
173
+ with open(pdist_matrix_fpath, "rb") as hFile:
174
+ arrays = pickle.load(hFile, encoding="latin1")
175
+ self.Pdist_matrix = arrays["Pdist_matrix"]
176
+ self.Part_ids = np.array(SMPL_subdiv["Part_ID_subdiv"].squeeze())
177
+ # Mean geodesic distances for parts.
178
+ self.Mean_Distances = np.array([0, 0.351, 0.107, 0.126, 0.237, 0.173, 0.142, 0.128, 0.150])
179
+ # Coarse Part labels.
180
+ self.CoarseParts = np.array(
181
+ [0, 1, 1, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8]
182
+ )
183
+
184
+ def _prepare(self):
185
+ """
186
+ Prepare ._gts and ._dts for evaluation based on params
187
+ :return: None
188
+ """
189
+
190
+ def _toMask(anns, coco):
191
+ # modify ann['segmentation'] by reference
192
+ for ann in anns:
193
+ # safeguard for invalid segmentation annotation;
194
+ # annotations containing empty lists exist in the posetrack
195
+ # dataset. This is not a correct segmentation annotation
196
+ # in terms of COCO format; we need to deal with it somehow
197
+ segm = ann["segmentation"]
198
+ if type(segm) is list and len(segm) == 0:
199
+ ann["segmentation"] = None
200
+ continue
201
+ rle = coco.annToRLE(ann)
202
+ ann["segmentation"] = rle
203
+
204
+ def _getIgnoreRegion(iid, coco):
205
+ img = coco.imgs[iid]
206
+
207
+ if "ignore_regions_x" not in img.keys():
208
+ return None
209
+
210
+ if len(img["ignore_regions_x"]) == 0:
211
+ return None
212
+
213
+ rgns_merged = [
214
+ [v for xy in zip(region_x, region_y) for v in xy]
215
+ for region_x, region_y in zip(img["ignore_regions_x"], img["ignore_regions_y"])
216
+ ]
217
+ rles = maskUtils.frPyObjects(rgns_merged, img["height"], img["width"])
218
+ rle = maskUtils.merge(rles)
219
+ return maskUtils.decode(rle)
220
+
221
+ def _checkIgnore(dt, iregion):
222
+ if iregion is None:
223
+ return True
224
+
225
+ bb = np.array(dt["bbox"]).astype(int)
226
+ x1, y1, x2, y2 = bb[0], bb[1], bb[0] + bb[2], bb[1] + bb[3]
227
+ x2 = min([x2, iregion.shape[1]])
228
+ y2 = min([y2, iregion.shape[0]])
229
+
230
+ if bb[2] * bb[3] == 0:
231
+ return False
232
+
233
+ crop_iregion = iregion[y1:y2, x1:x2]
234
+
235
+ if crop_iregion.sum() == 0:
236
+ return True
237
+
238
+ if "densepose" not in dt.keys(): # filtering boxes
239
+ return crop_iregion.sum() / bb[2] / bb[3] < self.ignoreThrBB
240
+
241
+ # filtering UVs
242
+ ignoremask = np.require(crop_iregion, requirements=["F"])
243
+ mask = self._extract_mask(dt)
244
+ uvmask = np.require(np.asarray(mask > 0), dtype=np.uint8, requirements=["F"])
245
+ uvmask_ = maskUtils.encode(uvmask)
246
+ ignoremask_ = maskUtils.encode(ignoremask)
247
+ uviou = maskUtils.iou([uvmask_], [ignoremask_], [1])[0]
248
+ return uviou < self.ignoreThrUV
249
+
250
+ p = self.params
251
+
252
+ if p.useCats:
253
+ gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
254
+ dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
255
+ else:
256
+ gts = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
257
+ dts = self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
258
+
259
+ imns = self.cocoGt.loadImgs(p.imgIds)
260
+ self.size_mapping = {}
261
+ for im in imns:
262
+ self.size_mapping[im["id"]] = [im["height"], im["width"]]
263
+
264
+ # if iouType == 'uv', add point gt annotations
265
+ if p.iouType == "densepose":
266
+ self._loadGEval()
267
+
268
+ # convert ground truth to mask if iouType == 'segm'
269
+ if p.iouType == "segm":
270
+ _toMask(gts, self.cocoGt)
271
+ _toMask(dts, self.cocoDt)
272
+
273
+ # set ignore flag
274
+ for gt in gts:
275
+ gt["ignore"] = gt["ignore"] if "ignore" in gt else 0
276
+ gt["ignore"] = "iscrowd" in gt and gt["iscrowd"]
277
+ if p.iouType == "keypoints":
278
+ gt["ignore"] = (gt["num_keypoints"] == 0) or gt["ignore"]
279
+ if p.iouType == "densepose":
280
+ gt["ignore"] = ("dp_x" in gt) == 0
281
+ if p.iouType == "segm":
282
+ gt["ignore"] = gt["segmentation"] is None
283
+
284
+ self._gts = defaultdict(list) # gt for evaluation
285
+ self._dts = defaultdict(list) # dt for evaluation
286
+ self._igrgns = defaultdict(list)
287
+
288
+ for gt in gts:
289
+ iid = gt["image_id"]
290
+ if iid not in self._igrgns.keys():
291
+ self._igrgns[iid] = _getIgnoreRegion(iid, self.cocoGt)
292
+ if _checkIgnore(gt, self._igrgns[iid]):
293
+ self._gts[iid, gt["category_id"]].append(gt)
294
+ for dt in dts:
295
+ iid = dt["image_id"]
296
+ if (iid not in self._igrgns) or _checkIgnore(dt, self._igrgns[iid]):
297
+ self._dts[iid, dt["category_id"]].append(dt)
298
+
299
+ self.evalImgs = defaultdict(list) # per-image per-category evaluation results
300
+ self.eval = {} # accumulated evaluation results
301
+
302
+ def evaluate(self):
303
+ """
304
+ Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
305
+ :return: None
306
+ """
307
+ tic = time.time()
308
+ logger.info("Running per image DensePose evaluation... {}".format(self.params.iouType))
309
+ p = self.params
310
+ # add backward compatibility if useSegm is specified in params
311
+ if p.useSegm is not None:
312
+ p.iouType = "segm" if p.useSegm == 1 else "bbox"
313
+ logger.info("useSegm (deprecated) is not None. Running DensePose evaluation")
314
+ p.imgIds = list(np.unique(p.imgIds))
315
+ if p.useCats:
316
+ p.catIds = list(np.unique(p.catIds))
317
+ p.maxDets = sorted(p.maxDets)
318
+ self.params = p
319
+
320
+ self._prepare()
321
+ # loop through images, area range, max detection number
322
+ catIds = p.catIds if p.useCats else [-1]
323
+
324
+ if p.iouType in ["segm", "bbox"]:
325
+ computeIoU = self.computeIoU
326
+ elif p.iouType == "keypoints":
327
+ computeIoU = self.computeOks
328
+ elif p.iouType == "densepose":
329
+ computeIoU = self.computeOgps
330
+ if self._dpEvalMode in {DensePoseEvalMode.GPSM, DensePoseEvalMode.IOU}:
331
+ self.real_ious = {
332
+ (imgId, catId): self.computeDPIoU(imgId, catId)
333
+ for imgId in p.imgIds
334
+ for catId in catIds
335
+ }
336
+
337
+ self.ious = {
338
+ (imgId, catId): computeIoU(imgId, catId) for imgId in p.imgIds for catId in catIds
339
+ }
340
+
341
+ evaluateImg = self.evaluateImg
342
+ maxDet = p.maxDets[-1]
343
+ self.evalImgs = [
344
+ evaluateImg(imgId, catId, areaRng, maxDet)
345
+ for catId in catIds
346
+ for areaRng in p.areaRng
347
+ for imgId in p.imgIds
348
+ ]
349
+ self._paramsEval = copy.deepcopy(self.params)
350
+ toc = time.time()
351
+ logger.info("DensePose evaluation DONE (t={:0.2f}s).".format(toc - tic))
352
+
353
+ def getDensePoseMask(self, polys):
354
+ maskGen = np.zeros([256, 256])
355
+ stop = min(len(polys) + 1, 15)
356
+ for i in range(1, stop):
357
+ if polys[i - 1]:
358
+ currentMask = maskUtils.decode(polys[i - 1])
359
+ maskGen[currentMask > 0] = i
360
+ return maskGen
361
+
362
+ def _generate_rlemask_on_image(self, mask, imgId, data):
363
+ bbox_xywh = np.array(data["bbox"])
364
+ x, y, w, h = bbox_xywh
365
+ im_h, im_w = self.size_mapping[imgId]
366
+ im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
367
+ if mask is not None:
368
+ x0 = max(int(x), 0)
369
+ x1 = min(int(x + w), im_w, int(x) + mask.shape[1])
370
+ y0 = max(int(y), 0)
371
+ y1 = min(int(y + h), im_h, int(y) + mask.shape[0])
372
+ y = int(y)
373
+ x = int(x)
374
+ im_mask[y0:y1, x0:x1] = mask[y0 - y : y1 - y, x0 - x : x1 - x]
375
+ im_mask = np.require(np.asarray(im_mask > 0), dtype=np.uint8, requirements=["F"])
376
+ rle_mask = maskUtils.encode(np.array(im_mask[:, :, np.newaxis], order="F"))[0]
377
+ return rle_mask
378
+
379
+ def computeDPIoU(self, imgId, catId):
380
+ p = self.params
381
+ if p.useCats:
382
+ gt = self._gts[imgId, catId]
383
+ dt = self._dts[imgId, catId]
384
+ else:
385
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
386
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
387
+ if len(gt) == 0 and len(dt) == 0:
388
+ return []
389
+ inds = np.argsort([-d["score"] for d in dt], kind="mergesort")
390
+ dt = [dt[i] for i in inds]
391
+ if len(dt) > p.maxDets[-1]:
392
+ dt = dt[0 : p.maxDets[-1]]
393
+
394
+ gtmasks = []
395
+ for g in gt:
396
+ if DensePoseDataRelative.S_KEY in g:
397
+ # convert DensePose mask to a binary mask
398
+ mask = np.minimum(self.getDensePoseMask(g[DensePoseDataRelative.S_KEY]), 1.0)
399
+ _, _, w, h = g["bbox"]
400
+ scale_x = float(max(w, 1)) / mask.shape[1]
401
+ scale_y = float(max(h, 1)) / mask.shape[0]
402
+ mask = spzoom(mask, (scale_y, scale_x), order=1, prefilter=False)
403
+ mask = np.array(mask > 0.5, dtype=np.uint8)
404
+ rle_mask = self._generate_rlemask_on_image(mask, imgId, g)
405
+ elif "segmentation" in g:
406
+ segmentation = g["segmentation"]
407
+ if isinstance(segmentation, list) and segmentation:
408
+ # polygons
409
+ im_h, im_w = self.size_mapping[imgId]
410
+ rles = maskUtils.frPyObjects(segmentation, im_h, im_w)
411
+ rle_mask = maskUtils.merge(rles)
412
+ elif isinstance(segmentation, dict):
413
+ if isinstance(segmentation["counts"], list):
414
+ # uncompressed RLE
415
+ im_h, im_w = self.size_mapping[imgId]
416
+ rle_mask = maskUtils.frPyObjects(segmentation, im_h, im_w)
417
+ else:
418
+ # compressed RLE
419
+ rle_mask = segmentation
420
+ else:
421
+ rle_mask = self._generate_rlemask_on_image(None, imgId, g)
422
+ else:
423
+ rle_mask = self._generate_rlemask_on_image(None, imgId, g)
424
+ gtmasks.append(rle_mask)
425
+
426
+ dtmasks = []
427
+ for d in dt:
428
+ mask = self._extract_mask(d)
429
+ mask = np.require(np.asarray(mask > 0), dtype=np.uint8, requirements=["F"])
430
+ rle_mask = self._generate_rlemask_on_image(mask, imgId, d)
431
+ dtmasks.append(rle_mask)
432
+
433
+ # compute iou between each dt and gt region
434
+ iscrowd = [int(o.get("iscrowd", 0)) for o in gt]
435
+ iousDP = maskUtils.iou(dtmasks, gtmasks, iscrowd)
436
+ return iousDP
437
+
438
+ def computeIoU(self, imgId, catId):
439
+ p = self.params
440
+ if p.useCats:
441
+ gt = self._gts[imgId, catId]
442
+ dt = self._dts[imgId, catId]
443
+ else:
444
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
445
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
446
+ if len(gt) == 0 and len(dt) == 0:
447
+ return []
448
+ inds = np.argsort([-d["score"] for d in dt], kind="mergesort")
449
+ dt = [dt[i] for i in inds]
450
+ if len(dt) > p.maxDets[-1]:
451
+ dt = dt[0 : p.maxDets[-1]]
452
+
453
+ if p.iouType == "segm":
454
+ g = [g["segmentation"] for g in gt if g["segmentation"] is not None]
455
+ d = [d["segmentation"] for d in dt if d["segmentation"] is not None]
456
+ elif p.iouType == "bbox":
457
+ g = [g["bbox"] for g in gt]
458
+ d = [d["bbox"] for d in dt]
459
+ else:
460
+ raise Exception("unknown iouType for iou computation")
461
+
462
+ # compute iou between each dt and gt region
463
+ iscrowd = [int(o.get("iscrowd", 0)) for o in gt]
464
+ ious = maskUtils.iou(d, g, iscrowd)
465
+ return ious
466
+
467
+ def computeOks(self, imgId, catId):
468
+ p = self.params
469
+ # dimension here should be Nxm
470
+ gts = self._gts[imgId, catId]
471
+ dts = self._dts[imgId, catId]
472
+ inds = np.argsort([-d["score"] for d in dts], kind="mergesort")
473
+ dts = [dts[i] for i in inds]
474
+ if len(dts) > p.maxDets[-1]:
475
+ dts = dts[0 : p.maxDets[-1]]
476
+ # if len(gts) == 0 and len(dts) == 0:
477
+ if len(gts) == 0 or len(dts) == 0:
478
+ return []
479
+ ious = np.zeros((len(dts), len(gts)))
480
+ sigmas = (
481
+ np.array(
482
+ [
483
+ 0.26,
484
+ 0.25,
485
+ 0.25,
486
+ 0.35,
487
+ 0.35,
488
+ 0.79,
489
+ 0.79,
490
+ 0.72,
491
+ 0.72,
492
+ 0.62,
493
+ 0.62,
494
+ 1.07,
495
+ 1.07,
496
+ 0.87,
497
+ 0.87,
498
+ 0.89,
499
+ 0.89,
500
+ ]
501
+ )
502
+ / 10.0
503
+ )
504
+ vars = (sigmas * 2) ** 2
505
+ k = len(sigmas)
506
+ # compute oks between each detection and ground truth object
507
+ for j, gt in enumerate(gts):
508
+ # create bounds for ignore regions(double the gt bbox)
509
+ g = np.array(gt["keypoints"])
510
+ xg = g[0::3]
511
+ yg = g[1::3]
512
+ vg = g[2::3]
513
+ k1 = np.count_nonzero(vg > 0)
514
+ bb = gt["bbox"]
515
+ x0 = bb[0] - bb[2]
516
+ x1 = bb[0] + bb[2] * 2
517
+ y0 = bb[1] - bb[3]
518
+ y1 = bb[1] + bb[3] * 2
519
+ for i, dt in enumerate(dts):
520
+ d = np.array(dt["keypoints"])
521
+ xd = d[0::3]
522
+ yd = d[1::3]
523
+ if k1 > 0:
524
+ # measure the per-keypoint distance if keypoints visible
525
+ dx = xd - xg
526
+ dy = yd - yg
527
+ else:
528
+ # measure minimum distance to keypoints in (x0,y0) & (x1,y1)
529
+ z = np.zeros(k)
530
+ dx = np.max((z, x0 - xd), axis=0) + np.max((z, xd - x1), axis=0)
531
+ dy = np.max((z, y0 - yd), axis=0) + np.max((z, yd - y1), axis=0)
532
+ e = (dx**2 + dy**2) / vars / (gt["area"] + np.spacing(1)) / 2
533
+ if k1 > 0:
534
+ e = e[vg > 0]
535
+ ious[i, j] = np.sum(np.exp(-e)) / e.shape[0]
536
+ return ious
537
+
538
+ def _extract_mask(self, dt: Dict[str, Any]) -> np.ndarray:
539
+ if "densepose" in dt:
540
+ densepose_results_quantized = dt["densepose"]
541
+ return densepose_results_quantized.labels_uv_uint8[0].numpy()
542
+ elif "cse_mask" in dt:
543
+ return dt["cse_mask"]
544
+ elif "coarse_segm" in dt:
545
+ dy = max(int(dt["bbox"][3]), 1)
546
+ dx = max(int(dt["bbox"][2]), 1)
547
+ return (
548
+ F.interpolate(
549
+ dt["coarse_segm"].unsqueeze(0),
550
+ (dy, dx),
551
+ mode="bilinear",
552
+ align_corners=False,
553
+ )
554
+ .squeeze(0)
555
+ .argmax(0)
556
+ .numpy()
557
+ .astype(np.uint8)
558
+ )
559
+ elif "record_id" in dt:
560
+ assert (
561
+ self.multi_storage is not None
562
+ ), f"Storage record id encountered in a detection {dt}, but no storage provided!"
563
+ record = self.multi_storage.get(dt["rank"], dt["record_id"])
564
+ coarse_segm = record["coarse_segm"]
565
+ dy = max(int(dt["bbox"][3]), 1)
566
+ dx = max(int(dt["bbox"][2]), 1)
567
+ return (
568
+ F.interpolate(
569
+ coarse_segm.unsqueeze(0),
570
+ (dy, dx),
571
+ mode="bilinear",
572
+ align_corners=False,
573
+ )
574
+ .squeeze(0)
575
+ .argmax(0)
576
+ .numpy()
577
+ .astype(np.uint8)
578
+ )
579
+ else:
580
+ raise Exception(f"No mask data in the detection: {dt}")
581
+ raise ValueError('The prediction dict needs to contain either "densepose" or "cse_mask"')
582
+
583
+ def _extract_iuv(
584
+ self, densepose_data: np.ndarray, py: np.ndarray, px: np.ndarray, gt: Dict[str, Any]
585
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
586
+ """
587
+ Extract arrays of I, U and V values at given points as numpy arrays
588
+ given the data mode stored in self._dpDataMode
589
+ """
590
+ if self._dpDataMode == DensePoseDataMode.IUV_DT:
591
+ # estimated labels and UV (default)
592
+ ipoints = densepose_data[0, py, px]
593
+ upoints = densepose_data[1, py, px] / 255.0 # convert from uint8 by /255.
594
+ vpoints = densepose_data[2, py, px] / 255.0
595
+ elif self._dpDataMode == DensePoseDataMode.IUV_GT:
596
+ # ground truth
597
+ ipoints = np.array(gt["dp_I"])
598
+ upoints = np.array(gt["dp_U"])
599
+ vpoints = np.array(gt["dp_V"])
600
+ elif self._dpDataMode == DensePoseDataMode.I_GT_UV_0:
601
+ # ground truth labels, UV = 0
602
+ ipoints = np.array(gt["dp_I"])
603
+ upoints = upoints * 0.0
604
+ vpoints = vpoints * 0.0
605
+ elif self._dpDataMode == DensePoseDataMode.I_GT_UV_DT:
606
+ # ground truth labels, estimated UV
607
+ ipoints = np.array(gt["dp_I"])
608
+ upoints = densepose_data[1, py, px] / 255.0 # convert from uint8 by /255.
609
+ vpoints = densepose_data[2, py, px] / 255.0
610
+ elif self._dpDataMode == DensePoseDataMode.I_DT_UV_0:
611
+ # estimated labels, UV = 0
612
+ ipoints = densepose_data[0, py, px]
613
+ upoints = upoints * 0.0
614
+ vpoints = vpoints * 0.0
615
+ else:
616
+ raise ValueError(f"Unknown data mode: {self._dpDataMode}")
617
+ return ipoints, upoints, vpoints
618
+
619
+ def computeOgps_single_pair(self, dt, gt, py, px, pt_mask):
620
+ if "densepose" in dt:
621
+ ipoints, upoints, vpoints = self.extract_iuv_from_quantized(dt, gt, py, px, pt_mask)
622
+ return self.computeOgps_single_pair_iuv(dt, gt, ipoints, upoints, vpoints)
623
+ elif "u" in dt:
624
+ ipoints, upoints, vpoints = self.extract_iuv_from_raw(dt, gt, py, px, pt_mask)
625
+ return self.computeOgps_single_pair_iuv(dt, gt, ipoints, upoints, vpoints)
626
+ elif "record_id" in dt:
627
+ assert (
628
+ self.multi_storage is not None
629
+ ), f"Storage record id encountered in detection {dt}, but no storage provided!"
630
+ record = self.multi_storage.get(dt["rank"], dt["record_id"])
631
+ record["bbox"] = dt["bbox"]
632
+ if "u" in record:
633
+ ipoints, upoints, vpoints = self.extract_iuv_from_raw(record, gt, py, px, pt_mask)
634
+ return self.computeOgps_single_pair_iuv(dt, gt, ipoints, upoints, vpoints)
635
+ elif "embedding" in record:
636
+ return self.computeOgps_single_pair_cse(
637
+ dt,
638
+ gt,
639
+ py,
640
+ px,
641
+ pt_mask,
642
+ record["coarse_segm"],
643
+ record["embedding"],
644
+ record["bbox"],
645
+ )
646
+ else:
647
+ raise Exception(f"Unknown record format: {record}")
648
+ elif "embedding" in dt:
649
+ return self.computeOgps_single_pair_cse(
650
+ dt, gt, py, px, pt_mask, dt["coarse_segm"], dt["embedding"], dt["bbox"]
651
+ )
652
+ raise Exception(f"Unknown detection format: {dt}")
653
+
654
+ def extract_iuv_from_quantized(self, dt, gt, py, px, pt_mask):
655
+ densepose_results_quantized = dt["densepose"]
656
+ ipoints, upoints, vpoints = self._extract_iuv(
657
+ densepose_results_quantized.labels_uv_uint8.numpy(), py, px, gt
658
+ )
659
+ ipoints[pt_mask == -1] = 0
660
+ return ipoints, upoints, vpoints
661
+
662
+ def extract_iuv_from_raw(self, dt, gt, py, px, pt_mask):
663
+ labels_dt = resample_fine_and_coarse_segm_tensors_to_bbox(
664
+ dt["fine_segm"].unsqueeze(0),
665
+ dt["coarse_segm"].unsqueeze(0),
666
+ dt["bbox"],
667
+ )
668
+ uv = resample_uv_tensors_to_bbox(
669
+ dt["u"].unsqueeze(0), dt["v"].unsqueeze(0), labels_dt.squeeze(0), dt["bbox"]
670
+ )
671
+ labels_uv_uint8 = torch.cat((labels_dt.byte(), (uv * 255).clamp(0, 255).byte()))
672
+ ipoints, upoints, vpoints = self._extract_iuv(labels_uv_uint8.numpy(), py, px, gt)
673
+ ipoints[pt_mask == -1] = 0
674
+ return ipoints, upoints, vpoints
675
+
676
+ def computeOgps_single_pair_iuv(self, dt, gt, ipoints, upoints, vpoints):
677
+ cVertsGT, ClosestVertsGTTransformed = self.findAllClosestVertsGT(gt)
678
+ cVerts = self.findAllClosestVertsUV(upoints, vpoints, ipoints)
679
+ # Get pairwise geodesic distances between gt and estimated mesh points.
680
+ dist = self.getDistancesUV(ClosestVertsGTTransformed, cVerts)
681
+ # Compute the Ogps measure.
682
+ # Find the mean geodesic normalization distance for
683
+ # each GT point, based on which part it is on.
684
+ Current_Mean_Distances = self.Mean_Distances[
685
+ self.CoarseParts[self.Part_ids[cVertsGT[cVertsGT > 0].astype(int) - 1]]
686
+ ]
687
+ return dist, Current_Mean_Distances
688
+
689
+ def computeOgps_single_pair_cse(
690
+ self, dt, gt, py, px, pt_mask, coarse_segm, embedding, bbox_xywh_abs
691
+ ):
692
+ # 0-based mesh vertex indices
693
+ cVertsGT = torch.as_tensor(gt["dp_vertex"], dtype=torch.int64)
694
+ # label for each pixel of the bbox, [H, W] tensor of long
695
+ labels_dt = resample_coarse_segm_tensor_to_bbox(
696
+ coarse_segm.unsqueeze(0), bbox_xywh_abs
697
+ ).squeeze(0)
698
+ x, y, w, h = bbox_xywh_abs
699
+ # embedding for each pixel of the bbox, [D, H, W] tensor of float32
700
+ embedding = F.interpolate(
701
+ embedding.unsqueeze(0), (int(h), int(w)), mode="bilinear", align_corners=False
702
+ ).squeeze(0)
703
+ # valid locations py, px
704
+ py_pt = torch.from_numpy(py[pt_mask > -1])
705
+ px_pt = torch.from_numpy(px[pt_mask > -1])
706
+ cVerts = torch.ones_like(cVertsGT) * -1
707
+ cVerts[pt_mask > -1] = self.findClosestVertsCse(
708
+ embedding, py_pt, px_pt, labels_dt, gt["ref_model"]
709
+ )
710
+ # Get pairwise geodesic distances between gt and estimated mesh points.
711
+ dist = self.getDistancesCse(cVertsGT, cVerts, gt["ref_model"])
712
+ # normalize distances
713
+ if (gt["ref_model"] == "smpl_27554") and ("dp_I" in gt):
714
+ Current_Mean_Distances = self.Mean_Distances[
715
+ self.CoarseParts[np.array(gt["dp_I"], dtype=int)]
716
+ ]
717
+ else:
718
+ Current_Mean_Distances = 0.255
719
+ return dist, Current_Mean_Distances
720
+
721
+ def computeOgps(self, imgId, catId):
722
+ p = self.params
723
+ # dimension here should be Nxm
724
+ g = self._gts[imgId, catId]
725
+ d = self._dts[imgId, catId]
726
+ inds = np.argsort([-d_["score"] for d_ in d], kind="mergesort")
727
+ d = [d[i] for i in inds]
728
+ if len(d) > p.maxDets[-1]:
729
+ d = d[0 : p.maxDets[-1]]
730
+ # if len(gts) == 0 and len(dts) == 0:
731
+ if len(g) == 0 or len(d) == 0:
732
+ return []
733
+ ious = np.zeros((len(d), len(g)))
734
+ # compute opgs between each detection and ground truth object
735
+ # sigma = self.sigma #0.255 # dist = 0.3m corresponds to ogps = 0.5
736
+ # 1 # dist = 0.3m corresponds to ogps = 0.96
737
+ # 1.45 # dist = 1.7m (person height) corresponds to ogps = 0.5)
738
+ for j, gt in enumerate(g):
739
+ if not gt["ignore"]:
740
+ g_ = gt["bbox"]
741
+ for i, dt in enumerate(d):
742
+ #
743
+ dy = int(dt["bbox"][3])
744
+ dx = int(dt["bbox"][2])
745
+ dp_x = np.array(gt["dp_x"]) * g_[2] / 255.0
746
+ dp_y = np.array(gt["dp_y"]) * g_[3] / 255.0
747
+ py = (dp_y + g_[1] - dt["bbox"][1]).astype(int)
748
+ px = (dp_x + g_[0] - dt["bbox"][0]).astype(int)
749
+ #
750
+ pts = np.zeros(len(px))
751
+ pts[px >= dx] = -1
752
+ pts[py >= dy] = -1
753
+ pts[px < 0] = -1
754
+ pts[py < 0] = -1
755
+ if len(pts) < 1:
756
+ ogps = 0.0
757
+ elif np.max(pts) == -1:
758
+ ogps = 0.0
759
+ else:
760
+ px[pts == -1] = 0
761
+ py[pts == -1] = 0
762
+ dists_between_matches, dist_norm_coeffs = self.computeOgps_single_pair(
763
+ dt, gt, py, px, pts
764
+ )
765
+ # Compute gps
766
+ ogps_values = np.exp(
767
+ -(dists_between_matches**2) / (2 * (dist_norm_coeffs**2))
768
+ )
769
+ #
770
+ ogps = np.mean(ogps_values) if len(ogps_values) > 0 else 0.0
771
+ ious[i, j] = ogps
772
+
773
+ gbb = [gt["bbox"] for gt in g]
774
+ dbb = [dt["bbox"] for dt in d]
775
+
776
+ # compute iou between each dt and gt region
777
+ iscrowd = [int(o.get("iscrowd", 0)) for o in g]
778
+ ious_bb = maskUtils.iou(dbb, gbb, iscrowd)
779
+ return ious, ious_bb
780
+
781
+ def evaluateImg(self, imgId, catId, aRng, maxDet):
782
+ """
783
+ perform evaluation for single category and image
784
+ :return: dict (single image results)
785
+ """
786
+
787
+ p = self.params
788
+ if p.useCats:
789
+ gt = self._gts[imgId, catId]
790
+ dt = self._dts[imgId, catId]
791
+ else:
792
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
793
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
794
+ if len(gt) == 0 and len(dt) == 0:
795
+ return None
796
+
797
+ for g in gt:
798
+ # g['_ignore'] = g['ignore']
799
+ if g["ignore"] or (g["area"] < aRng[0] or g["area"] > aRng[1]):
800
+ g["_ignore"] = True
801
+ else:
802
+ g["_ignore"] = False
803
+
804
+ # sort dt highest score first, sort gt ignore last
805
+ gtind = np.argsort([g["_ignore"] for g in gt], kind="mergesort")
806
+ gt = [gt[i] for i in gtind]
807
+ dtind = np.argsort([-d["score"] for d in dt], kind="mergesort")
808
+ dt = [dt[i] for i in dtind[0:maxDet]]
809
+ iscrowd = [int(o.get("iscrowd", 0)) for o in gt]
810
+ # load computed ious
811
+ if p.iouType == "densepose":
812
+ # print('Checking the length', len(self.ious[imgId, catId]))
813
+ # if len(self.ious[imgId, catId]) == 0:
814
+ # print(self.ious[imgId, catId])
815
+ ious = (
816
+ self.ious[imgId, catId][0][:, gtind]
817
+ if len(self.ious[imgId, catId]) > 0
818
+ else self.ious[imgId, catId]
819
+ )
820
+ ioubs = (
821
+ self.ious[imgId, catId][1][:, gtind]
822
+ if len(self.ious[imgId, catId]) > 0
823
+ else self.ious[imgId, catId]
824
+ )
825
+ if self._dpEvalMode in {DensePoseEvalMode.GPSM, DensePoseEvalMode.IOU}:
826
+ iousM = (
827
+ self.real_ious[imgId, catId][:, gtind]
828
+ if len(self.real_ious[imgId, catId]) > 0
829
+ else self.real_ious[imgId, catId]
830
+ )
831
+ else:
832
+ ious = (
833
+ self.ious[imgId, catId][:, gtind]
834
+ if len(self.ious[imgId, catId]) > 0
835
+ else self.ious[imgId, catId]
836
+ )
837
+
838
+ T = len(p.iouThrs)
839
+ G = len(gt)
840
+ D = len(dt)
841
+ gtm = np.zeros((T, G))
842
+ dtm = np.zeros((T, D))
843
+ gtIg = np.array([g["_ignore"] for g in gt])
844
+ dtIg = np.zeros((T, D))
845
+ if np.all(gtIg) and p.iouType == "densepose":
846
+ dtIg = np.logical_or(dtIg, True)
847
+
848
+ if len(ious) > 0: # and not p.iouType == 'densepose':
849
+ for tind, t in enumerate(p.iouThrs):
850
+ for dind, d in enumerate(dt):
851
+ # information about best match so far (m=-1 -> unmatched)
852
+ iou = min([t, 1 - 1e-10])
853
+ m = -1
854
+ for gind, _g in enumerate(gt):
855
+ # if this gt already matched, and not a crowd, continue
856
+ if gtm[tind, gind] > 0 and not iscrowd[gind]:
857
+ continue
858
+ # if dt matched to reg gt, and on ignore gt, stop
859
+ if m > -1 and gtIg[m] == 0 and gtIg[gind] == 1:
860
+ break
861
+ if p.iouType == "densepose":
862
+ if self._dpEvalMode == DensePoseEvalMode.GPSM:
863
+ new_iou = np.sqrt(iousM[dind, gind] * ious[dind, gind])
864
+ elif self._dpEvalMode == DensePoseEvalMode.IOU:
865
+ new_iou = iousM[dind, gind]
866
+ elif self._dpEvalMode == DensePoseEvalMode.GPS:
867
+ new_iou = ious[dind, gind]
868
+ else:
869
+ new_iou = ious[dind, gind]
870
+ if new_iou < iou:
871
+ continue
872
+ if new_iou == 0.0:
873
+ continue
874
+ # if match successful and best so far, store appropriately
875
+ iou = new_iou
876
+ m = gind
877
+ # if match made store id of match for both dt and gt
878
+ if m == -1:
879
+ continue
880
+ dtIg[tind, dind] = gtIg[m]
881
+ dtm[tind, dind] = gt[m]["id"]
882
+ gtm[tind, m] = d["id"]
883
+
884
+ if p.iouType == "densepose":
885
+ if not len(ioubs) == 0:
886
+ for dind, d in enumerate(dt):
887
+ # information about best match so far (m=-1 -> unmatched)
888
+ if dtm[tind, dind] == 0:
889
+ ioub = 0.8
890
+ m = -1
891
+ for gind, _g in enumerate(gt):
892
+ # if this gt already matched, and not a crowd, continue
893
+ if gtm[tind, gind] > 0 and not iscrowd[gind]:
894
+ continue
895
+ # continue to next gt unless better match made
896
+ if ioubs[dind, gind] < ioub:
897
+ continue
898
+ # if match successful and best so far, store appropriately
899
+ ioub = ioubs[dind, gind]
900
+ m = gind
901
+ # if match made store id of match for both dt and gt
902
+ if m > -1:
903
+ dtIg[:, dind] = gtIg[m]
904
+ if gtIg[m]:
905
+ dtm[tind, dind] = gt[m]["id"]
906
+ gtm[tind, m] = d["id"]
907
+ # set unmatched detections outside of area range to ignore
908
+ a = np.array([d["area"] < aRng[0] or d["area"] > aRng[1] for d in dt]).reshape((1, len(dt)))
909
+ dtIg = np.logical_or(dtIg, np.logical_and(dtm == 0, np.repeat(a, T, 0)))
910
+ # store results for given image and category
911
+ # print('Done with the function', len(self.ious[imgId, catId]))
912
+ return {
913
+ "image_id": imgId,
914
+ "category_id": catId,
915
+ "aRng": aRng,
916
+ "maxDet": maxDet,
917
+ "dtIds": [d["id"] for d in dt],
918
+ "gtIds": [g["id"] for g in gt],
919
+ "dtMatches": dtm,
920
+ "gtMatches": gtm,
921
+ "dtScores": [d["score"] for d in dt],
922
+ "gtIgnore": gtIg,
923
+ "dtIgnore": dtIg,
924
+ }
925
+
926
+ def accumulate(self, p=None):
927
+ """
928
+ Accumulate per image evaluation results and store the result in self.eval
929
+ :param p: input params for evaluation
930
+ :return: None
931
+ """
932
+ logger.info("Accumulating evaluation results...")
933
+ tic = time.time()
934
+ if not self.evalImgs:
935
+ logger.info("Please run evaluate() first")
936
+ # allows input customized parameters
937
+ if p is None:
938
+ p = self.params
939
+ p.catIds = p.catIds if p.useCats == 1 else [-1]
940
+ T = len(p.iouThrs)
941
+ R = len(p.recThrs)
942
+ K = len(p.catIds) if p.useCats else 1
943
+ A = len(p.areaRng)
944
+ M = len(p.maxDets)
945
+ precision = -(np.ones((T, R, K, A, M))) # -1 for the precision of absent categories
946
+ recall = -(np.ones((T, K, A, M)))
947
+
948
+ # create dictionary for future indexing
949
+ logger.info("Categories: {}".format(p.catIds))
950
+ _pe = self._paramsEval
951
+ catIds = _pe.catIds if _pe.useCats else [-1]
952
+ setK = set(catIds)
953
+ setA = set(map(tuple, _pe.areaRng))
954
+ setM = set(_pe.maxDets)
955
+ setI = set(_pe.imgIds)
956
+ # get inds to evaluate
957
+ k_list = [n for n, k in enumerate(p.catIds) if k in setK]
958
+ m_list = [m for n, m in enumerate(p.maxDets) if m in setM]
959
+ a_list = [n for n, a in enumerate(map(lambda x: tuple(x), p.areaRng)) if a in setA]
960
+ i_list = [n for n, i in enumerate(p.imgIds) if i in setI]
961
+ I0 = len(_pe.imgIds)
962
+ A0 = len(_pe.areaRng)
963
+ # retrieve E at each category, area range, and max number of detections
964
+ for k, k0 in enumerate(k_list):
965
+ Nk = k0 * A0 * I0
966
+ for a, a0 in enumerate(a_list):
967
+ Na = a0 * I0
968
+ for m, maxDet in enumerate(m_list):
969
+ E = [self.evalImgs[Nk + Na + i] for i in i_list]
970
+ E = [e for e in E if e is not None]
971
+ if len(E) == 0:
972
+ continue
973
+ dtScores = np.concatenate([e["dtScores"][0:maxDet] for e in E])
974
+
975
+ # different sorting method generates slightly different results.
976
+ # mergesort is used to be consistent as Matlab implementation.
977
+ inds = np.argsort(-dtScores, kind="mergesort")
978
+
979
+ dtm = np.concatenate([e["dtMatches"][:, 0:maxDet] for e in E], axis=1)[:, inds]
980
+ dtIg = np.concatenate([e["dtIgnore"][:, 0:maxDet] for e in E], axis=1)[:, inds]
981
+ gtIg = np.concatenate([e["gtIgnore"] for e in E])
982
+ npig = np.count_nonzero(gtIg == 0)
983
+ if npig == 0:
984
+ continue
985
+ tps = np.logical_and(dtm, np.logical_not(dtIg))
986
+ fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg))
987
+ tp_sum = np.cumsum(tps, axis=1).astype(dtype=float)
988
+ fp_sum = np.cumsum(fps, axis=1).astype(dtype=float)
989
+ for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
990
+ tp = np.array(tp)
991
+ fp = np.array(fp)
992
+ nd = len(tp)
993
+ rc = tp / npig
994
+ pr = tp / (fp + tp + np.spacing(1))
995
+ q = np.zeros((R,))
996
+
997
+ if nd:
998
+ recall[t, k, a, m] = rc[-1]
999
+ else:
1000
+ recall[t, k, a, m] = 0
1001
+
1002
+ # numpy is slow without cython optimization for accessing elements
1003
+ # use python array gets significant speed improvement
1004
+ pr = pr.tolist()
1005
+ q = q.tolist()
1006
+
1007
+ for i in range(nd - 1, 0, -1):
1008
+ if pr[i] > pr[i - 1]:
1009
+ pr[i - 1] = pr[i]
1010
+
1011
+ inds = np.searchsorted(rc, p.recThrs, side="left")
1012
+ try:
1013
+ for ri, pi in enumerate(inds):
1014
+ q[ri] = pr[pi]
1015
+ except Exception:
1016
+ pass
1017
+ precision[t, :, k, a, m] = np.array(q)
1018
+ logger.info(
1019
+ "Final: max precision {}, min precision {}".format(np.max(precision), np.min(precision))
1020
+ )
1021
+ self.eval = {
1022
+ "params": p,
1023
+ "counts": [T, R, K, A, M],
1024
+ "date": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
1025
+ "precision": precision,
1026
+ "recall": recall,
1027
+ }
1028
+ toc = time.time()
1029
+ logger.info("DONE (t={:0.2f}s).".format(toc - tic))
1030
+
1031
+ def summarize(self):
1032
+ """
1033
+ Compute and display summary metrics for evaluation results.
1034
+ Note this function can *only* be applied on the default parameter setting
1035
+ """
1036
+
1037
+ def _summarize(ap=1, iouThr=None, areaRng="all", maxDets=100):
1038
+ p = self.params
1039
+ iStr = " {:<18} {} @[ {}={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
1040
+ titleStr = "Average Precision" if ap == 1 else "Average Recall"
1041
+ typeStr = "(AP)" if ap == 1 else "(AR)"
1042
+ measure = "IoU"
1043
+ if self.params.iouType == "keypoints":
1044
+ measure = "OKS"
1045
+ elif self.params.iouType == "densepose":
1046
+ measure = "OGPS"
1047
+ iouStr = (
1048
+ "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
1049
+ if iouThr is None
1050
+ else "{:0.2f}".format(iouThr)
1051
+ )
1052
+
1053
+ aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
1054
+ mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
1055
+ if ap == 1:
1056
+ # dimension of precision: [TxRxKxAxM]
1057
+ s = self.eval["precision"]
1058
+ # IoU
1059
+ if iouThr is not None:
1060
+ t = np.where(np.abs(iouThr - p.iouThrs) < 0.001)[0]
1061
+ s = s[t]
1062
+ s = s[:, :, :, aind, mind]
1063
+ else:
1064
+ # dimension of recall: [TxKxAxM]
1065
+ s = self.eval["recall"]
1066
+ if iouThr is not None:
1067
+ t = np.where(np.abs(iouThr - p.iouThrs) < 0.001)[0]
1068
+ s = s[t]
1069
+ s = s[:, :, aind, mind]
1070
+ if len(s[s > -1]) == 0:
1071
+ mean_s = -1
1072
+ else:
1073
+ mean_s = np.mean(s[s > -1])
1074
+ logger.info(iStr.format(titleStr, typeStr, measure, iouStr, areaRng, maxDets, mean_s))
1075
+ return mean_s
1076
+
1077
+ def _summarizeDets():
1078
+ stats = np.zeros((12,))
1079
+ stats[0] = _summarize(1)
1080
+ stats[1] = _summarize(1, iouThr=0.5, maxDets=self.params.maxDets[2])
1081
+ stats[2] = _summarize(1, iouThr=0.75, maxDets=self.params.maxDets[2])
1082
+ stats[3] = _summarize(1, areaRng="small", maxDets=self.params.maxDets[2])
1083
+ stats[4] = _summarize(1, areaRng="medium", maxDets=self.params.maxDets[2])
1084
+ stats[5] = _summarize(1, areaRng="large", maxDets=self.params.maxDets[2])
1085
+ stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
1086
+ stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
1087
+ stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
1088
+ stats[9] = _summarize(0, areaRng="small", maxDets=self.params.maxDets[2])
1089
+ stats[10] = _summarize(0, areaRng="medium", maxDets=self.params.maxDets[2])
1090
+ stats[11] = _summarize(0, areaRng="large", maxDets=self.params.maxDets[2])
1091
+ return stats
1092
+
1093
+ def _summarizeKps():
1094
+ stats = np.zeros((10,))
1095
+ stats[0] = _summarize(1, maxDets=20)
1096
+ stats[1] = _summarize(1, maxDets=20, iouThr=0.5)
1097
+ stats[2] = _summarize(1, maxDets=20, iouThr=0.75)
1098
+ stats[3] = _summarize(1, maxDets=20, areaRng="medium")
1099
+ stats[4] = _summarize(1, maxDets=20, areaRng="large")
1100
+ stats[5] = _summarize(0, maxDets=20)
1101
+ stats[6] = _summarize(0, maxDets=20, iouThr=0.5)
1102
+ stats[7] = _summarize(0, maxDets=20, iouThr=0.75)
1103
+ stats[8] = _summarize(0, maxDets=20, areaRng="medium")
1104
+ stats[9] = _summarize(0, maxDets=20, areaRng="large")
1105
+ return stats
1106
+
1107
+ def _summarizeUvs():
1108
+ stats = [_summarize(1, maxDets=self.params.maxDets[0])]
1109
+ min_threshold = self.params.iouThrs.min()
1110
+ if min_threshold <= 0.201:
1111
+ stats += [_summarize(1, maxDets=self.params.maxDets[0], iouThr=0.2)]
1112
+ if min_threshold <= 0.301:
1113
+ stats += [_summarize(1, maxDets=self.params.maxDets[0], iouThr=0.3)]
1114
+ if min_threshold <= 0.401:
1115
+ stats += [_summarize(1, maxDets=self.params.maxDets[0], iouThr=0.4)]
1116
+ stats += [
1117
+ _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.5),
1118
+ _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.75),
1119
+ _summarize(1, maxDets=self.params.maxDets[0], areaRng="medium"),
1120
+ _summarize(1, maxDets=self.params.maxDets[0], areaRng="large"),
1121
+ _summarize(0, maxDets=self.params.maxDets[0]),
1122
+ _summarize(0, maxDets=self.params.maxDets[0], iouThr=0.5),
1123
+ _summarize(0, maxDets=self.params.maxDets[0], iouThr=0.75),
1124
+ _summarize(0, maxDets=self.params.maxDets[0], areaRng="medium"),
1125
+ _summarize(0, maxDets=self.params.maxDets[0], areaRng="large"),
1126
+ ]
1127
+ return np.array(stats)
1128
+
1129
+ def _summarizeUvsOld():
1130
+ stats = np.zeros((18,))
1131
+ stats[0] = _summarize(1, maxDets=self.params.maxDets[0])
1132
+ stats[1] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.5)
1133
+ stats[2] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.55)
1134
+ stats[3] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.60)
1135
+ stats[4] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.65)
1136
+ stats[5] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.70)
1137
+ stats[6] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.75)
1138
+ stats[7] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.80)
1139
+ stats[8] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.85)
1140
+ stats[9] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.90)
1141
+ stats[10] = _summarize(1, maxDets=self.params.maxDets[0], iouThr=0.95)
1142
+ stats[11] = _summarize(1, maxDets=self.params.maxDets[0], areaRng="medium")
1143
+ stats[12] = _summarize(1, maxDets=self.params.maxDets[0], areaRng="large")
1144
+ stats[13] = _summarize(0, maxDets=self.params.maxDets[0])
1145
+ stats[14] = _summarize(0, maxDets=self.params.maxDets[0], iouThr=0.5)
1146
+ stats[15] = _summarize(0, maxDets=self.params.maxDets[0], iouThr=0.75)
1147
+ stats[16] = _summarize(0, maxDets=self.params.maxDets[0], areaRng="medium")
1148
+ stats[17] = _summarize(0, maxDets=self.params.maxDets[0], areaRng="large")
1149
+ return stats
1150
+
1151
+ if not self.eval:
1152
+ raise Exception("Please run accumulate() first")
1153
+ iouType = self.params.iouType
1154
+ if iouType in ["segm", "bbox"]:
1155
+ summarize = _summarizeDets
1156
+ elif iouType in ["keypoints"]:
1157
+ summarize = _summarizeKps
1158
+ elif iouType in ["densepose"]:
1159
+ summarize = _summarizeUvs
1160
+ self.stats = summarize()
1161
+
1162
+ def __str__(self):
1163
+ self.summarize()
1164
+
1165
+ # ================ functions for dense pose ==============================
1166
+ def findAllClosestVertsUV(self, U_points, V_points, Index_points):
1167
+ ClosestVerts = np.ones(Index_points.shape) * -1
1168
+ for i in np.arange(24):
1169
+ #
1170
+ if (i + 1) in Index_points:
1171
+ UVs = np.array(
1172
+ [U_points[Index_points == (i + 1)], V_points[Index_points == (i + 1)]]
1173
+ )
1174
+ Current_Part_UVs = self.Part_UVs[i]
1175
+ Current_Part_ClosestVertInds = self.Part_ClosestVertInds[i]
1176
+ D = ssd.cdist(Current_Part_UVs.transpose(), UVs.transpose()).squeeze()
1177
+ ClosestVerts[Index_points == (i + 1)] = Current_Part_ClosestVertInds[
1178
+ np.argmin(D, axis=0)
1179
+ ]
1180
+ ClosestVertsTransformed = self.PDIST_transform[ClosestVerts.astype(int) - 1]
1181
+ ClosestVertsTransformed[ClosestVerts < 0] = 0
1182
+ return ClosestVertsTransformed
1183
+
1184
+ def findClosestVertsCse(self, embedding, py, px, mask, mesh_name):
1185
+ mesh_vertex_embeddings = self.embedder(mesh_name)
1186
+ pixel_embeddings = embedding[:, py, px].t().to(device="cuda")
1187
+ mask_vals = mask[py, px]
1188
+ edm = squared_euclidean_distance_matrix(pixel_embeddings, mesh_vertex_embeddings)
1189
+ vertex_indices = edm.argmin(dim=1).cpu()
1190
+ vertex_indices[mask_vals <= 0] = -1
1191
+ return vertex_indices
1192
+
1193
+ def findAllClosestVertsGT(self, gt):
1194
+ #
1195
+ I_gt = np.array(gt["dp_I"])
1196
+ U_gt = np.array(gt["dp_U"])
1197
+ V_gt = np.array(gt["dp_V"])
1198
+ #
1199
+ # print(I_gt)
1200
+ #
1201
+ ClosestVertsGT = np.ones(I_gt.shape) * -1
1202
+ for i in np.arange(24):
1203
+ if (i + 1) in I_gt:
1204
+ UVs = np.array([U_gt[I_gt == (i + 1)], V_gt[I_gt == (i + 1)]])
1205
+ Current_Part_UVs = self.Part_UVs[i]
1206
+ Current_Part_ClosestVertInds = self.Part_ClosestVertInds[i]
1207
+ D = ssd.cdist(Current_Part_UVs.transpose(), UVs.transpose()).squeeze()
1208
+ ClosestVertsGT[I_gt == (i + 1)] = Current_Part_ClosestVertInds[np.argmin(D, axis=0)]
1209
+ #
1210
+ ClosestVertsGTTransformed = self.PDIST_transform[ClosestVertsGT.astype(int) - 1]
1211
+ ClosestVertsGTTransformed[ClosestVertsGT < 0] = 0
1212
+ return ClosestVertsGT, ClosestVertsGTTransformed
1213
+
1214
+ def getDistancesCse(self, cVertsGT, cVerts, mesh_name):
1215
+ geodists_vertices = torch.ones_like(cVertsGT) * float("inf")
1216
+ selected = (cVertsGT >= 0) * (cVerts >= 0)
1217
+ mesh = create_mesh(mesh_name, "cpu")
1218
+ geodists_vertices[selected] = mesh.geodists[cVertsGT[selected], cVerts[selected]]
1219
+ return geodists_vertices.numpy()
1220
+
1221
+ def getDistancesUV(self, cVertsGT, cVerts):
1222
+ #
1223
+ n = 27554
1224
+ dists = []
1225
+ for d in range(len(cVertsGT)):
1226
+ if cVertsGT[d] > 0:
1227
+ if cVerts[d] > 0:
1228
+ i = cVertsGT[d] - 1
1229
+ j = cVerts[d] - 1
1230
+ if j == i:
1231
+ dists.append(0)
1232
+ elif j > i:
1233
+ ccc = i
1234
+ i = j
1235
+ j = ccc
1236
+ i = n - i - 1
1237
+ j = n - j - 1
1238
+ k = (n * (n - 1) / 2) - (n - i) * ((n - i) - 1) / 2 + j - i - 1
1239
+ k = (n * n - n) / 2 - k - 1
1240
+ dists.append(self.Pdist_matrix[int(k)][0])
1241
+ else:
1242
+ i = n - i - 1
1243
+ j = n - j - 1
1244
+ k = (n * (n - 1) / 2) - (n - i) * ((n - i) - 1) / 2 + j - i - 1
1245
+ k = (n * n - n) / 2 - k - 1
1246
+ dists.append(self.Pdist_matrix[int(k)][0])
1247
+ else:
1248
+ dists.append(np.inf)
1249
+ return np.atleast_1d(np.array(dists).squeeze())
1250
+
1251
+
1252
+ class Params:
1253
+ """
1254
+ Params for coco evaluation api
1255
+ """
1256
+
1257
+ def setDetParams(self):
1258
+ self.imgIds = []
1259
+ self.catIds = []
1260
+ # np.arange causes trouble. the data point on arange is slightly larger than the true value
1261
+ self.iouThrs = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True)
1262
+ self.recThrs = np.linspace(0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True)
1263
+ self.maxDets = [1, 10, 100]
1264
+ self.areaRng = [
1265
+ [0**2, 1e5**2],
1266
+ [0**2, 32**2],
1267
+ [32**2, 96**2],
1268
+ [96**2, 1e5**2],
1269
+ ]
1270
+ self.areaRngLbl = ["all", "small", "medium", "large"]
1271
+ self.useCats = 1
1272
+
1273
+ def setKpParams(self):
1274
+ self.imgIds = []
1275
+ self.catIds = []
1276
+ # np.arange causes trouble. the data point on arange is slightly larger than the true value
1277
+ self.iouThrs = np.linspace(0.5, 0.95, np.round((0.95 - 0.5) / 0.05) + 1, endpoint=True)
1278
+ self.recThrs = np.linspace(0.0, 1.00, np.round((1.00 - 0.0) / 0.01) + 1, endpoint=True)
1279
+ self.maxDets = [20]
1280
+ self.areaRng = [[0**2, 1e5**2], [32**2, 96**2], [96**2, 1e5**2]]
1281
+ self.areaRngLbl = ["all", "medium", "large"]
1282
+ self.useCats = 1
1283
+
1284
+ def setUvParams(self):
1285
+ self.imgIds = []
1286
+ self.catIds = []
1287
+ self.iouThrs = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True)
1288
+ self.recThrs = np.linspace(0.0, 1.00, int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True)
1289
+ self.maxDets = [20]
1290
+ self.areaRng = [[0**2, 1e5**2], [32**2, 96**2], [96**2, 1e5**2]]
1291
+ self.areaRngLbl = ["all", "medium", "large"]
1292
+ self.useCats = 1
1293
+
1294
+ def __init__(self, iouType="segm"):
1295
+ if iouType == "segm" or iouType == "bbox":
1296
+ self.setDetParams()
1297
+ elif iouType == "keypoints":
1298
+ self.setKpParams()
1299
+ elif iouType == "densepose":
1300
+ self.setUvParams()
1301
+ else:
1302
+ raise Exception("iouType not supported")
1303
+ self.iouType = iouType
1304
+ # useSegm is deprecated
1305
+ self.useSegm = None
densepose/evaluation/evaluator.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ # pyre-unsafe
5
+
6
+ import contextlib
7
+ import copy
8
+ import io
9
+ import itertools
10
+ import logging
11
+ import numpy as np
12
+ import os
13
+ from collections import OrderedDict
14
+ from typing import Dict, Iterable, List, Optional
15
+ import pycocotools.mask as mask_utils
16
+ import torch
17
+ from pycocotools.coco import COCO
18
+ from tabulate import tabulate
19
+
20
+ from detectron2.config import CfgNode
21
+ from detectron2.data import MetadataCatalog
22
+ from detectron2.evaluation import DatasetEvaluator
23
+ from detectron2.structures import BoxMode
24
+ from detectron2.utils.comm import gather, get_rank, is_main_process, synchronize
25
+ from detectron2.utils.file_io import PathManager
26
+ from detectron2.utils.logger import create_small_table
27
+
28
+ from densepose.converters import ToChartResultConverter, ToMaskConverter
29
+ from densepose.data.datasets.coco import maybe_filter_and_map_categories_cocoapi
30
+ from densepose.structures import (
31
+ DensePoseChartPredictorOutput,
32
+ DensePoseEmbeddingPredictorOutput,
33
+ quantize_densepose_chart_result,
34
+ )
35
+
36
+ from .densepose_coco_evaluation import DensePoseCocoEval, DensePoseEvalMode
37
+ from .mesh_alignment_evaluator import MeshAlignmentEvaluator
38
+ from .tensor_storage import (
39
+ SingleProcessFileTensorStorage,
40
+ SingleProcessRamTensorStorage,
41
+ SingleProcessTensorStorage,
42
+ SizeData,
43
+ storage_gather,
44
+ )
45
+
46
+
47
+ class DensePoseCOCOEvaluator(DatasetEvaluator):
48
+ def __init__(
49
+ self,
50
+ dataset_name,
51
+ distributed,
52
+ output_dir=None,
53
+ evaluator_type: str = "iuv",
54
+ min_iou_threshold: float = 0.5,
55
+ storage: Optional[SingleProcessTensorStorage] = None,
56
+ embedder=None,
57
+ should_evaluate_mesh_alignment: bool = False,
58
+ mesh_alignment_mesh_names: Optional[List[str]] = None,
59
+ ):
60
+ self._embedder = embedder
61
+ self._distributed = distributed
62
+ self._output_dir = output_dir
63
+ self._evaluator_type = evaluator_type
64
+ self._storage = storage
65
+ self._should_evaluate_mesh_alignment = should_evaluate_mesh_alignment
66
+
67
+ assert not (
68
+ should_evaluate_mesh_alignment and embedder is None
69
+ ), "Mesh alignment evaluation is activated, but no vertex embedder provided!"
70
+ if should_evaluate_mesh_alignment:
71
+ self._mesh_alignment_evaluator = MeshAlignmentEvaluator(
72
+ embedder,
73
+ mesh_alignment_mesh_names,
74
+ )
75
+
76
+ self._cpu_device = torch.device("cpu")
77
+ self._logger = logging.getLogger(__name__)
78
+
79
+ self._metadata = MetadataCatalog.get(dataset_name)
80
+ self._min_threshold = min_iou_threshold
81
+ json_file = PathManager.get_local_path(self._metadata.json_file)
82
+ with contextlib.redirect_stdout(io.StringIO()):
83
+ self._coco_api = COCO(json_file)
84
+ maybe_filter_and_map_categories_cocoapi(dataset_name, self._coco_api)
85
+
86
+ def reset(self):
87
+ self._predictions = []
88
+
89
+ def process(self, inputs, outputs):
90
+ """
91
+ Args:
92
+ inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
93
+ It is a list of dict. Each dict corresponds to an image and
94
+ contains keys like "height", "width", "file_name", "image_id".
95
+ outputs: the outputs of a COCO model. It is a list of dicts with key
96
+ "instances" that contains :class:`Instances`.
97
+ The :class:`Instances` object needs to have `densepose` field.
98
+ """
99
+ for input, output in zip(inputs, outputs):
100
+ instances = output["instances"].to(self._cpu_device)
101
+ if not instances.has("pred_densepose"):
102
+ continue
103
+ prediction_list = prediction_to_dict(
104
+ instances,
105
+ input["image_id"],
106
+ self._embedder,
107
+ self._metadata.class_to_mesh_name,
108
+ self._storage is not None,
109
+ )
110
+ if self._storage is not None:
111
+ for prediction_dict in prediction_list:
112
+ dict_to_store = {}
113
+ for field_name in self._storage.data_schema:
114
+ dict_to_store[field_name] = prediction_dict[field_name]
115
+ record_id = self._storage.put(dict_to_store)
116
+ prediction_dict["record_id"] = record_id
117
+ prediction_dict["rank"] = get_rank()
118
+ for field_name in self._storage.data_schema:
119
+ del prediction_dict[field_name]
120
+ self._predictions.extend(prediction_list)
121
+
122
+ def evaluate(self, img_ids=None):
123
+ if self._distributed:
124
+ synchronize()
125
+ predictions = gather(self._predictions)
126
+ predictions = list(itertools.chain(*predictions))
127
+ else:
128
+ predictions = self._predictions
129
+
130
+ multi_storage = storage_gather(self._storage) if self._storage is not None else None
131
+
132
+ if not is_main_process():
133
+ return
134
+ return copy.deepcopy(self._eval_predictions(predictions, multi_storage, img_ids))
135
+
136
+ def _eval_predictions(self, predictions, multi_storage=None, img_ids=None):
137
+ """
138
+ Evaluate predictions on densepose.
139
+ Return results with the metrics of the tasks.
140
+ """
141
+ self._logger.info("Preparing results for COCO format ...")
142
+
143
+ if self._output_dir:
144
+ PathManager.mkdirs(self._output_dir)
145
+ file_path = os.path.join(self._output_dir, "coco_densepose_predictions.pth")
146
+ with PathManager.open(file_path, "wb") as f:
147
+ torch.save(predictions, f)
148
+
149
+ self._logger.info("Evaluating predictions ...")
150
+ res = OrderedDict()
151
+ results_gps, results_gpsm, results_segm = _evaluate_predictions_on_coco(
152
+ self._coco_api,
153
+ predictions,
154
+ multi_storage,
155
+ self._embedder,
156
+ class_names=self._metadata.get("thing_classes"),
157
+ min_threshold=self._min_threshold,
158
+ img_ids=img_ids,
159
+ )
160
+ res["densepose_gps"] = results_gps
161
+ res["densepose_gpsm"] = results_gpsm
162
+ res["densepose_segm"] = results_segm
163
+ if self._should_evaluate_mesh_alignment:
164
+ res["densepose_mesh_alignment"] = self._evaluate_mesh_alignment()
165
+ return res
166
+
167
+ def _evaluate_mesh_alignment(self):
168
+ self._logger.info("Mesh alignment evaluation ...")
169
+ mean_ge, mean_gps, per_mesh_metrics = self._mesh_alignment_evaluator.evaluate()
170
+ results = {
171
+ "GE": mean_ge * 100,
172
+ "GPS": mean_gps * 100,
173
+ }
174
+ mesh_names = set()
175
+ for metric_name in per_mesh_metrics:
176
+ for mesh_name, value in per_mesh_metrics[metric_name].items():
177
+ results[f"{metric_name}-{mesh_name}"] = value * 100
178
+ mesh_names.add(mesh_name)
179
+ self._print_mesh_alignment_results(results, mesh_names)
180
+ return results
181
+
182
+ def _print_mesh_alignment_results(self, results: Dict[str, float], mesh_names: Iterable[str]):
183
+ self._logger.info("Evaluation results for densepose, mesh alignment:")
184
+ self._logger.info(f'| {"Mesh":13s} | {"GErr":7s} | {"GPS":7s} |')
185
+ self._logger.info("| :-----------: | :-----: | :-----: |")
186
+ for mesh_name in mesh_names:
187
+ ge_key = f"GE-{mesh_name}"
188
+ ge_str = f"{results[ge_key]:.4f}" if ge_key in results else " "
189
+ gps_key = f"GPS-{mesh_name}"
190
+ gps_str = f"{results[gps_key]:.4f}" if gps_key in results else " "
191
+ self._logger.info(f"| {mesh_name:13s} | {ge_str:7s} | {gps_str:7s} |")
192
+ self._logger.info("| :-------------------------------: |")
193
+ ge_key = "GE"
194
+ ge_str = f"{results[ge_key]:.4f}" if ge_key in results else " "
195
+ gps_key = "GPS"
196
+ gps_str = f"{results[gps_key]:.4f}" if gps_key in results else " "
197
+ self._logger.info(f'| {"MEAN":13s} | {ge_str:7s} | {gps_str:7s} |')
198
+
199
+
200
+ def prediction_to_dict(instances, img_id, embedder, class_to_mesh_name, use_storage):
201
+ """
202
+ Args:
203
+ instances (Instances): the output of the model
204
+ img_id (str): the image id in COCO
205
+
206
+ Returns:
207
+ list[dict]: the results in densepose evaluation format
208
+ """
209
+ scores = instances.scores.tolist()
210
+ classes = instances.pred_classes.tolist()
211
+ raw_boxes_xywh = BoxMode.convert(
212
+ instances.pred_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS
213
+ )
214
+
215
+ if isinstance(instances.pred_densepose, DensePoseEmbeddingPredictorOutput):
216
+ results_densepose = densepose_cse_predictions_to_dict(
217
+ instances, embedder, class_to_mesh_name, use_storage
218
+ )
219
+ elif isinstance(instances.pred_densepose, DensePoseChartPredictorOutput):
220
+ if not use_storage:
221
+ results_densepose = densepose_chart_predictions_to_dict(instances)
222
+ else:
223
+ results_densepose = densepose_chart_predictions_to_storage_dict(instances)
224
+
225
+ results = []
226
+ for k in range(len(instances)):
227
+ result = {
228
+ "image_id": img_id,
229
+ "category_id": classes[k],
230
+ "bbox": raw_boxes_xywh[k].tolist(),
231
+ "score": scores[k],
232
+ }
233
+ results.append({**result, **results_densepose[k]})
234
+ return results
235
+
236
+
237
+ def densepose_chart_predictions_to_dict(instances):
238
+ segmentations = ToMaskConverter.convert(
239
+ instances.pred_densepose, instances.pred_boxes, instances.image_size
240
+ )
241
+
242
+ results = []
243
+ for k in range(len(instances)):
244
+ densepose_results_quantized = quantize_densepose_chart_result(
245
+ ToChartResultConverter.convert(instances.pred_densepose[k], instances.pred_boxes[k])
246
+ )
247
+ densepose_results_quantized.labels_uv_uint8 = (
248
+ densepose_results_quantized.labels_uv_uint8.cpu()
249
+ )
250
+ segmentation = segmentations.tensor[k]
251
+ segmentation_encoded = mask_utils.encode(
252
+ np.require(segmentation.numpy(), dtype=np.uint8, requirements=["F"])
253
+ )
254
+ segmentation_encoded["counts"] = segmentation_encoded["counts"].decode("utf-8")
255
+ result = {
256
+ "densepose": densepose_results_quantized,
257
+ "segmentation": segmentation_encoded,
258
+ }
259
+ results.append(result)
260
+ return results
261
+
262
+
263
+ def densepose_chart_predictions_to_storage_dict(instances):
264
+ results = []
265
+ for k in range(len(instances)):
266
+ densepose_predictor_output = instances.pred_densepose[k]
267
+ result = {
268
+ "coarse_segm": densepose_predictor_output.coarse_segm.squeeze(0).cpu(),
269
+ "fine_segm": densepose_predictor_output.fine_segm.squeeze(0).cpu(),
270
+ "u": densepose_predictor_output.u.squeeze(0).cpu(),
271
+ "v": densepose_predictor_output.v.squeeze(0).cpu(),
272
+ }
273
+ results.append(result)
274
+ return results
275
+
276
+
277
+ def densepose_cse_predictions_to_dict(instances, embedder, class_to_mesh_name, use_storage):
278
+ results = []
279
+ for k in range(len(instances)):
280
+ cse = instances.pred_densepose[k]
281
+ results.append(
282
+ {
283
+ "coarse_segm": cse.coarse_segm[0].cpu(),
284
+ "embedding": cse.embedding[0].cpu(),
285
+ }
286
+ )
287
+ return results
288
+
289
+
290
+ def _evaluate_predictions_on_coco(
291
+ coco_gt,
292
+ coco_results,
293
+ multi_storage=None,
294
+ embedder=None,
295
+ class_names=None,
296
+ min_threshold: float = 0.5,
297
+ img_ids=None,
298
+ ):
299
+ logger = logging.getLogger(__name__)
300
+
301
+ densepose_metrics = _get_densepose_metrics(min_threshold)
302
+ if len(coco_results) == 0: # cocoapi does not handle empty results very well
303
+ logger.warn("No predictions from the model! Set scores to -1")
304
+ results_gps = {metric: -1 for metric in densepose_metrics}
305
+ results_gpsm = {metric: -1 for metric in densepose_metrics}
306
+ results_segm = {metric: -1 for metric in densepose_metrics}
307
+ return results_gps, results_gpsm, results_segm
308
+
309
+ coco_dt = coco_gt.loadRes(coco_results)
310
+
311
+ results = []
312
+ for eval_mode_name in ["GPS", "GPSM", "IOU"]:
313
+ eval_mode = getattr(DensePoseEvalMode, eval_mode_name)
314
+ coco_eval = DensePoseCocoEval(
315
+ coco_gt, coco_dt, "densepose", multi_storage, embedder, dpEvalMode=eval_mode
316
+ )
317
+ result = _derive_results_from_coco_eval(
318
+ coco_eval, eval_mode_name, densepose_metrics, class_names, min_threshold, img_ids
319
+ )
320
+ results.append(result)
321
+ return results
322
+
323
+
324
+ def _get_densepose_metrics(min_threshold: float = 0.5):
325
+ metrics = ["AP"]
326
+ if min_threshold <= 0.201:
327
+ metrics += ["AP20"]
328
+ if min_threshold <= 0.301:
329
+ metrics += ["AP30"]
330
+ if min_threshold <= 0.401:
331
+ metrics += ["AP40"]
332
+ metrics.extend(["AP50", "AP75", "APm", "APl", "AR", "AR50", "AR75", "ARm", "ARl"])
333
+ return metrics
334
+
335
+
336
+ def _derive_results_from_coco_eval(
337
+ coco_eval, eval_mode_name, metrics, class_names, min_threshold: float, img_ids
338
+ ):
339
+ if img_ids is not None:
340
+ coco_eval.params.imgIds = img_ids
341
+ coco_eval.params.iouThrs = np.linspace(
342
+ min_threshold, 0.95, int(np.round((0.95 - min_threshold) / 0.05)) + 1, endpoint=True
343
+ )
344
+ coco_eval.evaluate()
345
+ coco_eval.accumulate()
346
+ coco_eval.summarize()
347
+ results = {metric: float(coco_eval.stats[idx] * 100) for idx, metric in enumerate(metrics)}
348
+ logger = logging.getLogger(__name__)
349
+ logger.info(
350
+ f"Evaluation results for densepose, {eval_mode_name} metric: \n"
351
+ + create_small_table(results)
352
+ )
353
+ if class_names is None or len(class_names) <= 1:
354
+ return results
355
+
356
+ # Compute per-category AP, the same way as it is done in D2
357
+ # (see detectron2/evaluation/coco_evaluation.py):
358
+ precisions = coco_eval.eval["precision"]
359
+ # precision has dims (iou, recall, cls, area range, max dets)
360
+ assert len(class_names) == precisions.shape[2]
361
+
362
+ results_per_category = []
363
+ for idx, name in enumerate(class_names):
364
+ # area range index 0: all area ranges
365
+ # max dets index -1: typically 100 per image
366
+ precision = precisions[:, :, idx, 0, -1]
367
+ precision = precision[precision > -1]
368
+ ap = np.mean(precision) if precision.size else float("nan")
369
+ results_per_category.append((f"{name}", float(ap * 100)))
370
+
371
+ # tabulate it
372
+ n_cols = min(6, len(results_per_category) * 2)
373
+ results_flatten = list(itertools.chain(*results_per_category))
374
+ results_2d = itertools.zip_longest(*[results_flatten[i::n_cols] for i in range(n_cols)])
375
+ table = tabulate(
376
+ results_2d,
377
+ tablefmt="pipe",
378
+ floatfmt=".3f",
379
+ headers=["category", "AP"] * (n_cols // 2),
380
+ numalign="left",
381
+ )
382
+ logger.info(f"Per-category {eval_mode_name} AP: \n" + table)
383
+
384
+ results.update({"AP-" + name: ap for name, ap in results_per_category})
385
+ return results
386
+
387
+
388
+ def build_densepose_evaluator_storage(cfg: CfgNode, output_folder: str):
389
+ storage_spec = cfg.DENSEPOSE_EVALUATION.STORAGE
390
+ if storage_spec == "none":
391
+ return None
392
+ evaluator_type = cfg.DENSEPOSE_EVALUATION.TYPE
393
+ # common output tensor sizes
394
+ hout = cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE
395
+ wout = cfg.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE
396
+ n_csc = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS
397
+ # specific output tensors
398
+ if evaluator_type == "iuv":
399
+ n_fsc = cfg.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES + 1
400
+ schema = {
401
+ "coarse_segm": SizeData(dtype="float32", shape=(n_csc, hout, wout)),
402
+ "fine_segm": SizeData(dtype="float32", shape=(n_fsc, hout, wout)),
403
+ "u": SizeData(dtype="float32", shape=(n_fsc, hout, wout)),
404
+ "v": SizeData(dtype="float32", shape=(n_fsc, hout, wout)),
405
+ }
406
+ elif evaluator_type == "cse":
407
+ embed_size = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE
408
+ schema = {
409
+ "coarse_segm": SizeData(dtype="float32", shape=(n_csc, hout, wout)),
410
+ "embedding": SizeData(dtype="float32", shape=(embed_size, hout, wout)),
411
+ }
412
+ else:
413
+ raise ValueError(f"Unknown evaluator type: {evaluator_type}")
414
+ # storage types
415
+ if storage_spec == "ram":
416
+ storage = SingleProcessRamTensorStorage(schema, io.BytesIO())
417
+ elif storage_spec == "file":
418
+ fpath = os.path.join(output_folder, f"DensePoseEvaluatorStorage.{get_rank()}.bin")
419
+ PathManager.mkdirs(output_folder)
420
+ storage = SingleProcessFileTensorStorage(schema, fpath, "wb")
421
+ else:
422
+ raise ValueError(f"Unknown storage specification: {storage_spec}")
423
+ return storage
densepose/evaluation/mesh_alignment_evaluator.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ # pyre-unsafe
4
+
5
+ import json
6
+ import logging
7
+ from typing import List, Optional
8
+ import torch
9
+ from torch import nn
10
+
11
+ from detectron2.utils.file_io import PathManager
12
+
13
+ from densepose.structures.mesh import create_mesh
14
+
15
+
16
+ class MeshAlignmentEvaluator:
17
+ """
18
+ Class for evaluation of 3D mesh alignment based on the learned vertex embeddings
19
+ """
20
+
21
+ def __init__(self, embedder: nn.Module, mesh_names: Optional[List[str]]):
22
+ self.embedder = embedder
23
+ # use the provided mesh names if not None and not an empty list
24
+ self.mesh_names = mesh_names if mesh_names else embedder.mesh_names
25
+ self.logger = logging.getLogger(__name__)
26
+ with PathManager.open(
27
+ "https://dl.fbaipublicfiles.com/densepose/data/cse/mesh_keyvertices_v0.json", "r"
28
+ ) as f:
29
+ self.mesh_keyvertices = json.load(f)
30
+
31
+ def evaluate(self):
32
+ ge_per_mesh = {}
33
+ gps_per_mesh = {}
34
+ for mesh_name_1 in self.mesh_names:
35
+ avg_errors = []
36
+ avg_gps = []
37
+ embeddings_1 = self.embedder(mesh_name_1)
38
+ keyvertices_1 = self.mesh_keyvertices[mesh_name_1]
39
+ keyvertex_names_1 = list(keyvertices_1.keys())
40
+ keyvertex_indices_1 = [keyvertices_1[name] for name in keyvertex_names_1]
41
+ for mesh_name_2 in self.mesh_names:
42
+ if mesh_name_1 == mesh_name_2:
43
+ continue
44
+ embeddings_2 = self.embedder(mesh_name_2)
45
+ keyvertices_2 = self.mesh_keyvertices[mesh_name_2]
46
+ sim_matrix_12 = embeddings_1[keyvertex_indices_1].mm(embeddings_2.T)
47
+ vertices_2_matching_keyvertices_1 = sim_matrix_12.argmax(axis=1)
48
+ mesh_2 = create_mesh(mesh_name_2, embeddings_2.device)
49
+ geodists = mesh_2.geodists[
50
+ vertices_2_matching_keyvertices_1,
51
+ [keyvertices_2[name] for name in keyvertex_names_1],
52
+ ]
53
+ Current_Mean_Distances = 0.255
54
+ gps = (-(geodists**2) / (2 * (Current_Mean_Distances**2))).exp()
55
+ avg_errors.append(geodists.mean().item())
56
+ avg_gps.append(gps.mean().item())
57
+
58
+ ge_mean = torch.as_tensor(avg_errors).mean().item()
59
+ gps_mean = torch.as_tensor(avg_gps).mean().item()
60
+ ge_per_mesh[mesh_name_1] = ge_mean
61
+ gps_per_mesh[mesh_name_1] = gps_mean
62
+ ge_mean_global = torch.as_tensor(list(ge_per_mesh.values())).mean().item()
63
+ gps_mean_global = torch.as_tensor(list(gps_per_mesh.values())).mean().item()
64
+ per_mesh_metrics = {
65
+ "GE": ge_per_mesh,
66
+ "GPS": gps_per_mesh,
67
+ }
68
+ return ge_mean_global, gps_mean_global, per_mesh_metrics
densepose/evaluation/tensor_storage.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ # pyre-unsafe
4
+
5
+ import io
6
+ import numpy as np
7
+ import os
8
+ from dataclasses import dataclass
9
+ from functools import reduce
10
+ from operator import mul
11
+ from typing import BinaryIO, Dict, Optional, Tuple
12
+ import torch
13
+
14
+ from detectron2.utils.comm import gather, get_rank
15
+ from detectron2.utils.file_io import PathManager
16
+
17
+
18
+ @dataclass
19
+ class SizeData:
20
+ dtype: str
21
+ shape: Tuple[int]
22
+
23
+
24
+ def _calculate_record_field_size_b(data_schema: Dict[str, SizeData], field_name: str) -> int:
25
+ schema = data_schema[field_name]
26
+ element_size_b = np.dtype(schema.dtype).itemsize
27
+ record_field_size_b = reduce(mul, schema.shape) * element_size_b
28
+ return record_field_size_b
29
+
30
+
31
+ def _calculate_record_size_b(data_schema: Dict[str, SizeData]) -> int:
32
+ record_size_b = 0
33
+ for field_name in data_schema:
34
+ record_field_size_b = _calculate_record_field_size_b(data_schema, field_name)
35
+ record_size_b += record_field_size_b
36
+ return record_size_b
37
+
38
+
39
+ def _calculate_record_field_sizes_b(data_schema: Dict[str, SizeData]) -> Dict[str, int]:
40
+ field_sizes_b = {}
41
+ for field_name in data_schema:
42
+ field_sizes_b[field_name] = _calculate_record_field_size_b(data_schema, field_name)
43
+ return field_sizes_b
44
+
45
+
46
+ class SingleProcessTensorStorage:
47
+ """
48
+ Compact tensor storage to keep tensor data of predefined size and type.
49
+ """
50
+
51
+ def __init__(self, data_schema: Dict[str, SizeData], storage_impl: BinaryIO):
52
+ """
53
+ Construct tensor storage based on information on data shape and size.
54
+ Internally uses numpy to interpret the type specification.
55
+ The storage must support operations `seek(offset, whence=os.SEEK_SET)` and
56
+ `read(size)` to be able to perform the `get` operation.
57
+ The storage must support operation `write(bytes)` to be able to perform
58
+ the `put` operation.
59
+
60
+ Args:
61
+ data_schema (dict: str -> SizeData): dictionary which maps tensor name
62
+ to its size data (shape and data type), e.g.
63
+ ```
64
+ {
65
+ "coarse_segm": SizeData(dtype="float32", shape=(112, 112)),
66
+ "embedding": SizeData(dtype="float32", shape=(16, 112, 112)),
67
+ }
68
+ ```
69
+ storage_impl (BinaryIO): io instance that handles file-like seek, read
70
+ and write operations, e.g. a file handle or a memory buffer like io.BytesIO
71
+ """
72
+ self.data_schema = data_schema
73
+ self.record_size_b = _calculate_record_size_b(data_schema)
74
+ self.record_field_sizes_b = _calculate_record_field_sizes_b(data_schema)
75
+ self.storage_impl = storage_impl
76
+ self.next_record_id = 0
77
+
78
+ def get(self, record_id: int) -> Dict[str, torch.Tensor]:
79
+ """
80
+ Load tensors from the storage by record ID
81
+
82
+ Args:
83
+ record_id (int): Record ID, for which to load the data
84
+
85
+ Return:
86
+ dict: str -> tensor: tensor name mapped to tensor data, recorded under the provided ID
87
+ """
88
+ self.storage_impl.seek(record_id * self.record_size_b, os.SEEK_SET)
89
+ data_bytes = self.storage_impl.read(self.record_size_b)
90
+ assert len(data_bytes) == self.record_size_b, (
91
+ f"Expected data size {self.record_size_b} B could not be read: "
92
+ f"got {len(data_bytes)} B"
93
+ )
94
+ record = {}
95
+ cur_idx = 0
96
+ # it's important to read and write in the same order
97
+ for field_name in sorted(self.data_schema):
98
+ schema = self.data_schema[field_name]
99
+ field_size_b = self.record_field_sizes_b[field_name]
100
+ chunk = data_bytes[cur_idx : cur_idx + field_size_b]
101
+ data_np = np.frombuffer(
102
+ chunk, dtype=schema.dtype, count=reduce(mul, schema.shape)
103
+ ).reshape(schema.shape)
104
+ record[field_name] = torch.from_numpy(data_np)
105
+ cur_idx += field_size_b
106
+ return record
107
+
108
+ def put(self, data: Dict[str, torch.Tensor]) -> int:
109
+ """
110
+ Store tensors in the storage
111
+
112
+ Args:
113
+ data (dict: str -> tensor): data to store, a dictionary which maps
114
+ tensor names into tensors; tensor shapes must match those specified
115
+ in data schema.
116
+ Return:
117
+ int: record ID, under which the data is stored
118
+ """
119
+ # it's important to read and write in the same order
120
+ for field_name in sorted(self.data_schema):
121
+ assert (
122
+ field_name in data
123
+ ), f"Field '{field_name}' not present in data: data keys are {data.keys()}"
124
+ value = data[field_name]
125
+ assert value.shape == self.data_schema[field_name].shape, (
126
+ f"Mismatched tensor shapes for field '{field_name}': "
127
+ f"expected {self.data_schema[field_name].shape}, got {value.shape}"
128
+ )
129
+ data_bytes = value.cpu().numpy().tobytes()
130
+ assert len(data_bytes) == self.record_field_sizes_b[field_name], (
131
+ f"Expected field {field_name} to be of size "
132
+ f"{self.record_field_sizes_b[field_name]} B, got {len(data_bytes)} B"
133
+ )
134
+ self.storage_impl.write(data_bytes)
135
+ record_id = self.next_record_id
136
+ self.next_record_id += 1
137
+ return record_id
138
+
139
+
140
+ class SingleProcessFileTensorStorage(SingleProcessTensorStorage):
141
+ """
142
+ Implementation of a single process tensor storage which stores data in a file
143
+ """
144
+
145
+ def __init__(self, data_schema: Dict[str, SizeData], fpath: str, mode: str):
146
+ self.fpath = fpath
147
+ assert "b" in mode, f"Tensor storage should be opened in binary mode, got '{mode}'"
148
+ if "w" in mode:
149
+ # pyre-fixme[6]: For 2nd argument expected `Union[typing_extensions.Liter...
150
+ file_h = PathManager.open(fpath, mode)
151
+ elif "r" in mode:
152
+ local_fpath = PathManager.get_local_path(fpath)
153
+ file_h = open(local_fpath, mode)
154
+ else:
155
+ raise ValueError(f"Unsupported file mode {mode}, supported modes: rb, wb")
156
+ super().__init__(data_schema, file_h) # pyre-ignore[6]
157
+
158
+
159
+ class SingleProcessRamTensorStorage(SingleProcessTensorStorage):
160
+ """
161
+ Implementation of a single process tensor storage which stores data in RAM
162
+ """
163
+
164
+ def __init__(self, data_schema: Dict[str, SizeData], buf: io.BytesIO):
165
+ super().__init__(data_schema, buf)
166
+
167
+
168
+ class MultiProcessTensorStorage:
169
+ """
170
+ Representation of a set of tensor storages created by individual processes,
171
+ allows to access those storages from a single owner process. The storages
172
+ should either be shared or broadcasted to the owner process.
173
+ The processes are identified by their rank, data is uniquely defined by
174
+ the rank of the process and the record ID.
175
+ """
176
+
177
+ def __init__(self, rank_to_storage: Dict[int, SingleProcessTensorStorage]):
178
+ self.rank_to_storage = rank_to_storage
179
+
180
+ def get(self, rank: int, record_id: int) -> Dict[str, torch.Tensor]:
181
+ storage = self.rank_to_storage[rank]
182
+ return storage.get(record_id)
183
+
184
+ def put(self, rank: int, data: Dict[str, torch.Tensor]) -> int:
185
+ storage = self.rank_to_storage[rank]
186
+ return storage.put(data)
187
+
188
+
189
+ class MultiProcessFileTensorStorage(MultiProcessTensorStorage):
190
+ def __init__(self, data_schema: Dict[str, SizeData], rank_to_fpath: Dict[int, str], mode: str):
191
+ rank_to_storage = {
192
+ rank: SingleProcessFileTensorStorage(data_schema, fpath, mode)
193
+ for rank, fpath in rank_to_fpath.items()
194
+ }
195
+ super().__init__(rank_to_storage) # pyre-ignore[6]
196
+
197
+
198
+ class MultiProcessRamTensorStorage(MultiProcessTensorStorage):
199
+ def __init__(self, data_schema: Dict[str, SizeData], rank_to_buffer: Dict[int, io.BytesIO]):
200
+ rank_to_storage = {
201
+ rank: SingleProcessRamTensorStorage(data_schema, buf)
202
+ for rank, buf in rank_to_buffer.items()
203
+ }
204
+ super().__init__(rank_to_storage) # pyre-ignore[6]
205
+
206
+
207
+ def _ram_storage_gather(
208
+ storage: SingleProcessRamTensorStorage, dst_rank: int = 0
209
+ ) -> Optional[MultiProcessRamTensorStorage]:
210
+ storage.storage_impl.seek(0, os.SEEK_SET)
211
+ # TODO: overhead, pickling a bytes object, can just pass bytes in a tensor directly
212
+ # see detectron2/utils.comm.py
213
+ data_list = gather(storage.storage_impl.read(), dst=dst_rank)
214
+ if get_rank() != dst_rank:
215
+ return None
216
+ rank_to_buffer = {i: io.BytesIO(data_list[i]) for i in range(len(data_list))}
217
+ multiprocess_storage = MultiProcessRamTensorStorage(storage.data_schema, rank_to_buffer)
218
+ return multiprocess_storage
219
+
220
+
221
+ def _file_storage_gather(
222
+ storage: SingleProcessFileTensorStorage,
223
+ dst_rank: int = 0,
224
+ mode: str = "rb",
225
+ ) -> Optional[MultiProcessFileTensorStorage]:
226
+ storage.storage_impl.close()
227
+ fpath_list = gather(storage.fpath, dst=dst_rank)
228
+ if get_rank() != dst_rank:
229
+ return None
230
+ rank_to_fpath = {i: fpath_list[i] for i in range(len(fpath_list))}
231
+ return MultiProcessFileTensorStorage(storage.data_schema, rank_to_fpath, mode)
232
+
233
+
234
+ def storage_gather(
235
+ storage: SingleProcessTensorStorage, dst_rank: int = 0
236
+ ) -> Optional[MultiProcessTensorStorage]:
237
+ if isinstance(storage, SingleProcessRamTensorStorage):
238
+ return _ram_storage_gather(storage, dst_rank)
239
+ elif isinstance(storage, SingleProcessFileTensorStorage):
240
+ return _file_storage_gather(storage, dst_rank)
241
+ raise Exception(f"Unsupported storage for gather operation: {storage}")