Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +28 -0
- CatVTON/.gitignore +14 -0
- CatVTON/LICENSE +107 -0
- CatVTON/README.md +11 -0
- CatVTON/__pycache__/utils.cpython-311.pyc +0 -0
- CatVTON/__pycache__/utils.cpython-39.pyc +0 -0
- CatVTON/app.py +20 -0
- CatVTON/app_flux.py +305 -0
- CatVTON/app_p2p.py +567 -0
- CatVTON/densepose/__init__.py +22 -0
- CatVTON/densepose/__pycache__/__init__.cpython-311.pyc +0 -0
- CatVTON/densepose/__pycache__/config.cpython-311.pyc +0 -0
- CatVTON/densepose/config.py +277 -0
- CatVTON/densepose/converters/__init__.py +17 -0
- CatVTON/densepose/converters/__pycache__/__init__.cpython-311.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/base.cpython-311.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/builtin.cpython-311.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/chart_output_hflip.cpython-311.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/chart_output_to_chart_result.cpython-311.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/hflip.cpython-311.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/segm_to_mask.cpython-311.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/to_chart_result.cpython-311.pyc +0 -0
- CatVTON/densepose/converters/__pycache__/to_mask.cpython-311.pyc +0 -0
- CatVTON/densepose/converters/base.py +95 -0
- CatVTON/densepose/converters/builtin.py +33 -0
- CatVTON/densepose/converters/chart_output_hflip.py +73 -0
- CatVTON/densepose/converters/chart_output_to_chart_result.py +190 -0
- CatVTON/densepose/converters/hflip.py +36 -0
- CatVTON/densepose/converters/segm_to_mask.py +152 -0
- CatVTON/densepose/converters/to_chart_result.py +72 -0
- CatVTON/densepose/converters/to_mask.py +51 -0
- CatVTON/densepose/data/__init__.py +27 -0
- CatVTON/densepose/data/__pycache__/__init__.cpython-311.pyc +0 -0
- CatVTON/densepose/data/__pycache__/build.cpython-311.pyc +0 -0
- CatVTON/densepose/data/__pycache__/combined_loader.cpython-311.pyc +0 -0
- CatVTON/densepose/data/__pycache__/dataset_mapper.cpython-311.pyc +0 -0
- CatVTON/densepose/data/__pycache__/image_list_dataset.cpython-311.pyc +0 -0
- CatVTON/densepose/data/__pycache__/inference_based_loader.cpython-311.pyc +0 -0
- CatVTON/densepose/data/__pycache__/utils.cpython-311.pyc +0 -0
- CatVTON/densepose/data/build.py +738 -0
- CatVTON/densepose/data/combined_loader.py +46 -0
- CatVTON/densepose/data/dataset_mapper.py +170 -0
- CatVTON/densepose/data/datasets/__init__.py +7 -0
- CatVTON/densepose/data/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
- CatVTON/densepose/data/datasets/__pycache__/builtin.cpython-311.pyc +0 -0
- CatVTON/densepose/data/datasets/__pycache__/chimpnsee.cpython-311.pyc +0 -0
- CatVTON/densepose/data/datasets/__pycache__/coco.cpython-311.pyc +0 -0
- CatVTON/densepose/data/datasets/__pycache__/dataset_type.cpython-311.pyc +0 -0
- CatVTON/densepose/data/datasets/__pycache__/lvis.cpython-311.pyc +0 -0
- CatVTON/densepose/data/datasets/builtin.py +18 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,31 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
CatVTON/detectron2/_C.cpython-39-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
CatVTON/detectron2/data/datasets/__pycache__/lvis_v0_5_categories.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
CatVTON/detectron2/data/datasets/__pycache__/lvis_v1_categories.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
CatVTON/garment.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
CatVTON/person.jpg filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
CatVTON/resource/demo/example/condition/overall/21744571_51588794_1000.jpg filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
CatVTON/resource/demo/example/condition/overall/23962182_54027982_1000.jpg filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
CatVTON/resource/demo/example/condition/person/baumu30483223c3_1719437121402_2-0._QL90_UX564_V12524t6_.jpg filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
CatVTON/resource/demo/example/condition/person/mison407622250d_1719258948458_2-0._QL90_UX564_V12524t6_.jpg filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
CatVTON/resource/demo/example/condition/person/mothr22044226e8_1718142523286_2-0._QL90_UX564_V12524t6_.jpg filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
CatVTON/resource/demo/example/condition/upper/21514384_52353349_1000.jpg filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
CatVTON/resource/demo/example/condition/upper/22790049_53294275_1000.jpg filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
CatVTON/resource/demo/example/condition/upper/24083449_54173465_2048.jpg filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
CatVTON/resource/demo/example/person/men/Simon_1.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
CatVTON/resource/demo/example/person/men/Yifeng_0.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
CatVTON/resource/demo/example/person/men/model_5.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
CatVTON/resource/demo/example/person/men/model_7.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
CatVTON/resource/demo/example/person/women/049713_0.jpg filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
CatVTON/resource/demo/example/person/women/1-model_3.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
CatVTON/resource/demo/example/person/women/2-model_4.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
CatVTON/resource/demo/example/person/women/model_8.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
CatVTON/resource/img/architecture.jpg filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
CatVTON/resource/img/comfyui-1.png filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
CatVTON/resource/img/comfyui.png filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
CatVTON/resource/img/efficency.jpg filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
CatVTON/resource/img/structure.jpg filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
CatVTON/resource/img/teaser.jpg filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
CatVTON/result.jpg filter=lfs diff=lfs merge=lfs -text
|
CatVTON/.gitignore
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
model/__pycache__
|
| 3 |
+
model/DensePose/__pycache__
|
| 4 |
+
model/SCHP/__pycache__
|
| 5 |
+
model/SCHP/*/__pycache__
|
| 6 |
+
resource/demo/output
|
| 7 |
+
resource/demo/example/.DS_Store
|
| 8 |
+
Models
|
| 9 |
+
Datasets
|
| 10 |
+
densepose_
|
| 11 |
+
.vscode
|
| 12 |
+
playground.py
|
| 13 |
+
output
|
| 14 |
+
model/cloth_masker_segformer.py
|
CatVTON/LICENSE
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
| 2 |
+
|
| 3 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an "as-is" basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
|
| 4 |
+
|
| 5 |
+
Using Creative Commons Public Licenses
|
| 6 |
+
|
| 7 |
+
Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
|
| 8 |
+
|
| 9 |
+
Considerations for licensors: Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. More considerations for licensors : wiki.creativecommons.org/Considerations_for_licensors
|
| 10 |
+
|
| 11 |
+
Considerations for the public: By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor's permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. More considerations for the public : wiki.creativecommons.org/Considerations_for_licensees
|
| 12 |
+
|
| 13 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
| 14 |
+
|
| 15 |
+
By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
|
| 16 |
+
|
| 17 |
+
Section 1 – Definitions.
|
| 18 |
+
|
| 19 |
+
a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
|
| 20 |
+
b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
|
| 21 |
+
c. BY-NC-SA Compatible License means a license listed at creativecommons.org/compatiblelicenses, approved by Creative Commons as essentially the equivalent of this Public License.
|
| 22 |
+
d. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
|
| 23 |
+
e. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
|
| 24 |
+
f. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
|
| 25 |
+
g. License Elements means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike.
|
| 26 |
+
h. Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
|
| 27 |
+
i. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
|
| 28 |
+
j. Licensor means the individual(s) or entity(ies) granting rights under this Public License.
|
| 29 |
+
k. NonCommercial means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
|
| 30 |
+
l. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
|
| 31 |
+
m. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
|
| 32 |
+
n. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
|
| 33 |
+
Section 2 – Scope.
|
| 34 |
+
|
| 35 |
+
a. License grant.
|
| 36 |
+
1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
|
| 37 |
+
A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
|
| 38 |
+
B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
|
| 39 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
|
| 40 |
+
3. Term. The term of this Public License is specified in Section 6(a).
|
| 41 |
+
4. Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
|
| 42 |
+
5. Downstream recipients.
|
| 43 |
+
A. Offer from the Licensor – Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
|
| 44 |
+
B. Additional offer from the Licensor – Adapted Material. Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter's License You apply.
|
| 45 |
+
C. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
|
| 46 |
+
6. No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
|
| 47 |
+
b. Other rights.
|
| 48 |
+
1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
|
| 49 |
+
2. Patent and trademark rights are not licensed under this Public License.
|
| 50 |
+
3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
|
| 51 |
+
Section 3 – License Conditions.
|
| 52 |
+
|
| 53 |
+
Your exercise of the Licensed Rights is expressly made subject to the following conditions.
|
| 54 |
+
|
| 55 |
+
a. Attribution.
|
| 56 |
+
1. If You Share the Licensed Material (including in modified form), You must:
|
| 57 |
+
A. retain the following if it is supplied by the Licensor with the Licensed Material:
|
| 58 |
+
i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
|
| 59 |
+
ii. a copyright notice;
|
| 60 |
+
iii. a notice that refers to this Public License;
|
| 61 |
+
iv. a notice that refers to the disclaimer of warranties;
|
| 62 |
+
v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
|
| 63 |
+
|
| 64 |
+
B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
|
| 65 |
+
C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
|
| 66 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
|
| 67 |
+
3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
|
| 68 |
+
b. ShareAlike.In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply.
|
| 69 |
+
1. The Adapter's License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License.
|
| 70 |
+
2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material.
|
| 71 |
+
3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply.
|
| 72 |
+
Section 4 – Sui Generis Database Rights.
|
| 73 |
+
|
| 74 |
+
Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
|
| 75 |
+
|
| 76 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
|
| 77 |
+
b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and
|
| 78 |
+
c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
|
| 79 |
+
For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
|
| 80 |
+
Section 5 – Disclaimer of Warranties and Limitation of Liability.
|
| 81 |
+
|
| 82 |
+
a. Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.
|
| 83 |
+
b. To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.
|
| 84 |
+
c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
|
| 85 |
+
Section 6 – Term and Termination.
|
| 86 |
+
|
| 87 |
+
a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
|
| 88 |
+
b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
|
| 89 |
+
1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
|
| 90 |
+
2. upon express reinstatement by the Licensor.
|
| 91 |
+
For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
|
| 92 |
+
|
| 93 |
+
c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
|
| 94 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
|
| 95 |
+
Section 7 – Other Terms and Conditions.
|
| 96 |
+
|
| 97 |
+
a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
|
| 98 |
+
b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
|
| 99 |
+
Section 8 – Interpretation.
|
| 100 |
+
|
| 101 |
+
a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
|
| 102 |
+
b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
|
| 103 |
+
c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
|
| 104 |
+
d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
|
| 105 |
+
Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the "Licensor." The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark "Creative Commons" or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
|
| 106 |
+
|
| 107 |
+
Creative Commons may be contacted at creativecommons.org.
|
CatVTON/README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CatVTON Virtual Try-On (Hugging Face Space)
|
| 2 |
+
|
| 3 |
+
This Space hosts a Gradio interface for the CatVTON model.
|
| 4 |
+
Upload a person image and a garment image to generate a virtual try-on result.
|
| 5 |
+
|
| 6 |
+
## Run locally
|
| 7 |
+
```bash
|
| 8 |
+
pip install -r requirements.txt
|
| 9 |
+
python app.py
|
| 10 |
+
```
|
| 11 |
+
Then open http://127.0.0.1:7860 in your browser.
|
CatVTON/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (22.4 kB). View file
|
|
|
CatVTON/__pycache__/utils.cpython-39.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
CatVTON/app.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from inference import get_pipeline
|
| 4 |
+
|
| 5 |
+
def tryon(person_img, cloth_img):
|
| 6 |
+
pipeline = get_pipeline()
|
| 7 |
+
result, meta = pipeline.run_inference(person_img, cloth_img, num_steps=25)
|
| 8 |
+
return result, f"Seed: {meta['seed']} | Time: {meta['timings']['total_ms']} ms"
|
| 9 |
+
|
| 10 |
+
demo = gr.Interface(
|
| 11 |
+
fn=tryon,
|
| 12 |
+
inputs=[gr.Image(type="pil", label="Person"), gr.Image(type="pil", label="Garment")],
|
| 13 |
+
outputs=[gr.Image(type="pil", label="Result"), gr.Textbox(label="Metadata")],
|
| 14 |
+
title="CatVTON Virtual Try-On",
|
| 15 |
+
description="Upload a person image and a garment image to visualize the CatVTON try-on result.",
|
| 16 |
+
allow_flagging="never",
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
if __name__ == "__main__":
|
| 20 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
CatVTON/app_flux.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 9 |
+
from huggingface_hub import snapshot_download
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from model.cloth_masker import AutoMasker, vis_mask
|
| 13 |
+
from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
|
| 14 |
+
from utils import resize_and_crop, resize_and_padding
|
| 15 |
+
|
| 16 |
+
def parse_args():
|
| 17 |
+
parser = argparse.ArgumentParser(description="FLUX Try-On Demo")
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--base_model_path",
|
| 20 |
+
type=str,
|
| 21 |
+
# default="black-forest-labs/FLUX.1-Fill-dev",
|
| 22 |
+
default="Models/FLUX.1-Fill-dev",
|
| 23 |
+
help="The path to the base model to use for evaluation."
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--resume_path",
|
| 27 |
+
type=str,
|
| 28 |
+
default="zhengchong/CatVTON",
|
| 29 |
+
help="The Path to the checkpoint of trained tryon model."
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--output_dir",
|
| 33 |
+
type=str,
|
| 34 |
+
default="resource/demo/output",
|
| 35 |
+
help="The output directory where the model predictions will be written."
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--mixed_precision",
|
| 39 |
+
type=str,
|
| 40 |
+
default="bf16",
|
| 41 |
+
choices=["no", "fp16", "bf16"],
|
| 42 |
+
help="Whether to use mixed precision."
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--allow_tf32",
|
| 46 |
+
action="store_true",
|
| 47 |
+
default=True,
|
| 48 |
+
help="Whether or not to allow TF32 on Ampere GPUs."
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--width",
|
| 52 |
+
type=int,
|
| 53 |
+
default=768,
|
| 54 |
+
help="The width of the input image."
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--height",
|
| 58 |
+
type=int,
|
| 59 |
+
default=1024,
|
| 60 |
+
help="The height of the input image."
|
| 61 |
+
)
|
| 62 |
+
return parser.parse_args()
|
| 63 |
+
|
| 64 |
+
def image_grid(imgs, rows, cols):
|
| 65 |
+
assert len(imgs) == rows * cols
|
| 66 |
+
w, h = imgs[0].size
|
| 67 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
| 68 |
+
for i, img in enumerate(imgs):
|
| 69 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
| 70 |
+
return grid
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def submit_function_flux(
|
| 74 |
+
person_image,
|
| 75 |
+
cloth_image,
|
| 76 |
+
cloth_type,
|
| 77 |
+
num_inference_steps,
|
| 78 |
+
guidance_scale,
|
| 79 |
+
seed,
|
| 80 |
+
show_type
|
| 81 |
+
):
|
| 82 |
+
|
| 83 |
+
# Process image editor input
|
| 84 |
+
person_image, mask = person_image["background"], person_image["layers"][0]
|
| 85 |
+
mask = Image.open(mask).convert("L")
|
| 86 |
+
if len(np.unique(np.array(mask))) == 1:
|
| 87 |
+
mask = None
|
| 88 |
+
else:
|
| 89 |
+
mask = np.array(mask)
|
| 90 |
+
mask[mask > 0] = 255
|
| 91 |
+
mask = Image.fromarray(mask)
|
| 92 |
+
|
| 93 |
+
# Set random seed
|
| 94 |
+
generator = None
|
| 95 |
+
if seed != -1:
|
| 96 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
| 97 |
+
|
| 98 |
+
# Process input images
|
| 99 |
+
person_image = Image.open(person_image).convert("RGB")
|
| 100 |
+
cloth_image = Image.open(cloth_image).convert("RGB")
|
| 101 |
+
|
| 102 |
+
# Adjust image sizes
|
| 103 |
+
person_image = resize_and_crop(person_image, (args.width, args.height))
|
| 104 |
+
cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
|
| 105 |
+
|
| 106 |
+
# Process mask
|
| 107 |
+
if mask is not None:
|
| 108 |
+
mask = resize_and_crop(mask, (args.width, args.height))
|
| 109 |
+
else:
|
| 110 |
+
mask = automasker(
|
| 111 |
+
person_image,
|
| 112 |
+
cloth_type
|
| 113 |
+
)['mask']
|
| 114 |
+
mask = mask_processor.blur(mask, blur_factor=9)
|
| 115 |
+
|
| 116 |
+
# Inference
|
| 117 |
+
result_image = pipeline_flux(
|
| 118 |
+
image=person_image,
|
| 119 |
+
condition_image=cloth_image,
|
| 120 |
+
mask_image=mask,
|
| 121 |
+
height=args.height,
|
| 122 |
+
width=args.width,
|
| 123 |
+
num_inference_steps=num_inference_steps,
|
| 124 |
+
guidance_scale=guidance_scale,
|
| 125 |
+
generator=generator
|
| 126 |
+
).images[0]
|
| 127 |
+
|
| 128 |
+
# Post-processing
|
| 129 |
+
masked_person = vis_mask(person_image, mask)
|
| 130 |
+
|
| 131 |
+
# Return result based on show type
|
| 132 |
+
if show_type == "result only":
|
| 133 |
+
return result_image
|
| 134 |
+
else:
|
| 135 |
+
width, height = person_image.size
|
| 136 |
+
if show_type == "input & result":
|
| 137 |
+
condition_width = width // 2
|
| 138 |
+
conditions = image_grid([person_image, cloth_image], 2, 1)
|
| 139 |
+
else:
|
| 140 |
+
condition_width = width // 3
|
| 141 |
+
conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
|
| 142 |
+
|
| 143 |
+
conditions = conditions.resize((condition_width, height), Image.NEAREST)
|
| 144 |
+
new_result_image = Image.new("RGB", (width + condition_width + 5, height))
|
| 145 |
+
new_result_image.paste(conditions, (0, 0))
|
| 146 |
+
new_result_image.paste(result_image, (condition_width + 5, 0))
|
| 147 |
+
return new_result_image
|
| 148 |
+
|
| 149 |
+
def person_example_fn(image_path):
|
| 150 |
+
return image_path
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def app_gradio():
|
| 154 |
+
with gr.Blocks(title="CatVTON with FLUX.1-Fill-dev") as demo:
|
| 155 |
+
gr.Markdown("# CatVTON with FLUX.1-Fill-dev")
|
| 156 |
+
with gr.Row():
|
| 157 |
+
with gr.Column(scale=1, min_width=350):
|
| 158 |
+
with gr.Row():
|
| 159 |
+
image_path_flux = gr.Image(
|
| 160 |
+
type="filepath",
|
| 161 |
+
interactive=True,
|
| 162 |
+
visible=False,
|
| 163 |
+
)
|
| 164 |
+
person_image_flux = gr.ImageEditor(
|
| 165 |
+
interactive=True, label="Person Image", type="filepath"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
with gr.Row():
|
| 169 |
+
with gr.Column(scale=1, min_width=230):
|
| 170 |
+
cloth_image_flux = gr.Image(
|
| 171 |
+
interactive=True, label="Condition Image", type="filepath"
|
| 172 |
+
)
|
| 173 |
+
with gr.Column(scale=1, min_width=120):
|
| 174 |
+
gr.Markdown(
|
| 175 |
+
'<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
|
| 176 |
+
)
|
| 177 |
+
cloth_type = gr.Radio(
|
| 178 |
+
label="Try-On Cloth Type",
|
| 179 |
+
choices=["upper", "lower", "overall"],
|
| 180 |
+
value="upper",
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
submit_flux = gr.Button("Submit")
|
| 184 |
+
gr.Markdown(
|
| 185 |
+
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 189 |
+
num_inference_steps_flux = gr.Slider(
|
| 190 |
+
label="Inference Step", minimum=10, maximum=100, step=5, value=50
|
| 191 |
+
)
|
| 192 |
+
# Guidence Scale
|
| 193 |
+
guidance_scale_flux = gr.Slider(
|
| 194 |
+
label="CFG Strenth", minimum=0.0, maximum=50, step=0.5, value=30
|
| 195 |
+
)
|
| 196 |
+
# Random Seed
|
| 197 |
+
seed_flux = gr.Slider(
|
| 198 |
+
label="Seed", minimum=-1, maximum=10000, step=1, value=42
|
| 199 |
+
)
|
| 200 |
+
show_type = gr.Radio(
|
| 201 |
+
label="Show Type",
|
| 202 |
+
choices=["result only", "input & result", "input & mask & result"],
|
| 203 |
+
value="input & mask & result",
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
with gr.Column(scale=2, min_width=500):
|
| 207 |
+
result_image_flux = gr.Image(interactive=False, label="Result")
|
| 208 |
+
with gr.Row():
|
| 209 |
+
# Photo Examples
|
| 210 |
+
root_path = "resource/demo/example"
|
| 211 |
+
with gr.Column():
|
| 212 |
+
gr.Examples(
|
| 213 |
+
examples=[
|
| 214 |
+
os.path.join(root_path, "person", "men", _)
|
| 215 |
+
for _ in os.listdir(os.path.join(root_path, "person", "men"))
|
| 216 |
+
],
|
| 217 |
+
examples_per_page=4,
|
| 218 |
+
inputs=image_path_flux,
|
| 219 |
+
label="Person Examples ①",
|
| 220 |
+
)
|
| 221 |
+
gr.Examples(
|
| 222 |
+
examples=[
|
| 223 |
+
os.path.join(root_path, "person", "women", _)
|
| 224 |
+
for _ in os.listdir(os.path.join(root_path, "person", "women"))
|
| 225 |
+
],
|
| 226 |
+
examples_per_page=4,
|
| 227 |
+
inputs=image_path_flux,
|
| 228 |
+
label="Person Examples ②",
|
| 229 |
+
)
|
| 230 |
+
gr.Markdown(
|
| 231 |
+
'<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
|
| 232 |
+
)
|
| 233 |
+
with gr.Column():
|
| 234 |
+
gr.Examples(
|
| 235 |
+
examples=[
|
| 236 |
+
os.path.join(root_path, "condition", "upper", _)
|
| 237 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
|
| 238 |
+
],
|
| 239 |
+
examples_per_page=4,
|
| 240 |
+
inputs=cloth_image_flux,
|
| 241 |
+
label="Condition Upper Examples",
|
| 242 |
+
)
|
| 243 |
+
gr.Examples(
|
| 244 |
+
examples=[
|
| 245 |
+
os.path.join(root_path, "condition", "overall", _)
|
| 246 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
|
| 247 |
+
],
|
| 248 |
+
examples_per_page=4,
|
| 249 |
+
inputs=cloth_image_flux,
|
| 250 |
+
label="Condition Overall Examples",
|
| 251 |
+
)
|
| 252 |
+
condition_person_exm = gr.Examples(
|
| 253 |
+
examples=[
|
| 254 |
+
os.path.join(root_path, "condition", "person", _)
|
| 255 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "person"))
|
| 256 |
+
],
|
| 257 |
+
examples_per_page=4,
|
| 258 |
+
inputs=cloth_image_flux,
|
| 259 |
+
label="Condition Reference Person Examples",
|
| 260 |
+
)
|
| 261 |
+
gr.Markdown(
|
| 262 |
+
'<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
image_path_flux.change(
|
| 267 |
+
person_example_fn, inputs=image_path_flux, outputs=person_image_flux
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
submit_flux.click(
|
| 271 |
+
submit_function_flux,
|
| 272 |
+
[person_image_flux, cloth_image_flux, cloth_type, num_inference_steps_flux, guidance_scale_flux, seed_flux, show_type],
|
| 273 |
+
result_image_flux,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
demo.queue().launch(share=True, show_error=True)
|
| 278 |
+
|
| 279 |
+
# 解析参数
|
| 280 |
+
args = parse_args()
|
| 281 |
+
|
| 282 |
+
# 加载模型
|
| 283 |
+
repo_path = snapshot_download(repo_id=args.resume_path)
|
| 284 |
+
pipeline_flux = FluxTryOnPipeline.from_pretrained(args.base_model_path)
|
| 285 |
+
pipeline_flux.load_lora_weights(
|
| 286 |
+
os.path.join(repo_path, "flux-lora"),
|
| 287 |
+
weight_name='pytorch_lora_weights.safetensors'
|
| 288 |
+
)
|
| 289 |
+
pipeline_flux.to("cuda", torch.bfloat16)
|
| 290 |
+
|
| 291 |
+
# 初始化 AutoMasker
|
| 292 |
+
mask_processor = VaeImageProcessor(
|
| 293 |
+
vae_scale_factor=8,
|
| 294 |
+
do_normalize=False,
|
| 295 |
+
do_binarize=True,
|
| 296 |
+
do_convert_grayscale=True
|
| 297 |
+
)
|
| 298 |
+
automasker = AutoMasker(
|
| 299 |
+
densepose_ckpt=os.path.join(repo_path, "DensePose"),
|
| 300 |
+
schp_ckpt=os.path.join(repo_path, "SCHP"),
|
| 301 |
+
device='cuda'
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
if __name__ == "__main__":
|
| 305 |
+
app_gradio()
|
CatVTON/app_p2p.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 9 |
+
from huggingface_hub import snapshot_download
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from model.cloth_masker import AutoMasker, vis_mask
|
| 13 |
+
from model.pipeline import CatVTONPipeline, CatVTONPix2PixPipeline
|
| 14 |
+
from utils import init_weight_dtype, resize_and_crop, resize_and_padding
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def parse_args():
|
| 18 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--p2p_base_model_path",
|
| 21 |
+
type=str,
|
| 22 |
+
default="timbrooks/instruct-pix2pix",
|
| 23 |
+
help=(
|
| 24 |
+
"The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
|
| 25 |
+
),
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--ip_base_model_path",
|
| 29 |
+
type=str,
|
| 30 |
+
default="booksforcharlie/stable-diffusion-inpainting",
|
| 31 |
+
help=(
|
| 32 |
+
"The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
|
| 33 |
+
),
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--p2p_resume_path",
|
| 37 |
+
type=str,
|
| 38 |
+
default="zhengchong/CatVTON-MaskFree",
|
| 39 |
+
help=(
|
| 40 |
+
"The Path to the checkpoint of trained tryon model."
|
| 41 |
+
),
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--ip_resume_path",
|
| 45 |
+
type=str,
|
| 46 |
+
default="zhengchong/CatVTON",
|
| 47 |
+
help=(
|
| 48 |
+
"The Path to the checkpoint of trained tryon model."
|
| 49 |
+
),
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--output_dir",
|
| 53 |
+
type=str,
|
| 54 |
+
default="resource/demo/output",
|
| 55 |
+
help="The output directory where the model predictions will be written.",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--width",
|
| 60 |
+
type=int,
|
| 61 |
+
default=768,
|
| 62 |
+
help=(
|
| 63 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 64 |
+
" resolution"
|
| 65 |
+
),
|
| 66 |
+
)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--height",
|
| 69 |
+
type=int,
|
| 70 |
+
default=1024,
|
| 71 |
+
help=(
|
| 72 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 73 |
+
" resolution"
|
| 74 |
+
),
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--repaint",
|
| 78 |
+
action="store_true",
|
| 79 |
+
help="Whether to repaint the result image with the original background."
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--allow_tf32",
|
| 83 |
+
action="store_true",
|
| 84 |
+
default=True,
|
| 85 |
+
help=(
|
| 86 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 87 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 88 |
+
),
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--mixed_precision",
|
| 92 |
+
type=str,
|
| 93 |
+
default="bf16",
|
| 94 |
+
choices=["no", "fp16", "bf16"],
|
| 95 |
+
help=(
|
| 96 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 97 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 98 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 99 |
+
),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
args = parser.parse_args()
|
| 103 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 104 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 105 |
+
args.local_rank = env_local_rank
|
| 106 |
+
|
| 107 |
+
return args
|
| 108 |
+
|
| 109 |
+
def image_grid(imgs, rows, cols):
|
| 110 |
+
assert len(imgs) == rows * cols
|
| 111 |
+
|
| 112 |
+
w, h = imgs[0].size
|
| 113 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
| 114 |
+
|
| 115 |
+
for i, img in enumerate(imgs):
|
| 116 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
| 117 |
+
return grid
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
args = parse_args()
|
| 121 |
+
repo_path = snapshot_download(repo_id=args.ip_resume_path)
|
| 122 |
+
# Pipeline
|
| 123 |
+
pipeline_p2p = CatVTONPix2PixPipeline(
|
| 124 |
+
base_ckpt=args.p2p_base_model_path,
|
| 125 |
+
attn_ckpt=repo_path,
|
| 126 |
+
attn_ckpt_version="mix-48k-1024",
|
| 127 |
+
weight_dtype=init_weight_dtype(args.mixed_precision),
|
| 128 |
+
use_tf32=args.allow_tf32,
|
| 129 |
+
device='cuda'
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Pipeline
|
| 133 |
+
repo_path = snapshot_download(repo_id=args.ip_resume_path)
|
| 134 |
+
pipeline = CatVTONPipeline(
|
| 135 |
+
base_ckpt=args.ip_base_model_path,
|
| 136 |
+
attn_ckpt=repo_path,
|
| 137 |
+
attn_ckpt_version="mix",
|
| 138 |
+
weight_dtype=init_weight_dtype(args.mixed_precision),
|
| 139 |
+
use_tf32=args.allow_tf32,
|
| 140 |
+
device='cuda'
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# AutoMasker
|
| 144 |
+
mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
|
| 145 |
+
automasker = AutoMasker(
|
| 146 |
+
densepose_ckpt=os.path.join(repo_path, "DensePose"),
|
| 147 |
+
schp_ckpt=os.path.join(repo_path, "SCHP"),
|
| 148 |
+
device='cuda',
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def submit_function_p2p(
|
| 153 |
+
person_image,
|
| 154 |
+
cloth_image,
|
| 155 |
+
num_inference_steps,
|
| 156 |
+
guidance_scale,
|
| 157 |
+
seed):
|
| 158 |
+
person_image= person_image["background"]
|
| 159 |
+
|
| 160 |
+
tmp_folder = args.output_dir
|
| 161 |
+
date_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
| 162 |
+
result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
|
| 163 |
+
if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
|
| 164 |
+
os.makedirs(os.path.join(tmp_folder, date_str[:8]))
|
| 165 |
+
|
| 166 |
+
generator = None
|
| 167 |
+
if seed != -1:
|
| 168 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
| 169 |
+
|
| 170 |
+
person_image = Image.open(person_image).convert("RGB")
|
| 171 |
+
cloth_image = Image.open(cloth_image).convert("RGB")
|
| 172 |
+
person_image = resize_and_crop(person_image, (args.width, args.height))
|
| 173 |
+
cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
|
| 174 |
+
|
| 175 |
+
# Inference
|
| 176 |
+
try:
|
| 177 |
+
result_image = pipeline_p2p(
|
| 178 |
+
image=person_image,
|
| 179 |
+
condition_image=cloth_image,
|
| 180 |
+
num_inference_steps=num_inference_steps,
|
| 181 |
+
guidance_scale=guidance_scale,
|
| 182 |
+
generator=generator
|
| 183 |
+
)[0]
|
| 184 |
+
except Exception as e:
|
| 185 |
+
raise gr.Error(
|
| 186 |
+
"An error occurred. Please try again later: {}".format(e)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Post-process
|
| 190 |
+
save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
|
| 191 |
+
save_result_image.save(result_save_path)
|
| 192 |
+
return result_image
|
| 193 |
+
|
| 194 |
+
def submit_function(
|
| 195 |
+
person_image,
|
| 196 |
+
cloth_image,
|
| 197 |
+
cloth_type,
|
| 198 |
+
num_inference_steps,
|
| 199 |
+
guidance_scale,
|
| 200 |
+
seed,
|
| 201 |
+
show_type
|
| 202 |
+
):
|
| 203 |
+
person_image, mask = person_image["background"], person_image["layers"][0]
|
| 204 |
+
mask = Image.open(mask).convert("L")
|
| 205 |
+
if len(np.unique(np.array(mask))) == 1:
|
| 206 |
+
mask = None
|
| 207 |
+
else:
|
| 208 |
+
mask = np.array(mask)
|
| 209 |
+
mask[mask > 0] = 255
|
| 210 |
+
mask = Image.fromarray(mask)
|
| 211 |
+
|
| 212 |
+
tmp_folder = args.output_dir
|
| 213 |
+
date_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
| 214 |
+
result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
|
| 215 |
+
if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
|
| 216 |
+
os.makedirs(os.path.join(tmp_folder, date_str[:8]))
|
| 217 |
+
|
| 218 |
+
generator = None
|
| 219 |
+
if seed != -1:
|
| 220 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
| 221 |
+
|
| 222 |
+
person_image = Image.open(person_image).convert("RGB")
|
| 223 |
+
cloth_image = Image.open(cloth_image).convert("RGB")
|
| 224 |
+
person_image = resize_and_crop(person_image, (args.width, args.height))
|
| 225 |
+
cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
|
| 226 |
+
|
| 227 |
+
# Process mask
|
| 228 |
+
if mask is not None:
|
| 229 |
+
mask = resize_and_crop(mask, (args.width, args.height))
|
| 230 |
+
else:
|
| 231 |
+
mask = automasker(
|
| 232 |
+
person_image,
|
| 233 |
+
cloth_type
|
| 234 |
+
)['mask']
|
| 235 |
+
mask = mask_processor.blur(mask, blur_factor=9)
|
| 236 |
+
|
| 237 |
+
# Inference
|
| 238 |
+
# try:
|
| 239 |
+
result_image = pipeline(
|
| 240 |
+
image=person_image,
|
| 241 |
+
condition_image=cloth_image,
|
| 242 |
+
mask=mask,
|
| 243 |
+
num_inference_steps=num_inference_steps,
|
| 244 |
+
guidance_scale=guidance_scale,
|
| 245 |
+
generator=generator
|
| 246 |
+
)[0]
|
| 247 |
+
# except Exception as e:
|
| 248 |
+
# raise gr.Error(
|
| 249 |
+
# "An error occurred. Please try again later: {}".format(e)
|
| 250 |
+
# )
|
| 251 |
+
|
| 252 |
+
# Post-process
|
| 253 |
+
masked_person = vis_mask(person_image, mask)
|
| 254 |
+
save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
|
| 255 |
+
save_result_image.save(result_save_path)
|
| 256 |
+
if show_type == "result only":
|
| 257 |
+
return result_image
|
| 258 |
+
else:
|
| 259 |
+
width, height = person_image.size
|
| 260 |
+
if show_type == "input & result":
|
| 261 |
+
condition_width = width // 2
|
| 262 |
+
conditions = image_grid([person_image, cloth_image], 2, 1)
|
| 263 |
+
else:
|
| 264 |
+
condition_width = width // 3
|
| 265 |
+
conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
|
| 266 |
+
conditions = conditions.resize((condition_width, height), Image.NEAREST)
|
| 267 |
+
new_result_image = Image.new("RGB", (width + condition_width + 5, height))
|
| 268 |
+
new_result_image.paste(conditions, (0, 0))
|
| 269 |
+
new_result_image.paste(result_image, (condition_width + 5, 0))
|
| 270 |
+
return new_result_image
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def person_example_fn(image_path):
|
| 275 |
+
return image_path
|
| 276 |
+
|
| 277 |
+
HEADER = """
|
| 278 |
+
<h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
|
| 279 |
+
<div style="display: flex; justify-content: center; align-items: center;">
|
| 280 |
+
<a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
|
| 281 |
+
<img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
|
| 282 |
+
</a>
|
| 283 |
+
<a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
|
| 284 |
+
<img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
|
| 285 |
+
</a>
|
| 286 |
+
<a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
|
| 287 |
+
<img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
|
| 288 |
+
</a>
|
| 289 |
+
<a href="http://120.76.142.206:8888" style="margin: 0 2px;">
|
| 290 |
+
<img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
|
| 291 |
+
</a>
|
| 292 |
+
<a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
|
| 293 |
+
<img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
|
| 294 |
+
</a>
|
| 295 |
+
<a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
|
| 296 |
+
<img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
|
| 297 |
+
</a>
|
| 298 |
+
<a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
|
| 299 |
+
<img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
|
| 300 |
+
</a>
|
| 301 |
+
</div>
|
| 302 |
+
<br>
|
| 303 |
+
· This demo and our weights are only for <span>Non-commercial Use</span>. <br>
|
| 304 |
+
· You can try CatVTON in our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a> or our <a href="http://120.76.142.206:8888">online demo</a> (run on 3090). <br>
|
| 305 |
+
· Thanks to <a href="https://huggingface.co/zero-gpu-explorers">ZeroGPU</a> for providing A100 for our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a>. <br>
|
| 306 |
+
· SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
def app_gradio():
|
| 310 |
+
with gr.Blocks(title="CatVTON") as demo:
|
| 311 |
+
gr.Markdown(HEADER)
|
| 312 |
+
with gr.Tab("Mask-based Virtual Try-On"):
|
| 313 |
+
with gr.Row():
|
| 314 |
+
with gr.Column(scale=1, min_width=350):
|
| 315 |
+
with gr.Row():
|
| 316 |
+
image_path = gr.Image(
|
| 317 |
+
type="filepath",
|
| 318 |
+
interactive=True,
|
| 319 |
+
visible=False,
|
| 320 |
+
)
|
| 321 |
+
person_image = gr.ImageEditor(
|
| 322 |
+
interactive=True, label="Person Image", type="filepath"
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
with gr.Row():
|
| 326 |
+
with gr.Column(scale=1, min_width=230):
|
| 327 |
+
cloth_image = gr.Image(
|
| 328 |
+
interactive=True, label="Condition Image", type="filepath"
|
| 329 |
+
)
|
| 330 |
+
with gr.Column(scale=1, min_width=120):
|
| 331 |
+
gr.Markdown(
|
| 332 |
+
'<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
|
| 333 |
+
)
|
| 334 |
+
cloth_type = gr.Radio(
|
| 335 |
+
label="Try-On Cloth Type",
|
| 336 |
+
choices=["upper", "lower", "overall"],
|
| 337 |
+
value="upper",
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
submit = gr.Button("Submit")
|
| 342 |
+
gr.Markdown(
|
| 343 |
+
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
gr.Markdown(
|
| 347 |
+
'<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
|
| 348 |
+
)
|
| 349 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 350 |
+
num_inference_steps = gr.Slider(
|
| 351 |
+
label="Inference Step", minimum=10, maximum=100, step=5, value=50
|
| 352 |
+
)
|
| 353 |
+
# Guidence Scale
|
| 354 |
+
guidance_scale = gr.Slider(
|
| 355 |
+
label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
|
| 356 |
+
)
|
| 357 |
+
# Random Seed
|
| 358 |
+
seed = gr.Slider(
|
| 359 |
+
label="Seed", minimum=-1, maximum=10000, step=1, value=42
|
| 360 |
+
)
|
| 361 |
+
show_type = gr.Radio(
|
| 362 |
+
label="Show Type",
|
| 363 |
+
choices=["result only", "input & result", "input & mask & result"],
|
| 364 |
+
value="input & mask & result",
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
with gr.Column(scale=2, min_width=500):
|
| 368 |
+
result_image = gr.Image(interactive=False, label="Result")
|
| 369 |
+
with gr.Row():
|
| 370 |
+
# Photo Examples
|
| 371 |
+
root_path = "resource/demo/example"
|
| 372 |
+
with gr.Column():
|
| 373 |
+
men_exm = gr.Examples(
|
| 374 |
+
examples=[
|
| 375 |
+
os.path.join(root_path, "person", "men", _)
|
| 376 |
+
for _ in os.listdir(os.path.join(root_path, "person", "men"))
|
| 377 |
+
],
|
| 378 |
+
examples_per_page=4,
|
| 379 |
+
inputs=image_path,
|
| 380 |
+
label="Person Examples ①",
|
| 381 |
+
)
|
| 382 |
+
women_exm = gr.Examples(
|
| 383 |
+
examples=[
|
| 384 |
+
os.path.join(root_path, "person", "women", _)
|
| 385 |
+
for _ in os.listdir(os.path.join(root_path, "person", "women"))
|
| 386 |
+
],
|
| 387 |
+
examples_per_page=4,
|
| 388 |
+
inputs=image_path,
|
| 389 |
+
label="Person Examples ②",
|
| 390 |
+
)
|
| 391 |
+
gr.Markdown(
|
| 392 |
+
'<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
|
| 393 |
+
)
|
| 394 |
+
with gr.Column():
|
| 395 |
+
condition_upper_exm = gr.Examples(
|
| 396 |
+
examples=[
|
| 397 |
+
os.path.join(root_path, "condition", "upper", _)
|
| 398 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
|
| 399 |
+
],
|
| 400 |
+
examples_per_page=4,
|
| 401 |
+
inputs=cloth_image,
|
| 402 |
+
label="Condition Upper Examples",
|
| 403 |
+
)
|
| 404 |
+
condition_overall_exm = gr.Examples(
|
| 405 |
+
examples=[
|
| 406 |
+
os.path.join(root_path, "condition", "overall", _)
|
| 407 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
|
| 408 |
+
],
|
| 409 |
+
examples_per_page=4,
|
| 410 |
+
inputs=cloth_image,
|
| 411 |
+
label="Condition Overall Examples",
|
| 412 |
+
)
|
| 413 |
+
condition_person_exm = gr.Examples(
|
| 414 |
+
examples=[
|
| 415 |
+
os.path.join(root_path, "condition", "person", _)
|
| 416 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "person"))
|
| 417 |
+
],
|
| 418 |
+
examples_per_page=4,
|
| 419 |
+
inputs=cloth_image,
|
| 420 |
+
label="Condition Reference Person Examples",
|
| 421 |
+
)
|
| 422 |
+
gr.Markdown(
|
| 423 |
+
'<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
image_path.change(
|
| 427 |
+
person_example_fn, inputs=image_path, outputs=person_image
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
submit.click(
|
| 431 |
+
submit_function,
|
| 432 |
+
[
|
| 433 |
+
person_image,
|
| 434 |
+
cloth_image,
|
| 435 |
+
cloth_type,
|
| 436 |
+
num_inference_steps,
|
| 437 |
+
guidance_scale,
|
| 438 |
+
seed,
|
| 439 |
+
show_type,
|
| 440 |
+
],
|
| 441 |
+
result_image,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
with gr.Tab("Mask-Free Virtual Try-On"):
|
| 445 |
+
with gr.Row():
|
| 446 |
+
with gr.Column(scale=1, min_width=350):
|
| 447 |
+
with gr.Row():
|
| 448 |
+
image_path_p2p = gr.Image(
|
| 449 |
+
type="filepath",
|
| 450 |
+
interactive=True,
|
| 451 |
+
visible=False,
|
| 452 |
+
)
|
| 453 |
+
person_image_p2p = gr.ImageEditor(
|
| 454 |
+
interactive=True, label="Person Image", type="filepath"
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
with gr.Row():
|
| 458 |
+
with gr.Column(scale=1, min_width=230):
|
| 459 |
+
cloth_image_p2p = gr.Image(
|
| 460 |
+
interactive=True, label="Condition Image", type="filepath"
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
submit_p2p = gr.Button("Submit")
|
| 464 |
+
gr.Markdown(
|
| 465 |
+
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
gr.Markdown(
|
| 469 |
+
'<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
|
| 470 |
+
)
|
| 471 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 472 |
+
num_inference_steps_p2p = gr.Slider(
|
| 473 |
+
label="Inference Step", minimum=10, maximum=100, step=5, value=50
|
| 474 |
+
)
|
| 475 |
+
# Guidence Scale
|
| 476 |
+
guidance_scale_p2p = gr.Slider(
|
| 477 |
+
label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
|
| 478 |
+
)
|
| 479 |
+
# Random Seed
|
| 480 |
+
seed_p2p = gr.Slider(
|
| 481 |
+
label="Seed", minimum=-1, maximum=10000, step=1, value=42
|
| 482 |
+
)
|
| 483 |
+
# show_type = gr.Radio(
|
| 484 |
+
# label="Show Type",
|
| 485 |
+
# choices=["result only", "input & result", "input & mask & result"],
|
| 486 |
+
# value="input & mask & result",
|
| 487 |
+
# )
|
| 488 |
+
|
| 489 |
+
with gr.Column(scale=2, min_width=500):
|
| 490 |
+
result_image_p2p = gr.Image(interactive=False, label="Result")
|
| 491 |
+
with gr.Row():
|
| 492 |
+
# Photo Examples
|
| 493 |
+
root_path = "resource/demo/example"
|
| 494 |
+
with gr.Column():
|
| 495 |
+
gr.Examples(
|
| 496 |
+
examples=[
|
| 497 |
+
os.path.join(root_path, "person", "men", _)
|
| 498 |
+
for _ in os.listdir(os.path.join(root_path, "person", "men"))
|
| 499 |
+
],
|
| 500 |
+
examples_per_page=4,
|
| 501 |
+
inputs=image_path_p2p,
|
| 502 |
+
label="Person Examples ①",
|
| 503 |
+
)
|
| 504 |
+
gr.Examples(
|
| 505 |
+
examples=[
|
| 506 |
+
os.path.join(root_path, "person", "women", _)
|
| 507 |
+
for _ in os.listdir(os.path.join(root_path, "person", "women"))
|
| 508 |
+
],
|
| 509 |
+
examples_per_page=4,
|
| 510 |
+
inputs=image_path_p2p,
|
| 511 |
+
label="Person Examples ②",
|
| 512 |
+
)
|
| 513 |
+
gr.Markdown(
|
| 514 |
+
'<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
|
| 515 |
+
)
|
| 516 |
+
with gr.Column():
|
| 517 |
+
gr.Examples(
|
| 518 |
+
examples=[
|
| 519 |
+
os.path.join(root_path, "condition", "upper", _)
|
| 520 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
|
| 521 |
+
],
|
| 522 |
+
examples_per_page=4,
|
| 523 |
+
inputs=cloth_image_p2p,
|
| 524 |
+
label="Condition Upper Examples",
|
| 525 |
+
)
|
| 526 |
+
gr.Examples(
|
| 527 |
+
examples=[
|
| 528 |
+
os.path.join(root_path, "condition", "overall", _)
|
| 529 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
|
| 530 |
+
],
|
| 531 |
+
examples_per_page=4,
|
| 532 |
+
inputs=cloth_image_p2p,
|
| 533 |
+
label="Condition Overall Examples",
|
| 534 |
+
)
|
| 535 |
+
condition_person_exm = gr.Examples(
|
| 536 |
+
examples=[
|
| 537 |
+
os.path.join(root_path, "condition", "person", _)
|
| 538 |
+
for _ in os.listdir(os.path.join(root_path, "condition", "person"))
|
| 539 |
+
],
|
| 540 |
+
examples_per_page=4,
|
| 541 |
+
inputs=cloth_image_p2p,
|
| 542 |
+
label="Condition Reference Person Examples",
|
| 543 |
+
)
|
| 544 |
+
gr.Markdown(
|
| 545 |
+
'<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
image_path_p2p.change(
|
| 549 |
+
person_example_fn, inputs=image_path_p2p, outputs=person_image_p2p
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
submit_p2p.click(
|
| 553 |
+
submit_function_p2p,
|
| 554 |
+
[
|
| 555 |
+
person_image_p2p,
|
| 556 |
+
cloth_image_p2p,
|
| 557 |
+
num_inference_steps_p2p,
|
| 558 |
+
guidance_scale_p2p,
|
| 559 |
+
seed_p2p],
|
| 560 |
+
result_image_p2p,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
demo.queue().launch(share=True, show_error=True)
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
if __name__ == "__main__":
|
| 567 |
+
app_gradio()
|
CatVTON/densepose/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
from .data.datasets import builtin # just to register data
|
| 5 |
+
from .converters import builtin as builtin_converters # register converters
|
| 6 |
+
from .config import (
|
| 7 |
+
add_densepose_config,
|
| 8 |
+
add_densepose_head_config,
|
| 9 |
+
add_hrnet_config,
|
| 10 |
+
add_dataset_category_config,
|
| 11 |
+
add_bootstrap_config,
|
| 12 |
+
load_bootstrap_config,
|
| 13 |
+
)
|
| 14 |
+
from .structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
|
| 15 |
+
from .evaluation import DensePoseCOCOEvaluator
|
| 16 |
+
from .modeling.roi_heads import DensePoseROIHeads
|
| 17 |
+
from .modeling.test_time_augmentation import (
|
| 18 |
+
DensePoseGeneralizedRCNNWithTTA,
|
| 19 |
+
DensePoseDatasetMapperTTA,
|
| 20 |
+
)
|
| 21 |
+
from .utils.transform import load_from_cfg
|
| 22 |
+
from .modeling.hrfpn import build_hrfpn_backbone
|
CatVTON/densepose/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.17 kB). View file
|
|
|
CatVTON/densepose/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
CatVTON/densepose/config.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding = utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
# pyre-ignore-all-errors
|
| 4 |
+
|
| 5 |
+
from detectron2.config import CfgNode as CN
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def add_dataset_category_config(cfg: CN) -> None:
|
| 9 |
+
"""
|
| 10 |
+
Add config for additional category-related dataset options
|
| 11 |
+
- category whitelisting
|
| 12 |
+
- category mapping
|
| 13 |
+
"""
|
| 14 |
+
_C = cfg
|
| 15 |
+
_C.DATASETS.CATEGORY_MAPS = CN(new_allowed=True)
|
| 16 |
+
_C.DATASETS.WHITELISTED_CATEGORIES = CN(new_allowed=True)
|
| 17 |
+
# class to mesh mapping
|
| 18 |
+
_C.DATASETS.CLASS_TO_MESH_NAME_MAPPING = CN(new_allowed=True)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def add_evaluation_config(cfg: CN) -> None:
|
| 22 |
+
_C = cfg
|
| 23 |
+
_C.DENSEPOSE_EVALUATION = CN()
|
| 24 |
+
# evaluator type, possible values:
|
| 25 |
+
# - "iou": evaluator for models that produce iou data
|
| 26 |
+
# - "cse": evaluator for models that produce cse data
|
| 27 |
+
_C.DENSEPOSE_EVALUATION.TYPE = "iou"
|
| 28 |
+
# storage for DensePose results, possible values:
|
| 29 |
+
# - "none": no explicit storage, all the results are stored in the
|
| 30 |
+
# dictionary with predictions, memory intensive;
|
| 31 |
+
# historically the default storage type
|
| 32 |
+
# - "ram": RAM storage, uses per-process RAM storage, which is
|
| 33 |
+
# reduced to a single process storage on later stages,
|
| 34 |
+
# less memory intensive
|
| 35 |
+
# - "file": file storage, uses per-process file-based storage,
|
| 36 |
+
# the least memory intensive, but may create bottlenecks
|
| 37 |
+
# on file system accesses
|
| 38 |
+
_C.DENSEPOSE_EVALUATION.STORAGE = "none"
|
| 39 |
+
# minimum threshold for IOU values: the lower its values is,
|
| 40 |
+
# the more matches are produced (and the higher the AP score)
|
| 41 |
+
_C.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD = 0.5
|
| 42 |
+
# Non-distributed inference is slower (at inference time) but can avoid RAM OOM
|
| 43 |
+
_C.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE = True
|
| 44 |
+
# evaluate mesh alignment based on vertex embeddings, only makes sense in CSE context
|
| 45 |
+
_C.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT = False
|
| 46 |
+
# meshes to compute mesh alignment for
|
| 47 |
+
_C.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES = []
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def add_bootstrap_config(cfg: CN) -> None:
|
| 51 |
+
""" """
|
| 52 |
+
_C = cfg
|
| 53 |
+
_C.BOOTSTRAP_DATASETS = []
|
| 54 |
+
_C.BOOTSTRAP_MODEL = CN()
|
| 55 |
+
_C.BOOTSTRAP_MODEL.WEIGHTS = ""
|
| 56 |
+
_C.BOOTSTRAP_MODEL.DEVICE = "cuda"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_bootstrap_dataset_config() -> CN:
|
| 60 |
+
_C = CN()
|
| 61 |
+
_C.DATASET = ""
|
| 62 |
+
# ratio used to mix data loaders
|
| 63 |
+
_C.RATIO = 0.1
|
| 64 |
+
# image loader
|
| 65 |
+
_C.IMAGE_LOADER = CN(new_allowed=True)
|
| 66 |
+
_C.IMAGE_LOADER.TYPE = ""
|
| 67 |
+
_C.IMAGE_LOADER.BATCH_SIZE = 4
|
| 68 |
+
_C.IMAGE_LOADER.NUM_WORKERS = 4
|
| 69 |
+
_C.IMAGE_LOADER.CATEGORIES = []
|
| 70 |
+
_C.IMAGE_LOADER.MAX_COUNT_PER_CATEGORY = 1_000_000
|
| 71 |
+
_C.IMAGE_LOADER.CATEGORY_TO_CLASS_MAPPING = CN(new_allowed=True)
|
| 72 |
+
# inference
|
| 73 |
+
_C.INFERENCE = CN()
|
| 74 |
+
# batch size for model inputs
|
| 75 |
+
_C.INFERENCE.INPUT_BATCH_SIZE = 4
|
| 76 |
+
# batch size to group model outputs
|
| 77 |
+
_C.INFERENCE.OUTPUT_BATCH_SIZE = 2
|
| 78 |
+
# sampled data
|
| 79 |
+
_C.DATA_SAMPLER = CN(new_allowed=True)
|
| 80 |
+
_C.DATA_SAMPLER.TYPE = ""
|
| 81 |
+
_C.DATA_SAMPLER.USE_GROUND_TRUTH_CATEGORIES = False
|
| 82 |
+
# filter
|
| 83 |
+
_C.FILTER = CN(new_allowed=True)
|
| 84 |
+
_C.FILTER.TYPE = ""
|
| 85 |
+
return _C
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def load_bootstrap_config(cfg: CN) -> None:
|
| 89 |
+
"""
|
| 90 |
+
Bootstrap datasets are given as a list of `dict` that are not automatically
|
| 91 |
+
converted into CfgNode. This method processes all bootstrap dataset entries
|
| 92 |
+
and ensures that they are in CfgNode format and comply with the specification
|
| 93 |
+
"""
|
| 94 |
+
if not cfg.BOOTSTRAP_DATASETS:
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
bootstrap_datasets_cfgnodes = []
|
| 98 |
+
for dataset_cfg in cfg.BOOTSTRAP_DATASETS:
|
| 99 |
+
_C = get_bootstrap_dataset_config().clone()
|
| 100 |
+
_C.merge_from_other_cfg(CN(dataset_cfg))
|
| 101 |
+
bootstrap_datasets_cfgnodes.append(_C)
|
| 102 |
+
cfg.BOOTSTRAP_DATASETS = bootstrap_datasets_cfgnodes
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def add_densepose_head_cse_config(cfg: CN) -> None:
|
| 106 |
+
"""
|
| 107 |
+
Add configuration options for Continuous Surface Embeddings (CSE)
|
| 108 |
+
"""
|
| 109 |
+
_C = cfg
|
| 110 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE = CN()
|
| 111 |
+
# Dimensionality D of the embedding space
|
| 112 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE = 16
|
| 113 |
+
# Embedder specifications for various mesh IDs
|
| 114 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS = CN(new_allowed=True)
|
| 115 |
+
# normalization coefficient for embedding distances
|
| 116 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA = 0.01
|
| 117 |
+
# normalization coefficient for geodesic distances
|
| 118 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA = 0.01
|
| 119 |
+
# embedding loss weight
|
| 120 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT = 0.6
|
| 121 |
+
# embedding loss name, currently the following options are supported:
|
| 122 |
+
# - EmbeddingLoss: cross-entropy on vertex labels
|
| 123 |
+
# - SoftEmbeddingLoss: cross-entropy on vertex label combined with
|
| 124 |
+
# Gaussian penalty on distance between vertices
|
| 125 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME = "EmbeddingLoss"
|
| 126 |
+
# optimizer hyperparameters
|
| 127 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR = 1.0
|
| 128 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR = 1.0
|
| 129 |
+
# Shape to shape cycle consistency loss parameters:
|
| 130 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
|
| 131 |
+
# shape to shape cycle consistency loss weight
|
| 132 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.025
|
| 133 |
+
# norm type used for loss computation
|
| 134 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
|
| 135 |
+
# normalization term for embedding similarity matrices
|
| 136 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.TEMPERATURE = 0.05
|
| 137 |
+
# maximum number of vertices to include into shape to shape cycle loss
|
| 138 |
+
# if negative or zero, all vertices are considered
|
| 139 |
+
# if positive, random subset of vertices of given size is considered
|
| 140 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.MAX_NUM_VERTICES = 4936
|
| 141 |
+
# Pixel to shape cycle consistency loss parameters:
|
| 142 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS = CN({"ENABLED": False})
|
| 143 |
+
# pixel to shape cycle consistency loss weight
|
| 144 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT = 0.0001
|
| 145 |
+
# norm type used for loss computation
|
| 146 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NORM_P = 2
|
| 147 |
+
# map images to all meshes and back (if false, use only gt meshes from the batch)
|
| 148 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.USE_ALL_MESHES_NOT_GT_ONLY = False
|
| 149 |
+
# Randomly select at most this number of pixels from every instance
|
| 150 |
+
# if negative or zero, all vertices are considered
|
| 151 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.NUM_PIXELS_TO_SAMPLE = 100
|
| 152 |
+
# normalization factor for pixel to pixel distances (higher value = smoother distribution)
|
| 153 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.PIXEL_SIGMA = 5.0
|
| 154 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_PIXEL_TO_VERTEX = 0.05
|
| 155 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.TEMPERATURE_VERTEX_TO_PIXEL = 0.05
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def add_densepose_head_config(cfg: CN) -> None:
|
| 159 |
+
"""
|
| 160 |
+
Add config for densepose head.
|
| 161 |
+
"""
|
| 162 |
+
_C = cfg
|
| 163 |
+
|
| 164 |
+
_C.MODEL.DENSEPOSE_ON = True
|
| 165 |
+
|
| 166 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD = CN()
|
| 167 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NAME = ""
|
| 168 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NUM_STACKED_CONVS = 8
|
| 169 |
+
# Number of parts used for point labels
|
| 170 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NUM_PATCHES = 24
|
| 171 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECONV_KERNEL = 4
|
| 172 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_DIM = 512
|
| 173 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.CONV_HEAD_KERNEL = 3
|
| 174 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UP_SCALE = 2
|
| 175 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.HEATMAP_SIZE = 112
|
| 176 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_TYPE = "ROIAlignV2"
|
| 177 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_RESOLUTION = 28
|
| 178 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POOLER_SAMPLING_RATIO = 2
|
| 179 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.NUM_COARSE_SEGM_CHANNELS = 2 # 15 or 2
|
| 180 |
+
# Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD)
|
| 181 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD = 0.7
|
| 182 |
+
# Loss weights for annotation masks.(14 Parts)
|
| 183 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS = 5.0
|
| 184 |
+
# Loss weights for surface parts. (24 Parts)
|
| 185 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.PART_WEIGHTS = 1.0
|
| 186 |
+
# Loss weights for UV regression.
|
| 187 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.POINT_REGRESSION_WEIGHTS = 0.01
|
| 188 |
+
# Coarse segmentation is trained using instance segmentation task data
|
| 189 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS = False
|
| 190 |
+
# For Decoder
|
| 191 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_ON = True
|
| 192 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NUM_CLASSES = 256
|
| 193 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_CONV_DIMS = 256
|
| 194 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_NORM = ""
|
| 195 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DECODER_COMMON_STRIDE = 4
|
| 196 |
+
# For DeepLab head
|
| 197 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB = CN()
|
| 198 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NORM = "GN"
|
| 199 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.DEEPLAB.NONLOCAL_ON = 0
|
| 200 |
+
# Predictor class name, must be registered in DENSEPOSE_PREDICTOR_REGISTRY
|
| 201 |
+
# Some registered predictors:
|
| 202 |
+
# "DensePoseChartPredictor": predicts segmentation and UV coordinates for predefined charts
|
| 203 |
+
# "DensePoseChartWithConfidencePredictor": predicts segmentation, UV coordinates
|
| 204 |
+
# and associated confidences for predefined charts (default)
|
| 205 |
+
# "DensePoseEmbeddingWithConfidencePredictor": predicts segmentation, embeddings
|
| 206 |
+
# and associated confidences for CSE
|
| 207 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME = "DensePoseChartWithConfidencePredictor"
|
| 208 |
+
# Loss class name, must be registered in DENSEPOSE_LOSS_REGISTRY
|
| 209 |
+
# Some registered losses:
|
| 210 |
+
# "DensePoseChartLoss": loss for chart-based models that estimate
|
| 211 |
+
# segmentation and UV coordinates
|
| 212 |
+
# "DensePoseChartWithConfidenceLoss": loss for chart-based models that estimate
|
| 213 |
+
# segmentation, UV coordinates and the corresponding confidences (default)
|
| 214 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME = "DensePoseChartWithConfidenceLoss"
|
| 215 |
+
# Confidences
|
| 216 |
+
# Enable learning UV confidences (variances) along with the actual values
|
| 217 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE = CN({"ENABLED": False})
|
| 218 |
+
# UV confidence lower bound
|
| 219 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.EPSILON = 0.01
|
| 220 |
+
# Enable learning segmentation confidences (variances) along with the actual values
|
| 221 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE = CN({"ENABLED": False})
|
| 222 |
+
# Segmentation confidence lower bound
|
| 223 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.SEGM_CONFIDENCE.EPSILON = 0.01
|
| 224 |
+
# Statistical model type for confidence learning, possible values:
|
| 225 |
+
# - "iid_iso": statistically independent identically distributed residuals
|
| 226 |
+
# with isotropic covariance
|
| 227 |
+
# - "indep_aniso": statistically independent residuals with anisotropic
|
| 228 |
+
# covariances
|
| 229 |
+
_C.MODEL.ROI_DENSEPOSE_HEAD.UV_CONFIDENCE.TYPE = "iid_iso"
|
| 230 |
+
# List of angles for rotation in data augmentation during training
|
| 231 |
+
_C.INPUT.ROTATION_ANGLES = [0]
|
| 232 |
+
_C.TEST.AUG.ROTATION_ANGLES = () # Rotation TTA
|
| 233 |
+
|
| 234 |
+
add_densepose_head_cse_config(cfg)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def add_hrnet_config(cfg: CN) -> None:
|
| 238 |
+
"""
|
| 239 |
+
Add config for HRNet backbone.
|
| 240 |
+
"""
|
| 241 |
+
_C = cfg
|
| 242 |
+
|
| 243 |
+
# For HigherHRNet w32
|
| 244 |
+
_C.MODEL.HRNET = CN()
|
| 245 |
+
_C.MODEL.HRNET.STEM_INPLANES = 64
|
| 246 |
+
_C.MODEL.HRNET.STAGE2 = CN()
|
| 247 |
+
_C.MODEL.HRNET.STAGE2.NUM_MODULES = 1
|
| 248 |
+
_C.MODEL.HRNET.STAGE2.NUM_BRANCHES = 2
|
| 249 |
+
_C.MODEL.HRNET.STAGE2.BLOCK = "BASIC"
|
| 250 |
+
_C.MODEL.HRNET.STAGE2.NUM_BLOCKS = [4, 4]
|
| 251 |
+
_C.MODEL.HRNET.STAGE2.NUM_CHANNELS = [32, 64]
|
| 252 |
+
_C.MODEL.HRNET.STAGE2.FUSE_METHOD = "SUM"
|
| 253 |
+
_C.MODEL.HRNET.STAGE3 = CN()
|
| 254 |
+
_C.MODEL.HRNET.STAGE3.NUM_MODULES = 4
|
| 255 |
+
_C.MODEL.HRNET.STAGE3.NUM_BRANCHES = 3
|
| 256 |
+
_C.MODEL.HRNET.STAGE3.BLOCK = "BASIC"
|
| 257 |
+
_C.MODEL.HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4]
|
| 258 |
+
_C.MODEL.HRNET.STAGE3.NUM_CHANNELS = [32, 64, 128]
|
| 259 |
+
_C.MODEL.HRNET.STAGE3.FUSE_METHOD = "SUM"
|
| 260 |
+
_C.MODEL.HRNET.STAGE4 = CN()
|
| 261 |
+
_C.MODEL.HRNET.STAGE4.NUM_MODULES = 3
|
| 262 |
+
_C.MODEL.HRNET.STAGE4.NUM_BRANCHES = 4
|
| 263 |
+
_C.MODEL.HRNET.STAGE4.BLOCK = "BASIC"
|
| 264 |
+
_C.MODEL.HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
|
| 265 |
+
_C.MODEL.HRNET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
|
| 266 |
+
_C.MODEL.HRNET.STAGE4.FUSE_METHOD = "SUM"
|
| 267 |
+
|
| 268 |
+
_C.MODEL.HRNET.HRFPN = CN()
|
| 269 |
+
_C.MODEL.HRNET.HRFPN.OUT_CHANNELS = 256
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def add_densepose_config(cfg: CN) -> None:
|
| 273 |
+
add_densepose_head_config(cfg)
|
| 274 |
+
add_hrnet_config(cfg)
|
| 275 |
+
add_bootstrap_config(cfg)
|
| 276 |
+
add_dataset_category_config(cfg)
|
| 277 |
+
add_evaluation_config(cfg)
|
CatVTON/densepose/converters/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .hflip import HFlipConverter
|
| 6 |
+
from .to_mask import ToMaskConverter
|
| 7 |
+
from .to_chart_result import ToChartResultConverter, ToChartResultConverterWithConfidences
|
| 8 |
+
from .segm_to_mask import (
|
| 9 |
+
predictor_output_with_fine_and_coarse_segm_to_mask,
|
| 10 |
+
predictor_output_with_coarse_segm_to_mask,
|
| 11 |
+
resample_fine_and_coarse_segm_to_bbox,
|
| 12 |
+
)
|
| 13 |
+
from .chart_output_to_chart_result import (
|
| 14 |
+
densepose_chart_predictor_output_to_result,
|
| 15 |
+
densepose_chart_predictor_output_to_result_with_confidences,
|
| 16 |
+
)
|
| 17 |
+
from .chart_output_hflip import densepose_chart_predictor_output_hflip
|
CatVTON/densepose/converters/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (958 Bytes). View file
|
|
|
CatVTON/densepose/converters/__pycache__/base.cpython-311.pyc
ADDED
|
Binary file (4.92 kB). View file
|
|
|
CatVTON/densepose/converters/__pycache__/builtin.cpython-311.pyc
ADDED
|
Binary file (1.17 kB). View file
|
|
|
CatVTON/densepose/converters/__pycache__/chart_output_hflip.cpython-311.pyc
ADDED
|
Binary file (3.56 kB). View file
|
|
|
CatVTON/densepose/converters/__pycache__/chart_output_to_chart_result.cpython-311.pyc
ADDED
|
Binary file (9.34 kB). View file
|
|
|
CatVTON/densepose/converters/__pycache__/hflip.cpython-311.pyc
ADDED
|
Binary file (1.66 kB). View file
|
|
|
CatVTON/densepose/converters/__pycache__/segm_to_mask.cpython-311.pyc
ADDED
|
Binary file (7.95 kB). View file
|
|
|
CatVTON/densepose/converters/__pycache__/to_chart_result.cpython-311.pyc
ADDED
|
Binary file (3.4 kB). View file
|
|
|
CatVTON/densepose/converters/__pycache__/to_mask.cpython-311.pyc
ADDED
|
Binary file (2.14 kB). View file
|
|
|
CatVTON/densepose/converters/base.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any, Tuple, Type
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BaseConverter:
|
| 10 |
+
"""
|
| 11 |
+
Converter base class to be reused by various converters.
|
| 12 |
+
Converter allows one to convert data from various source types to a particular
|
| 13 |
+
destination type. Each source type needs to register its converter. The
|
| 14 |
+
registration for each source type is valid for all descendants of that type.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
@classmethod
|
| 18 |
+
def register(cls, from_type: Type, converter: Any = None):
|
| 19 |
+
"""
|
| 20 |
+
Registers a converter for the specified type.
|
| 21 |
+
Can be used as a decorator (if converter is None), or called as a method.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
from_type (type): type to register the converter for;
|
| 25 |
+
all instances of this type will use the same converter
|
| 26 |
+
converter (callable): converter to be registered for the given
|
| 27 |
+
type; if None, this method is assumed to be a decorator for the converter
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
if converter is not None:
|
| 31 |
+
cls._do_register(from_type, converter)
|
| 32 |
+
|
| 33 |
+
def wrapper(converter: Any) -> Any:
|
| 34 |
+
cls._do_register(from_type, converter)
|
| 35 |
+
return converter
|
| 36 |
+
|
| 37 |
+
return wrapper
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def _do_register(cls, from_type: Type, converter: Any):
|
| 41 |
+
cls.registry[from_type] = converter # pyre-ignore[16]
|
| 42 |
+
|
| 43 |
+
@classmethod
|
| 44 |
+
def _lookup_converter(cls, from_type: Type) -> Any:
|
| 45 |
+
"""
|
| 46 |
+
Perform recursive lookup for the given type
|
| 47 |
+
to find registered converter. If a converter was found for some base
|
| 48 |
+
class, it gets registered for this class to save on further lookups.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
from_type: type for which to find a converter
|
| 52 |
+
Return:
|
| 53 |
+
callable or None - registered converter or None
|
| 54 |
+
if no suitable entry was found in the registry
|
| 55 |
+
"""
|
| 56 |
+
if from_type in cls.registry: # pyre-ignore[16]
|
| 57 |
+
return cls.registry[from_type]
|
| 58 |
+
for base in from_type.__bases__:
|
| 59 |
+
converter = cls._lookup_converter(base)
|
| 60 |
+
if converter is not None:
|
| 61 |
+
cls._do_register(from_type, converter)
|
| 62 |
+
return converter
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def convert(cls, instance: Any, *args, **kwargs):
|
| 67 |
+
"""
|
| 68 |
+
Convert an instance to the destination type using some registered
|
| 69 |
+
converter. Does recursive lookup for base classes, so there's no need
|
| 70 |
+
for explicit registration for derived classes.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
instance: source instance to convert to the destination type
|
| 74 |
+
Return:
|
| 75 |
+
An instance of the destination type obtained from the source instance
|
| 76 |
+
Raises KeyError, if no suitable converter found
|
| 77 |
+
"""
|
| 78 |
+
instance_type = type(instance)
|
| 79 |
+
converter = cls._lookup_converter(instance_type)
|
| 80 |
+
if converter is None:
|
| 81 |
+
if cls.dst_type is None: # pyre-ignore[16]
|
| 82 |
+
output_type_str = "itself"
|
| 83 |
+
else:
|
| 84 |
+
output_type_str = cls.dst_type
|
| 85 |
+
raise KeyError(f"Could not find converter from {instance_type} to {output_type_str}")
|
| 86 |
+
return converter(instance, *args, **kwargs)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
IntTupleBox = Tuple[int, int, int, int]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def make_int_box(box: torch.Tensor) -> IntTupleBox:
|
| 93 |
+
int_box = [0, 0, 0, 0]
|
| 94 |
+
int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
|
| 95 |
+
return int_box[0], int_box[1], int_box[2], int_box[3]
|
CatVTON/densepose/converters/builtin.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from ..structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
|
| 6 |
+
from . import (
|
| 7 |
+
HFlipConverter,
|
| 8 |
+
ToChartResultConverter,
|
| 9 |
+
ToChartResultConverterWithConfidences,
|
| 10 |
+
ToMaskConverter,
|
| 11 |
+
densepose_chart_predictor_output_hflip,
|
| 12 |
+
densepose_chart_predictor_output_to_result,
|
| 13 |
+
densepose_chart_predictor_output_to_result_with_confidences,
|
| 14 |
+
predictor_output_with_coarse_segm_to_mask,
|
| 15 |
+
predictor_output_with_fine_and_coarse_segm_to_mask,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
ToMaskConverter.register(
|
| 19 |
+
DensePoseChartPredictorOutput, predictor_output_with_fine_and_coarse_segm_to_mask
|
| 20 |
+
)
|
| 21 |
+
ToMaskConverter.register(
|
| 22 |
+
DensePoseEmbeddingPredictorOutput, predictor_output_with_coarse_segm_to_mask
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
ToChartResultConverter.register(
|
| 26 |
+
DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
ToChartResultConverterWithConfidences.register(
|
| 30 |
+
DensePoseChartPredictorOutput, densepose_chart_predictor_output_to_result_with_confidences
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
HFlipConverter.register(DensePoseChartPredictorOutput, densepose_chart_predictor_output_hflip)
|
CatVTON/densepose/converters/chart_output_hflip.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
from dataclasses import fields
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from densepose.structures import DensePoseChartPredictorOutput, DensePoseTransformData
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def densepose_chart_predictor_output_hflip(
|
| 11 |
+
densepose_predictor_output: DensePoseChartPredictorOutput,
|
| 12 |
+
transform_data: DensePoseTransformData,
|
| 13 |
+
) -> DensePoseChartPredictorOutput:
|
| 14 |
+
"""
|
| 15 |
+
Change to take into account a Horizontal flip.
|
| 16 |
+
"""
|
| 17 |
+
if len(densepose_predictor_output) > 0:
|
| 18 |
+
|
| 19 |
+
PredictorOutput = type(densepose_predictor_output)
|
| 20 |
+
output_dict = {}
|
| 21 |
+
|
| 22 |
+
for field in fields(densepose_predictor_output):
|
| 23 |
+
field_value = getattr(densepose_predictor_output, field.name)
|
| 24 |
+
# flip tensors
|
| 25 |
+
if isinstance(field_value, torch.Tensor):
|
| 26 |
+
setattr(densepose_predictor_output, field.name, torch.flip(field_value, [3]))
|
| 27 |
+
|
| 28 |
+
densepose_predictor_output = _flip_iuv_semantics_tensor(
|
| 29 |
+
densepose_predictor_output, transform_data
|
| 30 |
+
)
|
| 31 |
+
densepose_predictor_output = _flip_segm_semantics_tensor(
|
| 32 |
+
densepose_predictor_output, transform_data
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
for field in fields(densepose_predictor_output):
|
| 36 |
+
output_dict[field.name] = getattr(densepose_predictor_output, field.name)
|
| 37 |
+
|
| 38 |
+
return PredictorOutput(**output_dict)
|
| 39 |
+
else:
|
| 40 |
+
return densepose_predictor_output
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _flip_iuv_semantics_tensor(
|
| 44 |
+
densepose_predictor_output: DensePoseChartPredictorOutput,
|
| 45 |
+
dp_transform_data: DensePoseTransformData,
|
| 46 |
+
) -> DensePoseChartPredictorOutput:
|
| 47 |
+
point_label_symmetries = dp_transform_data.point_label_symmetries
|
| 48 |
+
uv_symmetries = dp_transform_data.uv_symmetries
|
| 49 |
+
|
| 50 |
+
N, C, H, W = densepose_predictor_output.u.shape
|
| 51 |
+
u_loc = (densepose_predictor_output.u[:, 1:, :, :].clamp(0, 1) * 255).long()
|
| 52 |
+
v_loc = (densepose_predictor_output.v[:, 1:, :, :].clamp(0, 1) * 255).long()
|
| 53 |
+
Iindex = torch.arange(C - 1, device=densepose_predictor_output.u.device)[
|
| 54 |
+
None, :, None, None
|
| 55 |
+
].expand(N, C - 1, H, W)
|
| 56 |
+
densepose_predictor_output.u[:, 1:, :, :] = uv_symmetries["U_transforms"][Iindex, v_loc, u_loc]
|
| 57 |
+
densepose_predictor_output.v[:, 1:, :, :] = uv_symmetries["V_transforms"][Iindex, v_loc, u_loc]
|
| 58 |
+
|
| 59 |
+
for el in ["fine_segm", "u", "v"]:
|
| 60 |
+
densepose_predictor_output.__dict__[el] = densepose_predictor_output.__dict__[el][
|
| 61 |
+
:, point_label_symmetries, :, :
|
| 62 |
+
]
|
| 63 |
+
return densepose_predictor_output
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _flip_segm_semantics_tensor(
|
| 67 |
+
densepose_predictor_output: DensePoseChartPredictorOutput, dp_transform_data
|
| 68 |
+
):
|
| 69 |
+
if densepose_predictor_output.coarse_segm.shape[1] > 2:
|
| 70 |
+
densepose_predictor_output.coarse_segm = densepose_predictor_output.coarse_segm[
|
| 71 |
+
:, dp_transform_data.mask_label_symmetries, :, :
|
| 72 |
+
]
|
| 73 |
+
return densepose_predictor_output
|
CatVTON/densepose/converters/chart_output_to_chart_result.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Dict
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from detectron2.structures.boxes import Boxes, BoxMode
|
| 10 |
+
|
| 11 |
+
from ..structures import (
|
| 12 |
+
DensePoseChartPredictorOutput,
|
| 13 |
+
DensePoseChartResult,
|
| 14 |
+
DensePoseChartResultWithConfidences,
|
| 15 |
+
)
|
| 16 |
+
from . import resample_fine_and_coarse_segm_to_bbox
|
| 17 |
+
from .base import IntTupleBox, make_int_box
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def resample_uv_tensors_to_bbox(
|
| 21 |
+
u: torch.Tensor,
|
| 22 |
+
v: torch.Tensor,
|
| 23 |
+
labels: torch.Tensor,
|
| 24 |
+
box_xywh_abs: IntTupleBox,
|
| 25 |
+
) -> torch.Tensor:
|
| 26 |
+
"""
|
| 27 |
+
Resamples U and V coordinate estimates for the given bounding box
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
u (tensor [1, C, H, W] of float): U coordinates
|
| 31 |
+
v (tensor [1, C, H, W] of float): V coordinates
|
| 32 |
+
labels (tensor [H, W] of long): labels obtained by resampling segmentation
|
| 33 |
+
outputs for the given bounding box
|
| 34 |
+
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
|
| 35 |
+
Return:
|
| 36 |
+
Resampled U and V coordinates - a tensor [2, H, W] of float
|
| 37 |
+
"""
|
| 38 |
+
x, y, w, h = box_xywh_abs
|
| 39 |
+
w = max(int(w), 1)
|
| 40 |
+
h = max(int(h), 1)
|
| 41 |
+
u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
|
| 42 |
+
v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
|
| 43 |
+
uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
|
| 44 |
+
for part_id in range(1, u_bbox.size(1)):
|
| 45 |
+
uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
|
| 46 |
+
uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
|
| 47 |
+
return uv
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def resample_uv_to_bbox(
|
| 51 |
+
predictor_output: DensePoseChartPredictorOutput,
|
| 52 |
+
labels: torch.Tensor,
|
| 53 |
+
box_xywh_abs: IntTupleBox,
|
| 54 |
+
) -> torch.Tensor:
|
| 55 |
+
"""
|
| 56 |
+
Resamples U and V coordinate estimates for the given bounding box
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
| 60 |
+
output to be resampled
|
| 61 |
+
labels (tensor [H, W] of long): labels obtained by resampling segmentation
|
| 62 |
+
outputs for the given bounding box
|
| 63 |
+
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
|
| 64 |
+
Return:
|
| 65 |
+
Resampled U and V coordinates - a tensor [2, H, W] of float
|
| 66 |
+
"""
|
| 67 |
+
return resample_uv_tensors_to_bbox(
|
| 68 |
+
predictor_output.u,
|
| 69 |
+
predictor_output.v,
|
| 70 |
+
labels,
|
| 71 |
+
box_xywh_abs,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def densepose_chart_predictor_output_to_result(
|
| 76 |
+
predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
|
| 77 |
+
) -> DensePoseChartResult:
|
| 78 |
+
"""
|
| 79 |
+
Convert densepose chart predictor outputs to results
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
| 83 |
+
output to be converted to results, must contain only 1 output
|
| 84 |
+
boxes (Boxes): bounding box that corresponds to the predictor output,
|
| 85 |
+
must contain only 1 bounding box
|
| 86 |
+
Return:
|
| 87 |
+
DensePose chart-based result (DensePoseChartResult)
|
| 88 |
+
"""
|
| 89 |
+
assert len(predictor_output) == 1 and len(boxes) == 1, (
|
| 90 |
+
f"Predictor output to result conversion can operate only single outputs"
|
| 91 |
+
f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
| 95 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 96 |
+
box_xywh = make_int_box(boxes_xywh_abs[0])
|
| 97 |
+
|
| 98 |
+
labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
|
| 99 |
+
uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
|
| 100 |
+
return DensePoseChartResult(labels=labels, uv=uv)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def resample_confidences_to_bbox(
|
| 104 |
+
predictor_output: DensePoseChartPredictorOutput,
|
| 105 |
+
labels: torch.Tensor,
|
| 106 |
+
box_xywh_abs: IntTupleBox,
|
| 107 |
+
) -> Dict[str, torch.Tensor]:
|
| 108 |
+
"""
|
| 109 |
+
Resamples confidences for the given bounding box
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
| 113 |
+
output to be resampled
|
| 114 |
+
labels (tensor [H, W] of long): labels obtained by resampling segmentation
|
| 115 |
+
outputs for the given bounding box
|
| 116 |
+
box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
|
| 117 |
+
Return:
|
| 118 |
+
Resampled confidences - a dict of [H, W] tensors of float
|
| 119 |
+
"""
|
| 120 |
+
|
| 121 |
+
x, y, w, h = box_xywh_abs
|
| 122 |
+
w = max(int(w), 1)
|
| 123 |
+
h = max(int(h), 1)
|
| 124 |
+
|
| 125 |
+
confidence_names = [
|
| 126 |
+
"sigma_1",
|
| 127 |
+
"sigma_2",
|
| 128 |
+
"kappa_u",
|
| 129 |
+
"kappa_v",
|
| 130 |
+
"fine_segm_confidence",
|
| 131 |
+
"coarse_segm_confidence",
|
| 132 |
+
]
|
| 133 |
+
confidence_results = {key: None for key in confidence_names}
|
| 134 |
+
confidence_names = [
|
| 135 |
+
key for key in confidence_names if getattr(predictor_output, key) is not None
|
| 136 |
+
]
|
| 137 |
+
confidence_base = torch.zeros([h, w], dtype=torch.float32, device=predictor_output.u.device)
|
| 138 |
+
|
| 139 |
+
# assign data from channels that correspond to the labels
|
| 140 |
+
for key in confidence_names:
|
| 141 |
+
resampled_confidence = F.interpolate(
|
| 142 |
+
getattr(predictor_output, key),
|
| 143 |
+
(h, w),
|
| 144 |
+
mode="bilinear",
|
| 145 |
+
align_corners=False,
|
| 146 |
+
)
|
| 147 |
+
result = confidence_base.clone()
|
| 148 |
+
for part_id in range(1, predictor_output.u.size(1)):
|
| 149 |
+
if resampled_confidence.size(1) != predictor_output.u.size(1):
|
| 150 |
+
# confidence is not part-based, don't try to fill it part by part
|
| 151 |
+
continue
|
| 152 |
+
result[labels == part_id] = resampled_confidence[0, part_id][labels == part_id]
|
| 153 |
+
|
| 154 |
+
if resampled_confidence.size(1) != predictor_output.u.size(1):
|
| 155 |
+
# confidence is not part-based, fill the data with the first channel
|
| 156 |
+
# (targeted for segmentation confidences that have only 1 channel)
|
| 157 |
+
result = resampled_confidence[0, 0]
|
| 158 |
+
|
| 159 |
+
confidence_results[key] = result
|
| 160 |
+
|
| 161 |
+
return confidence_results # pyre-ignore[7]
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def densepose_chart_predictor_output_to_result_with_confidences(
|
| 165 |
+
predictor_output: DensePoseChartPredictorOutput, boxes: Boxes
|
| 166 |
+
) -> DensePoseChartResultWithConfidences:
|
| 167 |
+
"""
|
| 168 |
+
Convert densepose chart predictor outputs to results
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
predictor_output (DensePoseChartPredictorOutput): DensePose predictor
|
| 172 |
+
output with confidences to be converted to results, must contain only 1 output
|
| 173 |
+
boxes (Boxes): bounding box that corresponds to the predictor output,
|
| 174 |
+
must contain only 1 bounding box
|
| 175 |
+
Return:
|
| 176 |
+
DensePose chart-based result with confidences (DensePoseChartResultWithConfidences)
|
| 177 |
+
"""
|
| 178 |
+
assert len(predictor_output) == 1 and len(boxes) == 1, (
|
| 179 |
+
f"Predictor output to result conversion can operate only single outputs"
|
| 180 |
+
f", got {len(predictor_output)} predictor outputs and {len(boxes)} boxes"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
| 184 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 185 |
+
box_xywh = make_int_box(boxes_xywh_abs[0])
|
| 186 |
+
|
| 187 |
+
labels = resample_fine_and_coarse_segm_to_bbox(predictor_output, box_xywh).squeeze(0)
|
| 188 |
+
uv = resample_uv_to_bbox(predictor_output, labels, box_xywh)
|
| 189 |
+
confidences = resample_confidences_to_bbox(predictor_output, labels, box_xywh)
|
| 190 |
+
return DensePoseChartResultWithConfidences(labels=labels, uv=uv, **confidences)
|
CatVTON/densepose/converters/hflip.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from .base import BaseConverter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HFlipConverter(BaseConverter):
|
| 11 |
+
"""
|
| 12 |
+
Converts various DensePose predictor outputs to DensePose results.
|
| 13 |
+
Each DensePose predictor output type has to register its convertion strategy.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
registry = {}
|
| 17 |
+
dst_type = None
|
| 18 |
+
|
| 19 |
+
@classmethod
|
| 20 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
| 21 |
+
# inconsistently.
|
| 22 |
+
def convert(cls, predictor_outputs: Any, transform_data: Any, *args, **kwargs):
|
| 23 |
+
"""
|
| 24 |
+
Performs an horizontal flip on DensePose predictor outputs.
|
| 25 |
+
Does recursive lookup for base classes, so there's no need
|
| 26 |
+
for explicit registration for derived classes.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
predictor_outputs: DensePose predictor output to be converted to BitMasks
|
| 30 |
+
transform_data: Anything useful for the flip
|
| 31 |
+
Return:
|
| 32 |
+
An instance of the same type as predictor_outputs
|
| 33 |
+
"""
|
| 34 |
+
return super(HFlipConverter, cls).convert(
|
| 35 |
+
predictor_outputs, transform_data, *args, **kwargs
|
| 36 |
+
)
|
CatVTON/densepose/converters/segm_to_mask.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from detectron2.structures import BitMasks, Boxes, BoxMode
|
| 10 |
+
|
| 11 |
+
from .base import IntTupleBox, make_int_box
|
| 12 |
+
from .to_mask import ImageSizeType
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def resample_coarse_segm_tensor_to_bbox(coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox):
|
| 16 |
+
"""
|
| 17 |
+
Resample coarse segmentation tensor to the given
|
| 18 |
+
bounding box and derive labels for each pixel of the bounding box
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
coarse_segm: float tensor of shape [1, K, Hout, Wout]
|
| 22 |
+
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
|
| 23 |
+
corner coordinates, width (W) and height (H)
|
| 24 |
+
Return:
|
| 25 |
+
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
|
| 26 |
+
"""
|
| 27 |
+
x, y, w, h = box_xywh_abs
|
| 28 |
+
w = max(int(w), 1)
|
| 29 |
+
h = max(int(h), 1)
|
| 30 |
+
labels = F.interpolate(coarse_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
|
| 31 |
+
return labels
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def resample_fine_and_coarse_segm_tensors_to_bbox(
|
| 35 |
+
fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Resample fine and coarse segmentation tensors to the given
|
| 39 |
+
bounding box and derive labels for each pixel of the bounding box
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
fine_segm: float tensor of shape [1, C, Hout, Wout]
|
| 43 |
+
coarse_segm: float tensor of shape [1, K, Hout, Wout]
|
| 44 |
+
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
|
| 45 |
+
corner coordinates, width (W) and height (H)
|
| 46 |
+
Return:
|
| 47 |
+
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
|
| 48 |
+
"""
|
| 49 |
+
x, y, w, h = box_xywh_abs
|
| 50 |
+
w = max(int(w), 1)
|
| 51 |
+
h = max(int(h), 1)
|
| 52 |
+
# coarse segmentation
|
| 53 |
+
coarse_segm_bbox = F.interpolate(
|
| 54 |
+
coarse_segm,
|
| 55 |
+
(h, w),
|
| 56 |
+
mode="bilinear",
|
| 57 |
+
align_corners=False,
|
| 58 |
+
).argmax(dim=1)
|
| 59 |
+
# combined coarse and fine segmentation
|
| 60 |
+
labels = (
|
| 61 |
+
F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
|
| 62 |
+
* (coarse_segm_bbox > 0).long()
|
| 63 |
+
)
|
| 64 |
+
return labels
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def resample_fine_and_coarse_segm_to_bbox(predictor_output: Any, box_xywh_abs: IntTupleBox):
|
| 68 |
+
"""
|
| 69 |
+
Resample fine and coarse segmentation outputs from a predictor to the given
|
| 70 |
+
bounding box and derive labels for each pixel of the bounding box
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
predictor_output: DensePose predictor output that contains segmentation
|
| 74 |
+
results to be resampled
|
| 75 |
+
box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
|
| 76 |
+
corner coordinates, width (W) and height (H)
|
| 77 |
+
Return:
|
| 78 |
+
Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
|
| 79 |
+
"""
|
| 80 |
+
return resample_fine_and_coarse_segm_tensors_to_bbox(
|
| 81 |
+
predictor_output.fine_segm,
|
| 82 |
+
predictor_output.coarse_segm,
|
| 83 |
+
box_xywh_abs,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def predictor_output_with_coarse_segm_to_mask(
|
| 88 |
+
predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
|
| 89 |
+
) -> BitMasks:
|
| 90 |
+
"""
|
| 91 |
+
Convert predictor output with coarse and fine segmentation to a mask.
|
| 92 |
+
Assumes that predictor output has the following attributes:
|
| 93 |
+
- coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
|
| 94 |
+
unnormalized scores for N instances; D is the number of coarse
|
| 95 |
+
segmentation labels, H and W is the resolution of the estimate
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
predictor_output: DensePose predictor output to be converted to mask
|
| 99 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
| 100 |
+
predictor outputs
|
| 101 |
+
image_size_hw (tuple [int, int]): image height Himg and width Wimg
|
| 102 |
+
Return:
|
| 103 |
+
BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
|
| 104 |
+
a mask of the size of the image for each instance
|
| 105 |
+
"""
|
| 106 |
+
H, W = image_size_hw
|
| 107 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
| 108 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 109 |
+
N = len(boxes_xywh_abs)
|
| 110 |
+
masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
|
| 111 |
+
for i in range(len(boxes_xywh_abs)):
|
| 112 |
+
box_xywh = make_int_box(boxes_xywh_abs[i])
|
| 113 |
+
box_mask = resample_coarse_segm_tensor_to_bbox(predictor_output[i].coarse_segm, box_xywh)
|
| 114 |
+
x, y, w, h = box_xywh
|
| 115 |
+
masks[i, y : y + h, x : x + w] = box_mask
|
| 116 |
+
|
| 117 |
+
return BitMasks(masks)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def predictor_output_with_fine_and_coarse_segm_to_mask(
|
| 121 |
+
predictor_output: Any, boxes: Boxes, image_size_hw: ImageSizeType
|
| 122 |
+
) -> BitMasks:
|
| 123 |
+
"""
|
| 124 |
+
Convert predictor output with coarse and fine segmentation to a mask.
|
| 125 |
+
Assumes that predictor output has the following attributes:
|
| 126 |
+
- coarse_segm (tensor of size [N, D, H, W]): coarse segmentation
|
| 127 |
+
unnormalized scores for N instances; D is the number of coarse
|
| 128 |
+
segmentation labels, H and W is the resolution of the estimate
|
| 129 |
+
- fine_segm (tensor of size [N, C, H, W]): fine segmentation
|
| 130 |
+
unnormalized scores for N instances; C is the number of fine
|
| 131 |
+
segmentation labels, H and W is the resolution of the estimate
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
predictor_output: DensePose predictor output to be converted to mask
|
| 135 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
| 136 |
+
predictor outputs
|
| 137 |
+
image_size_hw (tuple [int, int]): image height Himg and width Wimg
|
| 138 |
+
Return:
|
| 139 |
+
BitMasks that contain a bool tensor of size [N, Himg, Wimg] with
|
| 140 |
+
a mask of the size of the image for each instance
|
| 141 |
+
"""
|
| 142 |
+
H, W = image_size_hw
|
| 143 |
+
boxes_xyxy_abs = boxes.tensor.clone()
|
| 144 |
+
boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| 145 |
+
N = len(boxes_xywh_abs)
|
| 146 |
+
masks = torch.zeros((N, H, W), dtype=torch.bool, device=boxes.tensor.device)
|
| 147 |
+
for i in range(len(boxes_xywh_abs)):
|
| 148 |
+
box_xywh = make_int_box(boxes_xywh_abs[i])
|
| 149 |
+
labels_i = resample_fine_and_coarse_segm_to_bbox(predictor_output[i], box_xywh)
|
| 150 |
+
x, y, w, h = box_xywh
|
| 151 |
+
masks[i, y : y + h, x : x + w] = labels_i > 0
|
| 152 |
+
return BitMasks(masks)
|
CatVTON/densepose/converters/to_chart_result.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
from detectron2.structures import Boxes
|
| 8 |
+
|
| 9 |
+
from ..structures import DensePoseChartResult, DensePoseChartResultWithConfidences
|
| 10 |
+
from .base import BaseConverter
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ToChartResultConverter(BaseConverter):
|
| 14 |
+
"""
|
| 15 |
+
Converts various DensePose predictor outputs to DensePose results.
|
| 16 |
+
Each DensePose predictor output type has to register its convertion strategy.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
registry = {}
|
| 20 |
+
dst_type = DensePoseChartResult
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
| 24 |
+
# inconsistently.
|
| 25 |
+
def convert(cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs) -> DensePoseChartResult:
|
| 26 |
+
"""
|
| 27 |
+
Convert DensePose predictor outputs to DensePoseResult using some registered
|
| 28 |
+
converter. Does recursive lookup for base classes, so there's no need
|
| 29 |
+
for explicit registration for derived classes.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
densepose_predictor_outputs: DensePose predictor output to be
|
| 33 |
+
converted to BitMasks
|
| 34 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
| 35 |
+
predictor outputs
|
| 36 |
+
Return:
|
| 37 |
+
An instance of DensePoseResult. If no suitable converter was found, raises KeyError
|
| 38 |
+
"""
|
| 39 |
+
return super(ToChartResultConverter, cls).convert(predictor_outputs, boxes, *args, **kwargs)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ToChartResultConverterWithConfidences(BaseConverter):
|
| 43 |
+
"""
|
| 44 |
+
Converts various DensePose predictor outputs to DensePose results.
|
| 45 |
+
Each DensePose predictor output type has to register its convertion strategy.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
registry = {}
|
| 49 |
+
dst_type = DensePoseChartResultWithConfidences
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
| 53 |
+
# inconsistently.
|
| 54 |
+
def convert(
|
| 55 |
+
cls, predictor_outputs: Any, boxes: Boxes, *args, **kwargs
|
| 56 |
+
) -> DensePoseChartResultWithConfidences:
|
| 57 |
+
"""
|
| 58 |
+
Convert DensePose predictor outputs to DensePoseResult with confidences
|
| 59 |
+
using some registered converter. Does recursive lookup for base classes,
|
| 60 |
+
so there's no need for explicit registration for derived classes.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
densepose_predictor_outputs: DensePose predictor output with confidences
|
| 64 |
+
to be converted to BitMasks
|
| 65 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
| 66 |
+
predictor outputs
|
| 67 |
+
Return:
|
| 68 |
+
An instance of DensePoseResult. If no suitable converter was found, raises KeyError
|
| 69 |
+
"""
|
| 70 |
+
return super(ToChartResultConverterWithConfidences, cls).convert(
|
| 71 |
+
predictor_outputs, boxes, *args, **kwargs
|
| 72 |
+
)
|
CatVTON/densepose/converters/to_mask.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from typing import Any, Tuple
|
| 6 |
+
|
| 7 |
+
from detectron2.structures import BitMasks, Boxes
|
| 8 |
+
|
| 9 |
+
from .base import BaseConverter
|
| 10 |
+
|
| 11 |
+
ImageSizeType = Tuple[int, int]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ToMaskConverter(BaseConverter):
|
| 15 |
+
"""
|
| 16 |
+
Converts various DensePose predictor outputs to masks
|
| 17 |
+
in bit mask format (see `BitMasks`). Each DensePose predictor output type
|
| 18 |
+
has to register its convertion strategy.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
registry = {}
|
| 22 |
+
dst_type = BitMasks
|
| 23 |
+
|
| 24 |
+
@classmethod
|
| 25 |
+
# pyre-fixme[14]: `convert` overrides method defined in `BaseConverter`
|
| 26 |
+
# inconsistently.
|
| 27 |
+
def convert(
|
| 28 |
+
cls,
|
| 29 |
+
densepose_predictor_outputs: Any,
|
| 30 |
+
boxes: Boxes,
|
| 31 |
+
image_size_hw: ImageSizeType,
|
| 32 |
+
*args,
|
| 33 |
+
**kwargs
|
| 34 |
+
) -> BitMasks:
|
| 35 |
+
"""
|
| 36 |
+
Convert DensePose predictor outputs to BitMasks using some registered
|
| 37 |
+
converter. Does recursive lookup for base classes, so there's no need
|
| 38 |
+
for explicit registration for derived classes.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
densepose_predictor_outputs: DensePose predictor output to be
|
| 42 |
+
converted to BitMasks
|
| 43 |
+
boxes (Boxes): bounding boxes that correspond to the DensePose
|
| 44 |
+
predictor outputs
|
| 45 |
+
image_size_hw (tuple [int, int]): image height and width
|
| 46 |
+
Return:
|
| 47 |
+
An instance of `BitMasks`. If no suitable converter was found, raises KeyError
|
| 48 |
+
"""
|
| 49 |
+
return super(ToMaskConverter, cls).convert(
|
| 50 |
+
densepose_predictor_outputs, boxes, image_size_hw, *args, **kwargs
|
| 51 |
+
)
|
CatVTON/densepose/data/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from .meshes import builtin
|
| 6 |
+
from .build import (
|
| 7 |
+
build_detection_test_loader,
|
| 8 |
+
build_detection_train_loader,
|
| 9 |
+
build_combined_loader,
|
| 10 |
+
build_frame_selector,
|
| 11 |
+
build_inference_based_loaders,
|
| 12 |
+
has_inference_based_loaders,
|
| 13 |
+
BootstrapDatasetFactoryCatalog,
|
| 14 |
+
)
|
| 15 |
+
from .combined_loader import CombinedDataLoader
|
| 16 |
+
from .dataset_mapper import DatasetMapper
|
| 17 |
+
from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter
|
| 18 |
+
from .image_list_dataset import ImageListDataset
|
| 19 |
+
from .utils import is_relative_local_path, maybe_prepend_base_path
|
| 20 |
+
|
| 21 |
+
# ensure the builtin datasets are registered
|
| 22 |
+
from . import datasets
|
| 23 |
+
|
| 24 |
+
# ensure the bootstrap datasets builders are registered
|
| 25 |
+
from . import build
|
| 26 |
+
|
| 27 |
+
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
CatVTON/densepose/data/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.44 kB). View file
|
|
|
CatVTON/densepose/data/__pycache__/build.cpython-311.pyc
ADDED
|
Binary file (39.7 kB). View file
|
|
|
CatVTON/densepose/data/__pycache__/combined_loader.cpython-311.pyc
ADDED
|
Binary file (3.02 kB). View file
|
|
|
CatVTON/densepose/data/__pycache__/dataset_mapper.cpython-311.pyc
ADDED
|
Binary file (9.5 kB). View file
|
|
|
CatVTON/densepose/data/__pycache__/image_list_dataset.cpython-311.pyc
ADDED
|
Binary file (4.03 kB). View file
|
|
|
CatVTON/densepose/data/__pycache__/inference_based_loader.cpython-311.pyc
ADDED
|
Binary file (9.08 kB). View file
|
|
|
CatVTON/densepose/data/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (2.41 kB). View file
|
|
|
CatVTON/densepose/data/build.py
ADDED
|
@@ -0,0 +1,738 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import itertools
|
| 6 |
+
import logging
|
| 7 |
+
import numpy as np
|
| 8 |
+
from collections import UserDict, defaultdict
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Sequence, Tuple
|
| 11 |
+
import torch
|
| 12 |
+
from torch.utils.data.dataset import Dataset
|
| 13 |
+
|
| 14 |
+
from detectron2.config import CfgNode
|
| 15 |
+
from detectron2.data.build import build_detection_test_loader as d2_build_detection_test_loader
|
| 16 |
+
from detectron2.data.build import build_detection_train_loader as d2_build_detection_train_loader
|
| 17 |
+
from detectron2.data.build import (
|
| 18 |
+
load_proposals_into_dataset,
|
| 19 |
+
print_instances_class_histogram,
|
| 20 |
+
trivial_batch_collator,
|
| 21 |
+
worker_init_reset_seed,
|
| 22 |
+
)
|
| 23 |
+
from detectron2.data.catalog import DatasetCatalog, Metadata, MetadataCatalog
|
| 24 |
+
from detectron2.data.samplers import TrainingSampler
|
| 25 |
+
from detectron2.utils.comm import get_world_size
|
| 26 |
+
|
| 27 |
+
from densepose.config import get_bootstrap_dataset_config
|
| 28 |
+
from densepose.modeling import build_densepose_embedder
|
| 29 |
+
|
| 30 |
+
from .combined_loader import CombinedDataLoader, Loader
|
| 31 |
+
from .dataset_mapper import DatasetMapper
|
| 32 |
+
from .datasets.coco import DENSEPOSE_CSE_KEYS_WITHOUT_MASK, DENSEPOSE_IUV_KEYS_WITHOUT_MASK
|
| 33 |
+
from .datasets.dataset_type import DatasetType
|
| 34 |
+
from .inference_based_loader import InferenceBasedLoader, ScoreBasedFilter
|
| 35 |
+
from .samplers import (
|
| 36 |
+
DensePoseConfidenceBasedSampler,
|
| 37 |
+
DensePoseCSEConfidenceBasedSampler,
|
| 38 |
+
DensePoseCSEUniformSampler,
|
| 39 |
+
DensePoseUniformSampler,
|
| 40 |
+
MaskFromDensePoseSampler,
|
| 41 |
+
PredictionToGroundTruthSampler,
|
| 42 |
+
)
|
| 43 |
+
from .transform import ImageResizeTransform
|
| 44 |
+
from .utils import get_category_to_class_mapping, get_class_to_mesh_name_mapping
|
| 45 |
+
from .video import (
|
| 46 |
+
FirstKFramesSelector,
|
| 47 |
+
FrameSelectionStrategy,
|
| 48 |
+
LastKFramesSelector,
|
| 49 |
+
RandomKFramesSelector,
|
| 50 |
+
VideoKeyframeDataset,
|
| 51 |
+
video_list_from_file,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
__all__ = ["build_detection_train_loader", "build_detection_test_loader"]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
Instance = Dict[str, Any]
|
| 58 |
+
InstancePredicate = Callable[[Instance], bool]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _compute_num_images_per_worker(cfg: CfgNode) -> int:
|
| 62 |
+
num_workers = get_world_size()
|
| 63 |
+
images_per_batch = cfg.SOLVER.IMS_PER_BATCH
|
| 64 |
+
assert (
|
| 65 |
+
images_per_batch % num_workers == 0
|
| 66 |
+
), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format(
|
| 67 |
+
images_per_batch, num_workers
|
| 68 |
+
)
|
| 69 |
+
assert (
|
| 70 |
+
images_per_batch >= num_workers
|
| 71 |
+
), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format(
|
| 72 |
+
images_per_batch, num_workers
|
| 73 |
+
)
|
| 74 |
+
images_per_worker = images_per_batch // num_workers
|
| 75 |
+
return images_per_worker
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _map_category_id_to_contiguous_id(dataset_name: str, dataset_dicts: Iterable[Instance]) -> None:
|
| 79 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 80 |
+
for dataset_dict in dataset_dicts:
|
| 81 |
+
for ann in dataset_dict["annotations"]:
|
| 82 |
+
ann["category_id"] = meta.thing_dataset_id_to_contiguous_id[ann["category_id"]]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@dataclass
|
| 86 |
+
class _DatasetCategory:
|
| 87 |
+
"""
|
| 88 |
+
Class representing category data in a dataset:
|
| 89 |
+
- id: category ID, as specified in the dataset annotations file
|
| 90 |
+
- name: category name, as specified in the dataset annotations file
|
| 91 |
+
- mapped_id: category ID after applying category maps (DATASETS.CATEGORY_MAPS config option)
|
| 92 |
+
- mapped_name: category name after applying category maps
|
| 93 |
+
- dataset_name: dataset in which the category is defined
|
| 94 |
+
|
| 95 |
+
For example, when training models in a class-agnostic manner, one could take LVIS 1.0
|
| 96 |
+
dataset and map the animal categories to the same category as human data from COCO:
|
| 97 |
+
id = 225
|
| 98 |
+
name = "cat"
|
| 99 |
+
mapped_id = 1
|
| 100 |
+
mapped_name = "person"
|
| 101 |
+
dataset_name = "lvis_v1_animals_dp_train"
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
id: int
|
| 105 |
+
name: str
|
| 106 |
+
mapped_id: int
|
| 107 |
+
mapped_name: str
|
| 108 |
+
dataset_name: str
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
_MergedCategoriesT = Dict[int, List[_DatasetCategory]]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _add_category_id_to_contiguous_id_maps_to_metadata(
|
| 115 |
+
merged_categories: _MergedCategoriesT,
|
| 116 |
+
) -> None:
|
| 117 |
+
merged_categories_per_dataset = {}
|
| 118 |
+
for contiguous_cat_id, cat_id in enumerate(sorted(merged_categories.keys())):
|
| 119 |
+
for cat in merged_categories[cat_id]:
|
| 120 |
+
if cat.dataset_name not in merged_categories_per_dataset:
|
| 121 |
+
merged_categories_per_dataset[cat.dataset_name] = defaultdict(list)
|
| 122 |
+
merged_categories_per_dataset[cat.dataset_name][cat_id].append(
|
| 123 |
+
(
|
| 124 |
+
contiguous_cat_id,
|
| 125 |
+
cat,
|
| 126 |
+
)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
logger = logging.getLogger(__name__)
|
| 130 |
+
for dataset_name, merged_categories in merged_categories_per_dataset.items():
|
| 131 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 132 |
+
if not hasattr(meta, "thing_classes"):
|
| 133 |
+
meta.thing_classes = []
|
| 134 |
+
meta.thing_dataset_id_to_contiguous_id = {}
|
| 135 |
+
meta.thing_dataset_id_to_merged_id = {}
|
| 136 |
+
else:
|
| 137 |
+
meta.thing_classes.clear()
|
| 138 |
+
meta.thing_dataset_id_to_contiguous_id.clear()
|
| 139 |
+
meta.thing_dataset_id_to_merged_id.clear()
|
| 140 |
+
logger.info(f"Dataset {dataset_name}: category ID to contiguous ID mapping:")
|
| 141 |
+
for _cat_id, categories in sorted(merged_categories.items()):
|
| 142 |
+
added_to_thing_classes = False
|
| 143 |
+
for contiguous_cat_id, cat in categories:
|
| 144 |
+
if not added_to_thing_classes:
|
| 145 |
+
meta.thing_classes.append(cat.mapped_name)
|
| 146 |
+
added_to_thing_classes = True
|
| 147 |
+
meta.thing_dataset_id_to_contiguous_id[cat.id] = contiguous_cat_id
|
| 148 |
+
meta.thing_dataset_id_to_merged_id[cat.id] = cat.mapped_id
|
| 149 |
+
logger.info(f"{cat.id} ({cat.name}) -> {contiguous_cat_id}")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _maybe_create_general_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
|
| 153 |
+
def has_annotations(instance: Instance) -> bool:
|
| 154 |
+
return "annotations" in instance
|
| 155 |
+
|
| 156 |
+
def has_only_crowd_anotations(instance: Instance) -> bool:
|
| 157 |
+
for ann in instance["annotations"]:
|
| 158 |
+
if ann.get("is_crowd", 0) == 0:
|
| 159 |
+
return False
|
| 160 |
+
return True
|
| 161 |
+
|
| 162 |
+
def general_keep_instance_predicate(instance: Instance) -> bool:
|
| 163 |
+
return has_annotations(instance) and not has_only_crowd_anotations(instance)
|
| 164 |
+
|
| 165 |
+
if not cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS:
|
| 166 |
+
return None
|
| 167 |
+
return general_keep_instance_predicate
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _maybe_create_keypoints_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
|
| 171 |
+
|
| 172 |
+
min_num_keypoints = cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
|
| 173 |
+
|
| 174 |
+
def has_sufficient_num_keypoints(instance: Instance) -> bool:
|
| 175 |
+
num_kpts = sum(
|
| 176 |
+
(np.array(ann["keypoints"][2::3]) > 0).sum()
|
| 177 |
+
for ann in instance["annotations"]
|
| 178 |
+
if "keypoints" in ann
|
| 179 |
+
)
|
| 180 |
+
return num_kpts >= min_num_keypoints
|
| 181 |
+
|
| 182 |
+
if cfg.MODEL.KEYPOINT_ON and (min_num_keypoints > 0):
|
| 183 |
+
return has_sufficient_num_keypoints
|
| 184 |
+
return None
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _maybe_create_mask_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
|
| 188 |
+
if not cfg.MODEL.MASK_ON:
|
| 189 |
+
return None
|
| 190 |
+
|
| 191 |
+
def has_mask_annotations(instance: Instance) -> bool:
|
| 192 |
+
return any("segmentation" in ann for ann in instance["annotations"])
|
| 193 |
+
|
| 194 |
+
return has_mask_annotations
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _maybe_create_densepose_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
|
| 198 |
+
if not cfg.MODEL.DENSEPOSE_ON:
|
| 199 |
+
return None
|
| 200 |
+
|
| 201 |
+
use_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
|
| 202 |
+
|
| 203 |
+
def has_densepose_annotations(instance: Instance) -> bool:
|
| 204 |
+
for ann in instance["annotations"]:
|
| 205 |
+
if all(key in ann for key in DENSEPOSE_IUV_KEYS_WITHOUT_MASK) or all(
|
| 206 |
+
key in ann for key in DENSEPOSE_CSE_KEYS_WITHOUT_MASK
|
| 207 |
+
):
|
| 208 |
+
return True
|
| 209 |
+
if use_masks and "segmentation" in ann:
|
| 210 |
+
return True
|
| 211 |
+
return False
|
| 212 |
+
|
| 213 |
+
return has_densepose_annotations
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _maybe_create_specific_keep_instance_predicate(cfg: CfgNode) -> Optional[InstancePredicate]:
|
| 217 |
+
specific_predicate_creators = [
|
| 218 |
+
_maybe_create_keypoints_keep_instance_predicate,
|
| 219 |
+
_maybe_create_mask_keep_instance_predicate,
|
| 220 |
+
_maybe_create_densepose_keep_instance_predicate,
|
| 221 |
+
]
|
| 222 |
+
predicates = [creator(cfg) for creator in specific_predicate_creators]
|
| 223 |
+
predicates = [p for p in predicates if p is not None]
|
| 224 |
+
if not predicates:
|
| 225 |
+
return None
|
| 226 |
+
|
| 227 |
+
def combined_predicate(instance: Instance) -> bool:
|
| 228 |
+
return any(p(instance) for p in predicates)
|
| 229 |
+
|
| 230 |
+
return combined_predicate
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _get_train_keep_instance_predicate(cfg: CfgNode):
|
| 234 |
+
general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg)
|
| 235 |
+
combined_specific_keep_predicate = _maybe_create_specific_keep_instance_predicate(cfg)
|
| 236 |
+
|
| 237 |
+
def combined_general_specific_keep_predicate(instance: Instance) -> bool:
|
| 238 |
+
return general_keep_predicate(instance) and combined_specific_keep_predicate(instance)
|
| 239 |
+
|
| 240 |
+
if (general_keep_predicate is None) and (combined_specific_keep_predicate is None):
|
| 241 |
+
return None
|
| 242 |
+
if general_keep_predicate is None:
|
| 243 |
+
return combined_specific_keep_predicate
|
| 244 |
+
if combined_specific_keep_predicate is None:
|
| 245 |
+
return general_keep_predicate
|
| 246 |
+
return combined_general_specific_keep_predicate
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _get_test_keep_instance_predicate(cfg: CfgNode):
|
| 250 |
+
general_keep_predicate = _maybe_create_general_keep_instance_predicate(cfg)
|
| 251 |
+
return general_keep_predicate
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def _maybe_filter_and_map_categories(
|
| 255 |
+
dataset_name: str, dataset_dicts: List[Instance]
|
| 256 |
+
) -> List[Instance]:
|
| 257 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 258 |
+
category_id_map = meta.thing_dataset_id_to_contiguous_id
|
| 259 |
+
filtered_dataset_dicts = []
|
| 260 |
+
for dataset_dict in dataset_dicts:
|
| 261 |
+
anns = []
|
| 262 |
+
for ann in dataset_dict["annotations"]:
|
| 263 |
+
cat_id = ann["category_id"]
|
| 264 |
+
if cat_id not in category_id_map:
|
| 265 |
+
continue
|
| 266 |
+
ann["category_id"] = category_id_map[cat_id]
|
| 267 |
+
anns.append(ann)
|
| 268 |
+
dataset_dict["annotations"] = anns
|
| 269 |
+
filtered_dataset_dicts.append(dataset_dict)
|
| 270 |
+
return filtered_dataset_dicts
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _add_category_whitelists_to_metadata(cfg: CfgNode) -> None:
|
| 274 |
+
for dataset_name, whitelisted_cat_ids in cfg.DATASETS.WHITELISTED_CATEGORIES.items():
|
| 275 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 276 |
+
meta.whitelisted_categories = whitelisted_cat_ids
|
| 277 |
+
logger = logging.getLogger(__name__)
|
| 278 |
+
logger.info(
|
| 279 |
+
"Whitelisted categories for dataset {}: {}".format(
|
| 280 |
+
dataset_name, meta.whitelisted_categories
|
| 281 |
+
)
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def _add_category_maps_to_metadata(cfg: CfgNode) -> None:
|
| 286 |
+
for dataset_name, category_map in cfg.DATASETS.CATEGORY_MAPS.items():
|
| 287 |
+
category_map = {
|
| 288 |
+
int(cat_id_src): int(cat_id_dst) for cat_id_src, cat_id_dst in category_map.items()
|
| 289 |
+
}
|
| 290 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 291 |
+
meta.category_map = category_map
|
| 292 |
+
logger = logging.getLogger(__name__)
|
| 293 |
+
logger.info("Category maps for dataset {}: {}".format(dataset_name, meta.category_map))
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _add_category_info_to_bootstrapping_metadata(dataset_name: str, dataset_cfg: CfgNode) -> None:
|
| 297 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 298 |
+
meta.category_to_class_mapping = get_category_to_class_mapping(dataset_cfg)
|
| 299 |
+
meta.categories = dataset_cfg.CATEGORIES
|
| 300 |
+
meta.max_count_per_category = dataset_cfg.MAX_COUNT_PER_CATEGORY
|
| 301 |
+
logger = logging.getLogger(__name__)
|
| 302 |
+
logger.info(
|
| 303 |
+
"Category to class mapping for dataset {}: {}".format(
|
| 304 |
+
dataset_name, meta.category_to_class_mapping
|
| 305 |
+
)
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def _maybe_add_class_to_mesh_name_map_to_metadata(dataset_names: List[str], cfg: CfgNode) -> None:
|
| 310 |
+
for dataset_name in dataset_names:
|
| 311 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 312 |
+
if not hasattr(meta, "class_to_mesh_name"):
|
| 313 |
+
meta.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _merge_categories(dataset_names: Collection[str]) -> _MergedCategoriesT:
|
| 317 |
+
merged_categories = defaultdict(list)
|
| 318 |
+
category_names = {}
|
| 319 |
+
for dataset_name in dataset_names:
|
| 320 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 321 |
+
whitelisted_categories = meta.get("whitelisted_categories")
|
| 322 |
+
category_map = meta.get("category_map", {})
|
| 323 |
+
cat_ids = (
|
| 324 |
+
whitelisted_categories if whitelisted_categories is not None else meta.categories.keys()
|
| 325 |
+
)
|
| 326 |
+
for cat_id in cat_ids:
|
| 327 |
+
cat_name = meta.categories[cat_id]
|
| 328 |
+
cat_id_mapped = category_map.get(cat_id, cat_id)
|
| 329 |
+
if cat_id_mapped == cat_id or cat_id_mapped in cat_ids:
|
| 330 |
+
category_names[cat_id] = cat_name
|
| 331 |
+
else:
|
| 332 |
+
category_names[cat_id] = str(cat_id_mapped)
|
| 333 |
+
# assign temporary mapped category name, this name can be changed
|
| 334 |
+
# during the second pass, since mapped ID can correspond to a category
|
| 335 |
+
# from a different dataset
|
| 336 |
+
cat_name_mapped = meta.categories[cat_id_mapped]
|
| 337 |
+
merged_categories[cat_id_mapped].append(
|
| 338 |
+
_DatasetCategory(
|
| 339 |
+
id=cat_id,
|
| 340 |
+
name=cat_name,
|
| 341 |
+
mapped_id=cat_id_mapped,
|
| 342 |
+
mapped_name=cat_name_mapped,
|
| 343 |
+
dataset_name=dataset_name,
|
| 344 |
+
)
|
| 345 |
+
)
|
| 346 |
+
# second pass to assign proper mapped category names
|
| 347 |
+
for cat_id, categories in merged_categories.items():
|
| 348 |
+
for cat in categories:
|
| 349 |
+
if cat_id in category_names and cat.mapped_name != category_names[cat_id]:
|
| 350 |
+
cat.mapped_name = category_names[cat_id]
|
| 351 |
+
|
| 352 |
+
return merged_categories
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def _warn_if_merged_different_categories(merged_categories: _MergedCategoriesT) -> None:
|
| 356 |
+
logger = logging.getLogger(__name__)
|
| 357 |
+
for cat_id in merged_categories:
|
| 358 |
+
merged_categories_i = merged_categories[cat_id]
|
| 359 |
+
first_cat_name = merged_categories_i[0].name
|
| 360 |
+
if len(merged_categories_i) > 1 and not all(
|
| 361 |
+
cat.name == first_cat_name for cat in merged_categories_i[1:]
|
| 362 |
+
):
|
| 363 |
+
cat_summary_str = ", ".join(
|
| 364 |
+
[f"{cat.id} ({cat.name}) from {cat.dataset_name}" for cat in merged_categories_i]
|
| 365 |
+
)
|
| 366 |
+
logger.warning(
|
| 367 |
+
f"Merged category {cat_id} corresponds to the following categories: "
|
| 368 |
+
f"{cat_summary_str}"
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def combine_detection_dataset_dicts(
|
| 373 |
+
dataset_names: Collection[str],
|
| 374 |
+
keep_instance_predicate: Optional[InstancePredicate] = None,
|
| 375 |
+
proposal_files: Optional[Collection[str]] = None,
|
| 376 |
+
) -> List[Instance]:
|
| 377 |
+
"""
|
| 378 |
+
Load and prepare dataset dicts for training / testing
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
dataset_names (Collection[str]): a list of dataset names
|
| 382 |
+
keep_instance_predicate (Callable: Dict[str, Any] -> bool): predicate
|
| 383 |
+
applied to instance dicts which defines whether to keep the instance
|
| 384 |
+
proposal_files (Collection[str]): if given, a list of object proposal files
|
| 385 |
+
that match each dataset in `dataset_names`.
|
| 386 |
+
"""
|
| 387 |
+
assert len(dataset_names)
|
| 388 |
+
if proposal_files is None:
|
| 389 |
+
proposal_files = [None] * len(dataset_names)
|
| 390 |
+
assert len(dataset_names) == len(proposal_files)
|
| 391 |
+
# load datasets and metadata
|
| 392 |
+
dataset_name_to_dicts = {}
|
| 393 |
+
for dataset_name in dataset_names:
|
| 394 |
+
dataset_name_to_dicts[dataset_name] = DatasetCatalog.get(dataset_name)
|
| 395 |
+
assert len(dataset_name_to_dicts), f"Dataset '{dataset_name}' is empty!"
|
| 396 |
+
# merge categories, requires category metadata to be loaded
|
| 397 |
+
# cat_id -> [(orig_cat_id, cat_name, dataset_name)]
|
| 398 |
+
merged_categories = _merge_categories(dataset_names)
|
| 399 |
+
_warn_if_merged_different_categories(merged_categories)
|
| 400 |
+
merged_category_names = [
|
| 401 |
+
merged_categories[cat_id][0].mapped_name for cat_id in sorted(merged_categories)
|
| 402 |
+
]
|
| 403 |
+
# map to contiguous category IDs
|
| 404 |
+
_add_category_id_to_contiguous_id_maps_to_metadata(merged_categories)
|
| 405 |
+
# load annotations and dataset metadata
|
| 406 |
+
for dataset_name, proposal_file in zip(dataset_names, proposal_files):
|
| 407 |
+
dataset_dicts = dataset_name_to_dicts[dataset_name]
|
| 408 |
+
assert len(dataset_dicts), f"Dataset '{dataset_name}' is empty!"
|
| 409 |
+
if proposal_file is not None:
|
| 410 |
+
dataset_dicts = load_proposals_into_dataset(dataset_dicts, proposal_file)
|
| 411 |
+
dataset_dicts = _maybe_filter_and_map_categories(dataset_name, dataset_dicts)
|
| 412 |
+
print_instances_class_histogram(dataset_dicts, merged_category_names)
|
| 413 |
+
dataset_name_to_dicts[dataset_name] = dataset_dicts
|
| 414 |
+
|
| 415 |
+
if keep_instance_predicate is not None:
|
| 416 |
+
all_datasets_dicts_plain = [
|
| 417 |
+
d
|
| 418 |
+
for d in itertools.chain.from_iterable(dataset_name_to_dicts.values())
|
| 419 |
+
if keep_instance_predicate(d)
|
| 420 |
+
]
|
| 421 |
+
else:
|
| 422 |
+
all_datasets_dicts_plain = list(
|
| 423 |
+
itertools.chain.from_iterable(dataset_name_to_dicts.values())
|
| 424 |
+
)
|
| 425 |
+
return all_datasets_dicts_plain
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def build_detection_train_loader(cfg: CfgNode, mapper=None):
|
| 429 |
+
"""
|
| 430 |
+
A data loader is created in a way similar to that of Detectron2.
|
| 431 |
+
The main differences are:
|
| 432 |
+
- it allows to combine datasets with different but compatible object category sets
|
| 433 |
+
|
| 434 |
+
The data loader is created by the following steps:
|
| 435 |
+
1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
|
| 436 |
+
2. Start workers to work on the dicts. Each worker will:
|
| 437 |
+
* Map each metadata dict into another format to be consumed by the model.
|
| 438 |
+
* Batch them by simply putting dicts into a list.
|
| 439 |
+
The batched ``list[mapped_dict]`` is what this dataloader will return.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
cfg (CfgNode): the config
|
| 443 |
+
mapper (callable): a callable which takes a sample (dict) from dataset and
|
| 444 |
+
returns the format to be consumed by the model.
|
| 445 |
+
By default it will be `DatasetMapper(cfg, True)`.
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
an infinite iterator of training data
|
| 449 |
+
"""
|
| 450 |
+
|
| 451 |
+
_add_category_whitelists_to_metadata(cfg)
|
| 452 |
+
_add_category_maps_to_metadata(cfg)
|
| 453 |
+
_maybe_add_class_to_mesh_name_map_to_metadata(cfg.DATASETS.TRAIN, cfg)
|
| 454 |
+
dataset_dicts = combine_detection_dataset_dicts(
|
| 455 |
+
cfg.DATASETS.TRAIN,
|
| 456 |
+
keep_instance_predicate=_get_train_keep_instance_predicate(cfg),
|
| 457 |
+
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
|
| 458 |
+
)
|
| 459 |
+
if mapper is None:
|
| 460 |
+
mapper = DatasetMapper(cfg, True)
|
| 461 |
+
return d2_build_detection_train_loader(cfg, dataset=dataset_dicts, mapper=mapper)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def build_detection_test_loader(cfg, dataset_name, mapper=None):
|
| 465 |
+
"""
|
| 466 |
+
Similar to `build_detection_train_loader`.
|
| 467 |
+
But this function uses the given `dataset_name` argument (instead of the names in cfg),
|
| 468 |
+
and uses batch size 1.
|
| 469 |
+
|
| 470 |
+
Args:
|
| 471 |
+
cfg: a detectron2 CfgNode
|
| 472 |
+
dataset_name (str): a name of the dataset that's available in the DatasetCatalog
|
| 473 |
+
mapper (callable): a callable which takes a sample (dict) from dataset
|
| 474 |
+
and returns the format to be consumed by the model.
|
| 475 |
+
By default it will be `DatasetMapper(cfg, False)`.
|
| 476 |
+
|
| 477 |
+
Returns:
|
| 478 |
+
DataLoader: a torch DataLoader, that loads the given detection
|
| 479 |
+
dataset, with test-time transformation and batching.
|
| 480 |
+
"""
|
| 481 |
+
_add_category_whitelists_to_metadata(cfg)
|
| 482 |
+
_add_category_maps_to_metadata(cfg)
|
| 483 |
+
_maybe_add_class_to_mesh_name_map_to_metadata([dataset_name], cfg)
|
| 484 |
+
dataset_dicts = combine_detection_dataset_dicts(
|
| 485 |
+
[dataset_name],
|
| 486 |
+
keep_instance_predicate=_get_test_keep_instance_predicate(cfg),
|
| 487 |
+
proposal_files=(
|
| 488 |
+
[cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)]]
|
| 489 |
+
if cfg.MODEL.LOAD_PROPOSALS
|
| 490 |
+
else None
|
| 491 |
+
),
|
| 492 |
+
)
|
| 493 |
+
sampler = None
|
| 494 |
+
if not cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE:
|
| 495 |
+
sampler = torch.utils.data.SequentialSampler(dataset_dicts)
|
| 496 |
+
if mapper is None:
|
| 497 |
+
mapper = DatasetMapper(cfg, False)
|
| 498 |
+
return d2_build_detection_test_loader(
|
| 499 |
+
dataset_dicts, mapper=mapper, num_workers=cfg.DATALOADER.NUM_WORKERS, sampler=sampler
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def build_frame_selector(cfg: CfgNode):
|
| 504 |
+
strategy = FrameSelectionStrategy(cfg.STRATEGY)
|
| 505 |
+
if strategy == FrameSelectionStrategy.RANDOM_K:
|
| 506 |
+
frame_selector = RandomKFramesSelector(cfg.NUM_IMAGES)
|
| 507 |
+
elif strategy == FrameSelectionStrategy.FIRST_K:
|
| 508 |
+
frame_selector = FirstKFramesSelector(cfg.NUM_IMAGES)
|
| 509 |
+
elif strategy == FrameSelectionStrategy.LAST_K:
|
| 510 |
+
frame_selector = LastKFramesSelector(cfg.NUM_IMAGES)
|
| 511 |
+
elif strategy == FrameSelectionStrategy.ALL:
|
| 512 |
+
frame_selector = None
|
| 513 |
+
# pyre-fixme[61]: `frame_selector` may not be initialized here.
|
| 514 |
+
return frame_selector
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def build_transform(cfg: CfgNode, data_type: str):
|
| 518 |
+
if cfg.TYPE == "resize":
|
| 519 |
+
if data_type == "image":
|
| 520 |
+
return ImageResizeTransform(cfg.MIN_SIZE, cfg.MAX_SIZE)
|
| 521 |
+
raise ValueError(f"Unknown transform {cfg.TYPE} for data type {data_type}")
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def build_combined_loader(cfg: CfgNode, loaders: Collection[Loader], ratios: Sequence[float]):
|
| 525 |
+
images_per_worker = _compute_num_images_per_worker(cfg)
|
| 526 |
+
return CombinedDataLoader(loaders, images_per_worker, ratios)
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def build_bootstrap_dataset(dataset_name: str, cfg: CfgNode) -> Sequence[torch.Tensor]:
|
| 530 |
+
"""
|
| 531 |
+
Build dataset that provides data to bootstrap on
|
| 532 |
+
|
| 533 |
+
Args:
|
| 534 |
+
dataset_name (str): Name of the dataset, needs to have associated metadata
|
| 535 |
+
to load the data
|
| 536 |
+
cfg (CfgNode): bootstrapping config
|
| 537 |
+
Returns:
|
| 538 |
+
Sequence[Tensor] - dataset that provides image batches, Tensors of size
|
| 539 |
+
[N, C, H, W] of type float32
|
| 540 |
+
"""
|
| 541 |
+
logger = logging.getLogger(__name__)
|
| 542 |
+
_add_category_info_to_bootstrapping_metadata(dataset_name, cfg)
|
| 543 |
+
meta = MetadataCatalog.get(dataset_name)
|
| 544 |
+
factory = BootstrapDatasetFactoryCatalog.get(meta.dataset_type)
|
| 545 |
+
dataset = None
|
| 546 |
+
if factory is not None:
|
| 547 |
+
dataset = factory(meta, cfg)
|
| 548 |
+
if dataset is None:
|
| 549 |
+
logger.warning(f"Failed to create dataset {dataset_name} of type {meta.dataset_type}")
|
| 550 |
+
return dataset
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def build_data_sampler(cfg: CfgNode, sampler_cfg: CfgNode, embedder: Optional[torch.nn.Module]):
|
| 554 |
+
if sampler_cfg.TYPE == "densepose_uniform":
|
| 555 |
+
data_sampler = PredictionToGroundTruthSampler()
|
| 556 |
+
# transform densepose pred -> gt
|
| 557 |
+
data_sampler.register_sampler(
|
| 558 |
+
"pred_densepose",
|
| 559 |
+
"gt_densepose",
|
| 560 |
+
DensePoseUniformSampler(count_per_class=sampler_cfg.COUNT_PER_CLASS),
|
| 561 |
+
)
|
| 562 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
| 563 |
+
return data_sampler
|
| 564 |
+
elif sampler_cfg.TYPE == "densepose_UV_confidence":
|
| 565 |
+
data_sampler = PredictionToGroundTruthSampler()
|
| 566 |
+
# transform densepose pred -> gt
|
| 567 |
+
data_sampler.register_sampler(
|
| 568 |
+
"pred_densepose",
|
| 569 |
+
"gt_densepose",
|
| 570 |
+
DensePoseConfidenceBasedSampler(
|
| 571 |
+
confidence_channel="sigma_2",
|
| 572 |
+
count_per_class=sampler_cfg.COUNT_PER_CLASS,
|
| 573 |
+
search_proportion=0.5,
|
| 574 |
+
),
|
| 575 |
+
)
|
| 576 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
| 577 |
+
return data_sampler
|
| 578 |
+
elif sampler_cfg.TYPE == "densepose_fine_segm_confidence":
|
| 579 |
+
data_sampler = PredictionToGroundTruthSampler()
|
| 580 |
+
# transform densepose pred -> gt
|
| 581 |
+
data_sampler.register_sampler(
|
| 582 |
+
"pred_densepose",
|
| 583 |
+
"gt_densepose",
|
| 584 |
+
DensePoseConfidenceBasedSampler(
|
| 585 |
+
confidence_channel="fine_segm_confidence",
|
| 586 |
+
count_per_class=sampler_cfg.COUNT_PER_CLASS,
|
| 587 |
+
search_proportion=0.5,
|
| 588 |
+
),
|
| 589 |
+
)
|
| 590 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
| 591 |
+
return data_sampler
|
| 592 |
+
elif sampler_cfg.TYPE == "densepose_coarse_segm_confidence":
|
| 593 |
+
data_sampler = PredictionToGroundTruthSampler()
|
| 594 |
+
# transform densepose pred -> gt
|
| 595 |
+
data_sampler.register_sampler(
|
| 596 |
+
"pred_densepose",
|
| 597 |
+
"gt_densepose",
|
| 598 |
+
DensePoseConfidenceBasedSampler(
|
| 599 |
+
confidence_channel="coarse_segm_confidence",
|
| 600 |
+
count_per_class=sampler_cfg.COUNT_PER_CLASS,
|
| 601 |
+
search_proportion=0.5,
|
| 602 |
+
),
|
| 603 |
+
)
|
| 604 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
| 605 |
+
return data_sampler
|
| 606 |
+
elif sampler_cfg.TYPE == "densepose_cse_uniform":
|
| 607 |
+
assert embedder is not None
|
| 608 |
+
data_sampler = PredictionToGroundTruthSampler()
|
| 609 |
+
# transform densepose pred -> gt
|
| 610 |
+
data_sampler.register_sampler(
|
| 611 |
+
"pred_densepose",
|
| 612 |
+
"gt_densepose",
|
| 613 |
+
DensePoseCSEUniformSampler(
|
| 614 |
+
cfg=cfg,
|
| 615 |
+
use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES,
|
| 616 |
+
embedder=embedder,
|
| 617 |
+
count_per_class=sampler_cfg.COUNT_PER_CLASS,
|
| 618 |
+
),
|
| 619 |
+
)
|
| 620 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
| 621 |
+
return data_sampler
|
| 622 |
+
elif sampler_cfg.TYPE == "densepose_cse_coarse_segm_confidence":
|
| 623 |
+
assert embedder is not None
|
| 624 |
+
data_sampler = PredictionToGroundTruthSampler()
|
| 625 |
+
# transform densepose pred -> gt
|
| 626 |
+
data_sampler.register_sampler(
|
| 627 |
+
"pred_densepose",
|
| 628 |
+
"gt_densepose",
|
| 629 |
+
DensePoseCSEConfidenceBasedSampler(
|
| 630 |
+
cfg=cfg,
|
| 631 |
+
use_gt_categories=sampler_cfg.USE_GROUND_TRUTH_CATEGORIES,
|
| 632 |
+
embedder=embedder,
|
| 633 |
+
confidence_channel="coarse_segm_confidence",
|
| 634 |
+
count_per_class=sampler_cfg.COUNT_PER_CLASS,
|
| 635 |
+
search_proportion=0.5,
|
| 636 |
+
),
|
| 637 |
+
)
|
| 638 |
+
data_sampler.register_sampler("pred_densepose", "gt_masks", MaskFromDensePoseSampler())
|
| 639 |
+
return data_sampler
|
| 640 |
+
|
| 641 |
+
raise ValueError(f"Unknown data sampler type {sampler_cfg.TYPE}")
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def build_data_filter(cfg: CfgNode):
|
| 645 |
+
if cfg.TYPE == "detection_score":
|
| 646 |
+
min_score = cfg.MIN_VALUE
|
| 647 |
+
return ScoreBasedFilter(min_score=min_score)
|
| 648 |
+
raise ValueError(f"Unknown data filter type {cfg.TYPE}")
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
def build_inference_based_loader(
|
| 652 |
+
cfg: CfgNode,
|
| 653 |
+
dataset_cfg: CfgNode,
|
| 654 |
+
model: torch.nn.Module,
|
| 655 |
+
embedder: Optional[torch.nn.Module] = None,
|
| 656 |
+
) -> InferenceBasedLoader:
|
| 657 |
+
"""
|
| 658 |
+
Constructs data loader based on inference results of a model.
|
| 659 |
+
"""
|
| 660 |
+
dataset = build_bootstrap_dataset(dataset_cfg.DATASET, dataset_cfg.IMAGE_LOADER)
|
| 661 |
+
meta = MetadataCatalog.get(dataset_cfg.DATASET)
|
| 662 |
+
training_sampler = TrainingSampler(len(dataset))
|
| 663 |
+
data_loader = torch.utils.data.DataLoader(
|
| 664 |
+
dataset, # pyre-ignore[6]
|
| 665 |
+
batch_size=dataset_cfg.IMAGE_LOADER.BATCH_SIZE,
|
| 666 |
+
sampler=training_sampler,
|
| 667 |
+
num_workers=dataset_cfg.IMAGE_LOADER.NUM_WORKERS,
|
| 668 |
+
collate_fn=trivial_batch_collator,
|
| 669 |
+
worker_init_fn=worker_init_reset_seed,
|
| 670 |
+
)
|
| 671 |
+
return InferenceBasedLoader(
|
| 672 |
+
model,
|
| 673 |
+
data_loader=data_loader,
|
| 674 |
+
data_sampler=build_data_sampler(cfg, dataset_cfg.DATA_SAMPLER, embedder),
|
| 675 |
+
data_filter=build_data_filter(dataset_cfg.FILTER),
|
| 676 |
+
shuffle=True,
|
| 677 |
+
batch_size=dataset_cfg.INFERENCE.OUTPUT_BATCH_SIZE,
|
| 678 |
+
inference_batch_size=dataset_cfg.INFERENCE.INPUT_BATCH_SIZE,
|
| 679 |
+
category_to_class_mapping=meta.category_to_class_mapping,
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def has_inference_based_loaders(cfg: CfgNode) -> bool:
|
| 684 |
+
"""
|
| 685 |
+
Returns True, if at least one inferense-based loader must
|
| 686 |
+
be instantiated for training
|
| 687 |
+
"""
|
| 688 |
+
return len(cfg.BOOTSTRAP_DATASETS) > 0
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def build_inference_based_loaders(
|
| 692 |
+
cfg: CfgNode, model: torch.nn.Module
|
| 693 |
+
) -> Tuple[List[InferenceBasedLoader], List[float]]:
|
| 694 |
+
loaders = []
|
| 695 |
+
ratios = []
|
| 696 |
+
embedder = build_densepose_embedder(cfg).to(device=model.device) # pyre-ignore[16]
|
| 697 |
+
for dataset_spec in cfg.BOOTSTRAP_DATASETS:
|
| 698 |
+
dataset_cfg = get_bootstrap_dataset_config().clone()
|
| 699 |
+
dataset_cfg.merge_from_other_cfg(CfgNode(dataset_spec))
|
| 700 |
+
loader = build_inference_based_loader(cfg, dataset_cfg, model, embedder)
|
| 701 |
+
loaders.append(loader)
|
| 702 |
+
ratios.append(dataset_cfg.RATIO)
|
| 703 |
+
return loaders, ratios
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def build_video_list_dataset(meta: Metadata, cfg: CfgNode):
|
| 707 |
+
video_list_fpath = meta.video_list_fpath
|
| 708 |
+
video_base_path = meta.video_base_path
|
| 709 |
+
category = meta.category
|
| 710 |
+
if cfg.TYPE == "video_keyframe":
|
| 711 |
+
frame_selector = build_frame_selector(cfg.SELECT)
|
| 712 |
+
transform = build_transform(cfg.TRANSFORM, data_type="image")
|
| 713 |
+
video_list = video_list_from_file(video_list_fpath, video_base_path)
|
| 714 |
+
keyframe_helper_fpath = getattr(cfg, "KEYFRAME_HELPER", None)
|
| 715 |
+
return VideoKeyframeDataset(
|
| 716 |
+
video_list, category, frame_selector, transform, keyframe_helper_fpath
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
class _BootstrapDatasetFactoryCatalog(UserDict):
|
| 721 |
+
"""
|
| 722 |
+
A global dictionary that stores information about bootstrapped datasets creation functions
|
| 723 |
+
from metadata and config, for diverse DatasetType
|
| 724 |
+
"""
|
| 725 |
+
|
| 726 |
+
def register(self, dataset_type: DatasetType, factory: Callable[[Metadata, CfgNode], Dataset]):
|
| 727 |
+
"""
|
| 728 |
+
Args:
|
| 729 |
+
dataset_type (DatasetType): a DatasetType e.g. DatasetType.VIDEO_LIST
|
| 730 |
+
factory (Callable[Metadata, CfgNode]): a callable which takes Metadata and cfg
|
| 731 |
+
arguments and returns a dataset object.
|
| 732 |
+
"""
|
| 733 |
+
assert dataset_type not in self, "Dataset '{}' is already registered!".format(dataset_type)
|
| 734 |
+
self[dataset_type] = factory
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
BootstrapDatasetFactoryCatalog = _BootstrapDatasetFactoryCatalog()
|
| 738 |
+
BootstrapDatasetFactoryCatalog.register(DatasetType.VIDEO_LIST, build_video_list_dataset)
|
CatVTON/densepose/data/combined_loader.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
from collections import deque
|
| 7 |
+
from typing import Any, Collection, Deque, Iterable, Iterator, List, Sequence
|
| 8 |
+
|
| 9 |
+
Loader = Iterable[Any]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _pooled_next(iterator: Iterator[Any], pool: Deque[Any]):
|
| 13 |
+
if not pool:
|
| 14 |
+
pool.extend(next(iterator))
|
| 15 |
+
return pool.popleft()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CombinedDataLoader:
|
| 19 |
+
"""
|
| 20 |
+
Combines data loaders using the provided sampling ratios
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
BATCH_COUNT = 100
|
| 24 |
+
|
| 25 |
+
def __init__(self, loaders: Collection[Loader], batch_size: int, ratios: Sequence[float]):
|
| 26 |
+
self.loaders = loaders
|
| 27 |
+
self.batch_size = batch_size
|
| 28 |
+
self.ratios = ratios
|
| 29 |
+
|
| 30 |
+
def __iter__(self) -> Iterator[List[Any]]:
|
| 31 |
+
iters = [iter(loader) for loader in self.loaders]
|
| 32 |
+
indices = []
|
| 33 |
+
pool = [deque()] * len(iters)
|
| 34 |
+
# infinite iterator, as in D2
|
| 35 |
+
while True:
|
| 36 |
+
if not indices:
|
| 37 |
+
# just a buffer of indices, its size doesn't matter
|
| 38 |
+
# as long as it's a multiple of batch_size
|
| 39 |
+
k = self.batch_size * self.BATCH_COUNT
|
| 40 |
+
indices = random.choices(range(len(self.loaders)), self.ratios, k=k)
|
| 41 |
+
try:
|
| 42 |
+
batch = [_pooled_next(iters[i], pool[i]) for i in indices[: self.batch_size]]
|
| 43 |
+
except StopIteration:
|
| 44 |
+
break
|
| 45 |
+
indices = indices[self.batch_size :]
|
| 46 |
+
yield batch
|
CatVTON/densepose/data/dataset_mapper.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
|
| 4 |
+
# pyre-unsafe
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Any, Dict, List, Tuple
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from detectron2.data import MetadataCatalog
|
| 12 |
+
from detectron2.data import detection_utils as utils
|
| 13 |
+
from detectron2.data import transforms as T
|
| 14 |
+
from detectron2.layers import ROIAlign
|
| 15 |
+
from detectron2.structures import BoxMode
|
| 16 |
+
from detectron2.utils.file_io import PathManager
|
| 17 |
+
|
| 18 |
+
from densepose.structures import DensePoseDataRelative, DensePoseList, DensePoseTransformData
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_augmentation(cfg, is_train):
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
result = utils.build_augmentation(cfg, is_train)
|
| 24 |
+
if is_train:
|
| 25 |
+
random_rotation = T.RandomRotation(
|
| 26 |
+
cfg.INPUT.ROTATION_ANGLES, expand=False, sample_style="choice"
|
| 27 |
+
)
|
| 28 |
+
result.append(random_rotation)
|
| 29 |
+
logger.info("DensePose-specific augmentation used in training: " + str(random_rotation))
|
| 30 |
+
return result
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class DatasetMapper:
|
| 34 |
+
"""
|
| 35 |
+
A customized version of `detectron2.data.DatasetMapper`
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, cfg, is_train=True):
|
| 39 |
+
self.augmentation = build_augmentation(cfg, is_train)
|
| 40 |
+
|
| 41 |
+
# fmt: off
|
| 42 |
+
self.img_format = cfg.INPUT.FORMAT
|
| 43 |
+
self.mask_on = (
|
| 44 |
+
cfg.MODEL.MASK_ON or (
|
| 45 |
+
cfg.MODEL.DENSEPOSE_ON
|
| 46 |
+
and cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS)
|
| 47 |
+
)
|
| 48 |
+
self.keypoint_on = cfg.MODEL.KEYPOINT_ON
|
| 49 |
+
self.densepose_on = cfg.MODEL.DENSEPOSE_ON
|
| 50 |
+
assert not cfg.MODEL.LOAD_PROPOSALS, "not supported yet"
|
| 51 |
+
# fmt: on
|
| 52 |
+
if self.keypoint_on and is_train:
|
| 53 |
+
# Flip only makes sense in training
|
| 54 |
+
self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
|
| 55 |
+
else:
|
| 56 |
+
self.keypoint_hflip_indices = None
|
| 57 |
+
|
| 58 |
+
if self.densepose_on:
|
| 59 |
+
densepose_transform_srcs = [
|
| 60 |
+
MetadataCatalog.get(ds).densepose_transform_src
|
| 61 |
+
for ds in cfg.DATASETS.TRAIN + cfg.DATASETS.TEST
|
| 62 |
+
]
|
| 63 |
+
assert len(densepose_transform_srcs) > 0
|
| 64 |
+
# TODO: check that DensePose transformation data is the same for
|
| 65 |
+
# all the datasets. Otherwise one would have to pass DB ID with
|
| 66 |
+
# each entry to select proper transformation data. For now, since
|
| 67 |
+
# all DensePose annotated data uses the same data semantics, we
|
| 68 |
+
# omit this check.
|
| 69 |
+
densepose_transform_data_fpath = PathManager.get_local_path(densepose_transform_srcs[0])
|
| 70 |
+
self.densepose_transform_data = DensePoseTransformData.load(
|
| 71 |
+
densepose_transform_data_fpath
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.is_train = is_train
|
| 75 |
+
|
| 76 |
+
def __call__(self, dataset_dict):
|
| 77 |
+
"""
|
| 78 |
+
Args:
|
| 79 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
dict: a format that builtin models in detectron2 accept
|
| 83 |
+
"""
|
| 84 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
| 85 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
| 86 |
+
utils.check_image_size(dataset_dict, image)
|
| 87 |
+
|
| 88 |
+
image, transforms = T.apply_transform_gens(self.augmentation, image)
|
| 89 |
+
image_shape = image.shape[:2] # h, w
|
| 90 |
+
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))
|
| 91 |
+
|
| 92 |
+
if not self.is_train:
|
| 93 |
+
dataset_dict.pop("annotations", None)
|
| 94 |
+
return dataset_dict
|
| 95 |
+
|
| 96 |
+
for anno in dataset_dict["annotations"]:
|
| 97 |
+
if not self.mask_on:
|
| 98 |
+
anno.pop("segmentation", None)
|
| 99 |
+
if not self.keypoint_on:
|
| 100 |
+
anno.pop("keypoints", None)
|
| 101 |
+
|
| 102 |
+
# USER: Implement additional transformations if you have other types of data
|
| 103 |
+
# USER: Don't call transpose_densepose if you don't need
|
| 104 |
+
annos = [
|
| 105 |
+
self._transform_densepose(
|
| 106 |
+
utils.transform_instance_annotations(
|
| 107 |
+
obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
|
| 108 |
+
),
|
| 109 |
+
transforms,
|
| 110 |
+
)
|
| 111 |
+
for obj in dataset_dict.pop("annotations")
|
| 112 |
+
if obj.get("iscrowd", 0) == 0
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
if self.mask_on:
|
| 116 |
+
self._add_densepose_masks_as_segmentation(annos, image_shape)
|
| 117 |
+
|
| 118 |
+
instances = utils.annotations_to_instances(annos, image_shape, mask_format="bitmask")
|
| 119 |
+
densepose_annotations = [obj.get("densepose") for obj in annos]
|
| 120 |
+
if densepose_annotations and not all(v is None for v in densepose_annotations):
|
| 121 |
+
instances.gt_densepose = DensePoseList(
|
| 122 |
+
densepose_annotations, instances.gt_boxes, image_shape
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
dataset_dict["instances"] = instances[instances.gt_boxes.nonempty()]
|
| 126 |
+
return dataset_dict
|
| 127 |
+
|
| 128 |
+
def _transform_densepose(self, annotation, transforms):
|
| 129 |
+
if not self.densepose_on:
|
| 130 |
+
return annotation
|
| 131 |
+
|
| 132 |
+
# Handle densepose annotations
|
| 133 |
+
is_valid, reason_not_valid = DensePoseDataRelative.validate_annotation(annotation)
|
| 134 |
+
if is_valid:
|
| 135 |
+
densepose_data = DensePoseDataRelative(annotation, cleanup=True)
|
| 136 |
+
densepose_data.apply_transform(transforms, self.densepose_transform_data)
|
| 137 |
+
annotation["densepose"] = densepose_data
|
| 138 |
+
else:
|
| 139 |
+
# logger = logging.getLogger(__name__)
|
| 140 |
+
# logger.debug("Could not load DensePose annotation: {}".format(reason_not_valid))
|
| 141 |
+
DensePoseDataRelative.cleanup_annotation(annotation)
|
| 142 |
+
# NOTE: annotations for certain instances may be unavailable.
|
| 143 |
+
# 'None' is accepted by the DensePostList data structure.
|
| 144 |
+
annotation["densepose"] = None
|
| 145 |
+
return annotation
|
| 146 |
+
|
| 147 |
+
def _add_densepose_masks_as_segmentation(
|
| 148 |
+
self, annotations: List[Dict[str, Any]], image_shape_hw: Tuple[int, int]
|
| 149 |
+
):
|
| 150 |
+
for obj in annotations:
|
| 151 |
+
if ("densepose" not in obj) or ("segmentation" in obj):
|
| 152 |
+
continue
|
| 153 |
+
# DP segmentation: torch.Tensor [S, S] of float32, S=256
|
| 154 |
+
segm_dp = torch.zeros_like(obj["densepose"].segm)
|
| 155 |
+
segm_dp[obj["densepose"].segm > 0] = 1
|
| 156 |
+
segm_h, segm_w = segm_dp.shape
|
| 157 |
+
bbox_segm_dp = torch.tensor((0, 0, segm_h - 1, segm_w - 1), dtype=torch.float32)
|
| 158 |
+
# image bbox
|
| 159 |
+
x0, y0, x1, y1 = (
|
| 160 |
+
v.item() for v in BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS)
|
| 161 |
+
)
|
| 162 |
+
segm_aligned = (
|
| 163 |
+
ROIAlign((y1 - y0, x1 - x0), 1.0, 0, aligned=True)
|
| 164 |
+
.forward(segm_dp.view(1, 1, *segm_dp.shape), bbox_segm_dp)
|
| 165 |
+
.squeeze()
|
| 166 |
+
)
|
| 167 |
+
image_mask = torch.zeros(*image_shape_hw, dtype=torch.float32)
|
| 168 |
+
image_mask[y0:y1, x0:x1] = segm_aligned
|
| 169 |
+
# segmentation for BitMask: np.array [H, W] of bool
|
| 170 |
+
obj["segmentation"] = image_mask >= 0.5
|
CatVTON/densepose/data/datasets/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
|
| 5 |
+
from . import builtin # ensure the builtin datasets are registered
|
| 6 |
+
|
| 7 |
+
__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
|
CatVTON/densepose/data/datasets/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (571 Bytes). View file
|
|
|
CatVTON/densepose/data/datasets/__pycache__/builtin.cpython-311.pyc
ADDED
|
Binary file (805 Bytes). View file
|
|
|
CatVTON/densepose/data/datasets/__pycache__/chimpnsee.cpython-311.pyc
ADDED
|
Binary file (1.46 kB). View file
|
|
|
CatVTON/densepose/data/datasets/__pycache__/coco.cpython-311.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
CatVTON/densepose/data/datasets/__pycache__/dataset_type.cpython-311.pyc
ADDED
|
Binary file (601 Bytes). View file
|
|
|
CatVTON/densepose/data/datasets/__pycache__/lvis.cpython-311.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
CatVTON/densepose/data/datasets/builtin.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
# pyre-unsafe
|
| 4 |
+
from .chimpnsee import register_dataset as register_chimpnsee_dataset
|
| 5 |
+
from .coco import BASE_DATASETS as BASE_COCO_DATASETS
|
| 6 |
+
from .coco import DATASETS as COCO_DATASETS
|
| 7 |
+
from .coco import register_datasets as register_coco_datasets
|
| 8 |
+
from .lvis import DATASETS as LVIS_DATASETS
|
| 9 |
+
from .lvis import register_datasets as register_lvis_datasets
|
| 10 |
+
|
| 11 |
+
DEFAULT_DATASETS_ROOT = "datasets"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
register_coco_datasets(COCO_DATASETS, DEFAULT_DATASETS_ROOT)
|
| 15 |
+
register_coco_datasets(BASE_COCO_DATASETS, DEFAULT_DATASETS_ROOT)
|
| 16 |
+
register_lvis_datasets(LVIS_DATASETS, DEFAULT_DATASETS_ROOT)
|
| 17 |
+
|
| 18 |
+
register_chimpnsee_dataset(DEFAULT_DATASETS_ROOT) # pyre-ignore[19]
|