dqj5182 commited on
Commit
5732928
·
1 Parent(s): 6db77af
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +399 -0
  2. README.md +201 -13
  3. data/MOW/__pycache__/dataset.cpython-38.pyc +0 -0
  4. data/MOW/dataset.py +156 -0
  5. data/dataset.py +40 -0
  6. demo.py +122 -0
  7. demo_video.py +132 -0
  8. lib/core/__pycache__/config.cpython-38.pyc +0 -0
  9. lib/core/__pycache__/logger.cpython-38.pyc +0 -0
  10. lib/core/config.py +93 -0
  11. lib/core/logger.py +55 -0
  12. lib/models/__pycache__/model.cpython-38.pyc +0 -0
  13. lib/models/backbone/__pycache__/backbone_hamer_style.cpython-38.pyc +0 -0
  14. lib/models/backbone/__pycache__/resnet.cpython-38.pyc +0 -0
  15. lib/models/backbone/__pycache__/vit.cpython-38.pyc +0 -0
  16. lib/models/backbone/backbone_hamer_style.py +273 -0
  17. lib/models/backbone/fpn.py +282 -0
  18. lib/models/backbone/hrnet.py +518 -0
  19. lib/models/backbone/resnet.py +95 -0
  20. lib/models/backbone/vit.py +33 -0
  21. lib/models/decoder/__pycache__/decoder_hamer_style.cpython-38.pyc +0 -0
  22. lib/models/decoder/decoder_hamer_style.py +637 -0
  23. lib/models/model.py +100 -0
  24. lib/utils/__pycache__/contact_utils.cpython-38.pyc +0 -0
  25. lib/utils/__pycache__/eval_utils.cpython-38.pyc +0 -0
  26. lib/utils/__pycache__/func_utils.cpython-38.pyc +0 -0
  27. lib/utils/__pycache__/human_models.cpython-38.pyc +0 -0
  28. lib/utils/__pycache__/log_utils.cpython-38.pyc +0 -0
  29. lib/utils/__pycache__/mano_utils.cpython-38.pyc +0 -0
  30. lib/utils/__pycache__/mesh_utils.cpython-38.pyc +0 -0
  31. lib/utils/__pycache__/preprocessing.cpython-38.pyc +0 -0
  32. lib/utils/__pycache__/train_utils.cpython-38.pyc +0 -0
  33. lib/utils/__pycache__/transforms.cpython-38.pyc +0 -0
  34. lib/utils/__pycache__/vis_utils.cpython-38.pyc +0 -0
  35. lib/utils/contact_utils.py +55 -0
  36. lib/utils/demo_utils.py +105 -0
  37. lib/utils/eval_utils.py +50 -0
  38. lib/utils/func_utils.py +65 -0
  39. lib/utils/human_models.py +49 -0
  40. lib/utils/log_utils.py +12 -0
  41. lib/utils/mano_utils.py +136 -0
  42. lib/utils/mesh_utils.py +74 -0
  43. lib/utils/preprocessing.py +330 -0
  44. lib/utils/smplx/LICENSE +58 -0
  45. lib/utils/smplx/README.md +186 -0
  46. lib/utils/smplx/examples/demo.py +180 -0
  47. lib/utils/smplx/examples/demo_layers.py +181 -0
  48. lib/utils/smplx/examples/vis_flame_vertices.py +92 -0
  49. lib/utils/smplx/examples/vis_mano_vertices.py +99 -0
  50. lib/utils/smplx/setup.py +79 -0
LICENSE ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More_considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+ Section 1 -- Definitions.
71
+
72
+ a. Adapted Material means material subject to Copyright and Similar
73
+ Rights that is derived from or based upon the Licensed Material
74
+ and in which the Licensed Material is translated, altered,
75
+ arranged, transformed, or otherwise modified in a manner requiring
76
+ permission under the Copyright and Similar Rights held by the
77
+ Licensor. For purposes of this Public License, where the Licensed
78
+ Material is a musical work, performance, or sound recording,
79
+ Adapted Material is always produced where the Licensed Material is
80
+ synched in timed relation with a moving image.
81
+
82
+ b. Adapter's License means the license You apply to Your Copyright
83
+ and Similar Rights in Your contributions to Adapted Material in
84
+ accordance with the terms and conditions of this Public License.
85
+
86
+ c. Copyright and Similar Rights means copyright and/or similar rights
87
+ closely related to copyright including, without limitation,
88
+ performance, broadcast, sound recording, and Sui Generis Database
89
+ Rights, without regard to how the rights are labeled or
90
+ categorized. For purposes of this Public License, the rights
91
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
92
+ Rights.
93
+ d. Effective Technological Measures means those measures that, in the
94
+ absence of proper authority, may not be circumvented under laws
95
+ fulfilling obligations under Article 11 of the WIPO Copyright
96
+ Treaty adopted on December 20, 1996, and/or similar international
97
+ agreements.
98
+
99
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
100
+ any other exception or limitation to Copyright and Similar Rights
101
+ that applies to Your use of the Licensed Material.
102
+
103
+ f. Licensed Material means the artistic or literary work, database,
104
+ or other material to which the Licensor applied this Public
105
+ License.
106
+
107
+ g. Licensed Rights means the rights granted to You subject to the
108
+ terms and conditions of this Public License, which are limited to
109
+ all Copyright and Similar Rights that apply to Your use of the
110
+ Licensed Material and that the Licensor has authority to license.
111
+
112
+ h. Licensor means the individual(s) or entity(ies) granting rights
113
+ under this Public License.
114
+
115
+ i. NonCommercial means not primarily intended for or directed towards
116
+ commercial advantage or monetary compensation. For purposes of
117
+ this Public License, the exchange of the Licensed Material for
118
+ other material subject to Copyright and Similar Rights by digital
119
+ file-sharing or similar means is NonCommercial provided there is
120
+ no payment of monetary compensation in connection with the
121
+ exchange.
122
+
123
+ j. Share means to provide material to the public by any means or
124
+ process that requires permission under the Licensed Rights, such
125
+ as reproduction, public display, public performance, distribution,
126
+ dissemination, communication, or importation, and to make material
127
+ available to the public including in ways that members of the
128
+ public may access the material from a place and at a time
129
+ individually chosen by them.
130
+
131
+ k. Sui Generis Database Rights means rights other than copyright
132
+ resulting from Directive 96/9/EC of the European Parliament and of
133
+ the Council of 11 March 1996 on the legal protection of databases,
134
+ as amended and/or succeeded, as well as other essentially
135
+ equivalent rights anywhere in the world.
136
+
137
+ l. You means the individual or entity exercising the Licensed Rights
138
+ under this Public License. Your has a corresponding meaning.
139
+
140
+ Section 2 -- Scope.
141
+
142
+ a. License grant.
143
+
144
+ 1. Subject to the terms and conditions of this Public License,
145
+ the Licensor hereby grants You a worldwide, royalty-free,
146
+ non-sublicensable, non-exclusive, irrevocable license to
147
+ exercise the Licensed Rights in the Licensed Material to:
148
+
149
+ a. reproduce and Share the Licensed Material, in whole or
150
+ in part, for NonCommercial purposes only; and
151
+
152
+ b. produce, reproduce, and Share Adapted Material for
153
+ NonCommercial purposes only.
154
+
155
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
156
+ Exceptions and Limitations apply to Your use, this Public
157
+ License does not apply, and You do not need to comply with
158
+ its terms and conditions.
159
+
160
+ 3. Term. The term of this Public License is specified in Section
161
+ 6(a).
162
+
163
+ 4. Media and formats; technical modifications allowed. The
164
+ Licensor authorizes You to exercise the Licensed Rights in
165
+ all media and formats whether now known or hereafter created,
166
+ and to make technical modifications necessary to do so. The
167
+ Licensor waives and/or agrees not to assert any right or
168
+ authority to forbid You from making technical modifications
169
+ necessary to exercise the Licensed Rights, including
170
+ technical modifications necessary to circumvent Effective
171
+ Technological Measures. For purposes of this Public License,
172
+ simply making modifications authorized by this Section 2(a)
173
+ (4) never produces Adapted Material.
174
+
175
+ 5. Downstream recipients.
176
+
177
+ a. Offer from the Licensor -- Licensed Material. Every
178
+ recipient of the Licensed Material automatically
179
+ receives an offer from the Licensor to exercise the
180
+ Licensed Rights under the terms and conditions of this
181
+ Public License.
182
+
183
+ b. No downstream restrictions. You may not offer or impose
184
+ any additional or different terms or conditions on, or
185
+ apply any Effective Technological Measures to, the
186
+ Licensed Material if doing so restricts exercise of the
187
+ Licensed Rights by any recipient of the Licensed
188
+ Material.
189
+
190
+ 6. No endorsement. Nothing in this Public License constitutes or
191
+ may be construed as permission to assert or imply that You
192
+ are, or that Your use of the Licensed Material is, connected
193
+ with, or sponsored, endorsed, or granted official status by,
194
+ the Licensor or others designated to receive attribution as
195
+ provided in Section 3(a)(1)(A)(i).
196
+
197
+ b. Other rights.
198
+
199
+ 1. Moral rights, such as the right of integrity, are not
200
+ licensed under this Public License, nor are publicity,
201
+ privacy, and/or other similar personality rights; however, to
202
+ the extent possible, the Licensor waives and/or agrees not to
203
+ assert any such rights held by the Licensor to the limited
204
+ extent necessary to allow You to exercise the Licensed
205
+ Rights, but not otherwise.
206
+
207
+ 2. Patent and trademark rights are not licensed under this
208
+ Public License.
209
+
210
+ 3. To the extent possible, the Licensor waives any right to
211
+ collect royalties from You for the exercise of the Licensed
212
+ Rights, whether directly or through a collecting society
213
+ under any voluntary or waivable statutory or compulsory
214
+ licensing scheme. In all other cases the Licensor expressly
215
+ reserves any right to collect such royalties, including when
216
+ the Licensed Material is used other than for NonCommercial
217
+ purposes.
218
+
219
+ Section 3 -- License Conditions.
220
+
221
+ Your exercise of the Licensed Rights is expressly made subject to the
222
+ following conditions.
223
+
224
+ a. Attribution.
225
+
226
+ 1. If You Share the Licensed Material (including in modified
227
+ form), You must:
228
+
229
+ a. retain the following if it is supplied by the Licensor
230
+ with the Licensed Material:
231
+
232
+ i. identification of the creator(s) of the Licensed
233
+ Material and any others designated to receive
234
+ attribution, in any reasonable manner requested by
235
+ the Licensor (including by pseudonym if
236
+ designated);
237
+
238
+ ii. a copyright notice;
239
+
240
+ iii. a notice that refers to this Public License;
241
+
242
+ iv. a notice that refers to the disclaimer of
243
+ warranties;
244
+
245
+ v. a URI or hyperlink to the Licensed Material to the
246
+ extent reasonably practicable;
247
+
248
+ b. indicate if You modified the Licensed Material and
249
+ retain an indication of any previous modifications; and
250
+
251
+ c. indicate the Licensed Material is licensed under this
252
+ Public License, and include the text of, or the URI or
253
+ hyperlink to, this Public License.
254
+
255
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
256
+ reasonable manner based on the medium, means, and context in
257
+ which You Share the Licensed Material. For example, it may be
258
+ reasonable to satisfy the conditions by providing a URI or
259
+ hyperlink to a resource that includes the required
260
+ information.
261
+
262
+ 3. If requested by the Licensor, You must remove any of the
263
+ information required by Section 3(a)(1)(A) to the extent
264
+ reasonably practicable.
265
+
266
+ 4. If You Share Adapted Material You produce, the Adapter's
267
+ License You apply must not prevent recipients of the Adapted
268
+ Material from complying with this Public License.
269
+
270
+ Section 4 -- Sui Generis Database Rights.
271
+
272
+ Where the Licensed Rights include Sui Generis Database Rights that
273
+ apply to Your use of the Licensed Material:
274
+
275
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276
+ to extract, reuse, reproduce, and Share all or a substantial
277
+ portion of the contents of the database for NonCommercial purposes
278
+ only;
279
+
280
+ b. if You include all or a substantial portion of the database
281
+ contents in a database in which You have Sui Generis Database
282
+ Rights, then the database in which You have Sui Generis Database
283
+ Rights (but not its individual contents) is Adapted Material; and
284
+
285
+ c. You must comply with the conditions in Section 3(a) if You Share
286
+ all or a substantial portion of the contents of the database.
287
+
288
+ For the avoidance of doubt, this Section 4 supplements and does not
289
+ replace Your obligations under this Public License where the Licensed
290
+ Rights include other Copyright and Similar Rights.
291
+
292
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293
+
294
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304
+
305
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314
+
315
+ c. The disclaimer of warranties and limitation of liability provided
316
+ above shall be interpreted in a manner that, to the extent
317
+ possible, most closely approximates an absolute disclaimer and
318
+ waiver of all liability.
319
+
320
+ Section 6 -- Term and Termination.
321
+
322
+ a. This Public License applies for the term of the Copyright and
323
+ Similar Rights licensed here. However, if You fail to comply with
324
+ this Public License, then Your rights under this Public License
325
+ terminate automatically.
326
+
327
+ b. Where Your right to use the Licensed Material has terminated under
328
+ Section 6(a), it reinstates:
329
+
330
+ 1. automatically as of the date the violation is cured, provided
331
+ it is cured within 30 days of Your discovery of the
332
+ violation; or
333
+
334
+ 2. upon express reinstatement by the Licensor.
335
+
336
+ For the avoidance of doubt, this Section 6(b) does not affect any
337
+ right the Licensor may have to seek remedies for Your violations
338
+ of this Public License.
339
+
340
+ c. For the avoidance of doubt, the Licensor may also offer the
341
+ Licensed Material under separate terms or conditions or stop
342
+ distributing the Licensed Material at any time; however, doing so
343
+ will not terminate this Public License.
344
+
345
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346
+ License.
347
+
348
+ Section 7 -- Other Terms and Conditions.
349
+
350
+ a. The Licensor shall not be bound by any additional or different
351
+ terms or conditions communicated by You unless expressly agreed.
352
+
353
+ b. Any arrangements, understandings, or agreements regarding the
354
+ Licensed Material not stated herein are separate from and
355
+ independent of the terms and conditions of this Public License.
356
+
357
+ Section 8 -- Interpretation.
358
+
359
+ a. For the avoidance of doubt, this Public License does not, and
360
+ shall not be interpreted to, reduce, limit, restrict, or impose
361
+ conditions on any use of the Licensed Material that could lawfully
362
+ be made without permission under this Public License.
363
+
364
+ b. To the extent possible, if any provision of this Public License is
365
+ deemed unenforceable, it shall be automatically reformed to the
366
+ minimum extent necessary to make it enforceable. If the provision
367
+ cannot be reformed, it shall be severed from this Public License
368
+ without affecting the enforceability of the remaining terms and
369
+ conditions.
370
+
371
+ c. No term or condition of this Public License will be waived and no
372
+ failure to comply consented to unless expressly agreed to by the
373
+ Licensor.
374
+
375
+ d. Nothing in this Public License constitutes or may be interpreted
376
+ as a limitation upon, or waiver of, any privileges and immunities
377
+ that apply to the Licensor or You, including from the legal
378
+ processes of any jurisdiction or authority.
379
+
380
+ =======================================================================
381
+
382
+ Creative Commons is not a party to its public
383
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
384
+ its public licenses to material it publishes and in those instances
385
+ will be considered the “Licensor.” The text of the Creative Commons
386
+ public licenses is dedicated to the public domain under the CC0 Public
387
+ Domain Dedication. Except for the limited purpose of indicating that
388
+ material is shared under a Creative Commons public license or as
389
+ otherwise permitted by the Creative Commons policies published at
390
+ creativecommons.org/policies, Creative Commons does not authorize the
391
+ use of the trademark "Creative Commons" or any other trademark or logo
392
+ of Creative Commons without its prior written consent including,
393
+ without limitation, in connection with any unauthorized modifications
394
+ to any of its public licenses or any other arrangements,
395
+ understandings, or agreements concerning use of licensed material. For
396
+ the avoidance of doubt, this paragraph does not form part of the
397
+ public licenses.
398
+
399
+ Creative Commons may be contacted at creativecommons.org.
README.md CHANGED
@@ -1,13 +1,201 @@
1
- ---
2
- title: HACO
3
- emoji: 🦀
4
- colorFrom: yellow
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.29.1
8
- app_file: app.py
9
- pinned: false
10
- license: cc-by-nc-sa-4.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # HACO: Learning Dense Hand Contact Estimation <br> from Imbalanced Data
4
+
5
+ <b>[Daniel Sungho Jung](https://dqj5182.github.io/)</b>, <b>[Kyoung Mu Lee](https://cv.snu.ac.kr/index.php/~kmlee/)</b>
6
+
7
+ <p align="center">
8
+ <img src="asset/logo_cvlab.png" height=55>
9
+ </p>
10
+
11
+ <b>Seoul National University</b>
12
+
13
+ <a>![Python 3.8+](https://img.shields.io/badge/Python-3.8%2B-brightgreen.svg)</a>
14
+ <a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
15
+ [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-CC%20BY--NC%204.0-lightgrey.svg)](https://creativecommons.org/licenses/by-nc/4.0/)
16
+ <a href='https://haco-release.github.io/'><img src='https://img.shields.io/badge/Project_Page-HACO-green' alt='Project Page'></a>
17
+ <a href="https://arxiv.org/pdf/2505.11152"><img src='https://img.shields.io/badge/Paper-HACO-blue' alt='Paper PDF'></a>
18
+ <a href="https://arxiv.org/abs/2505.11152"><img src='https://img.shields.io/badge/arXiv-HACO-red' alt='Paper PDF'></a>
19
+
20
+
21
+ <h2>ArXiv 2025</h2>
22
+
23
+ <img src="./asset/teaser.png" alt="Logo" width="75%">
24
+
25
+ </div>
26
+
27
+ _**HACO** is a framework for **dense hand contact estimation** that addresses **class and spatial imbalance issues** in training on large-scale datasets. Based on **14 datasets** that span **hand-object**, **hand-hand**, **hand-scene**, and **hand-body interaction**, we build a powerful model that learns dense hand contact in diverse scenarios._
28
+
29
+
30
+
31
+ ## Installation
32
+ * We recommend you to use an [Anaconda](https://www.anaconda.com/) virtual environment. Install PyTorch >=1.11.0 and Python >= 3.8.0. Our latest HACO model is tested on Python 3.8.20, PyTorch 1.11.0, CUDA 11.3.
33
+ * Setup the environment.
34
+ ```
35
+ # Initialize conda environment
36
+ conda create -n haco python=3.8 -y
37
+ conda activate haco
38
+
39
+ # Install PyTorch
40
+ conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
41
+
42
+ # Install all remaining packages
43
+ pip install -r requirements.txt
44
+ ```
45
+ * Download our checkpoints from [OneDrive](https://1drv.ms/u/c/bf7e2a9a100f1dba/Ef18aU5ItbFDgW1sSv3P0l0BGTzN6PlsCnm0q5ecpTWIfQ?e=Y40qsN).
46
+
47
+
48
+
49
+ ## Quick demo (Image)
50
+ To run HACO on demo images using the [Mediapipe](https://ai.google.dev/edge/mediapipe/solutions/guide) hand detector, please run:
51
+ ```
52
+ python demo.py --backbone {BACKBONE_TYPE} --checkpoint {CKPT_PATH} --input_path {INPUT_PATH}
53
+ ```
54
+
55
+ For example,
56
+ ```
57
+ # ViT-H (Default, HaMeR initialized) backbone
58
+ python demo.py --backbone hamer --checkpoint release_checkpoint/haco_final_hamer_checkpoint.ckpt --input_path asset/example_images
59
+
60
+ # ViT-B (ImageNet initialized) backbone
61
+ python demo.py --backbone vit-b-16 --checkpoint release_checkpoint/haco_final_vit_b_checkpoint.ckpt --input_path asset/example_images
62
+ ```
63
+
64
+ > Note: The demo includes post-processing to reduce noise in small or sparse contact areas.
65
+
66
+ ## Quick demo (Video)
67
+ Before the demo, please download example videos from [OneDrive](https://1drv.ms/u/c/bf7e2a9a100f1dba/ERsk_D-EubxBi1Usu2bW2hABwy9nxzRxAHutXDxmv85TLw?e=rIjOI7) and save at `asset/example_videos`.<br>
68
+
69
+ To run HACO on demo videos using the [Mediapipe](https://ai.google.dev/edge/mediapipe/solutions/guide) hand detector, please run:
70
+ ```
71
+ python demo_video.py --backbone {BACKBONE_TYPE} --checkpoint {CKPT_PATH} --input_path {INPUT_PATH}
72
+ ```
73
+
74
+ For example,
75
+ ```
76
+ # ViT-H (Default, HaMeR initialized) backbone
77
+ python demo_video.py --backbone hamer --checkpoint release_checkpoint/haco_final_hamer_checkpoint.ckpt --input_path asset/example_videos
78
+
79
+ # ViT-B (ImageNet initialized) backbone
80
+ python demo_video.py --backbone vit-b-16 --checkpoint release_checkpoint/haco_final_vit_b_checkpoint.ckpt --input_path asset/example_videos
81
+ ```
82
+
83
+ > Note: The demo includes post-processing for both spatial smoothing of small contact areas and temporal smoothing across frames to ensure stable contact predictions and hand detections.
84
+
85
+
86
+ ## Data
87
+ You need to follow directory structure of the `data` and `release_checkpoint` as below.
88
+ ```
89
+ ${ROOT}
90
+ |-- data
91
+ | |-- base_data
92
+ | | |-- demo_data
93
+ | | | |-- hand_landmarker.task
94
+ | | |-- human_models
95
+ | | | |-- mano
96
+ | | | | |-- MANO_LEFT.pkl
97
+ | | | | |-- MANO_RIGHT.pkl
98
+ | | | | |-- V_regressor_84.npy
99
+ | | | | |-- V_regressor_336.npy
100
+ | | |-- pretrained_models
101
+ | | | |-- hamer
102
+ | | | |-- handoccnet
103
+ | | | |-- hrnet
104
+ | | | |-- pose2pose
105
+ | |-- MOW
106
+ | | |-- data
107
+ | | | |-- images
108
+ | | | |-- masks
109
+ | | | |-- models
110
+ | | | |-- poses.json
111
+ | | | |-- watertight_models
112
+ | | |-- preprocessed_data
113
+ | | | |-- test
114
+ | | | | |-- contact_data
115
+ | | |-- splits
116
+ | | |-- dataset.py
117
+ |-- release_checkpoint
118
+ ```
119
+ * Download base_data from [OneDrive](https://1drv.ms/u/c/bf7e2a9a100f1dba/EUmlgxCPqwpEvIhma80VZsoBnHrIPXzbsmJzoQpP-saj-A?e=fSxPEi).
120
+ * Download [MOW](https://zhec.github.io/rhoi/) data from GitHub ([images](https://github.com/ZheC/MOW), [models](https://github.com/ZheC/MOW), [poses.json](https://github.com/ZheC/MOW)) and OneDrive ([masks](https://1drv.ms/u/c/bf7e2a9a100f1dba/Ef2YhwccS4tPt1WrAAP4-iMBjcaSUgawDMnf_HDpqoTeNw?e=eQYJ4e), [watertight_models](https://1drv.ms/u/c/bf7e2a9a100f1dba/EW5YXeXtk3NBnX9PcvJtGIABj_9c1FW2RdrcppDgRzqHhg?e=ryUqCf), [preprocessed_data](https://1drv.ms/u/c/bf7e2a9a100f1dba/ESkqLhHk9gFHo4HH2uA9akABgYuS2wLgWfr4YJMRmagezQ?e=DoGFso), [splits](https://1drv.ms/u/c/bf7e2a9a100f1dba/EW60jCPiuNNOjkmCUdqlBbEBact_Ums22dwBoQoFMkUV6w?e=2lxpJd)). For GitHub data, you can directly download them by running:
121
+ ```
122
+ bash scripts/download_official_mow.sh
123
+ ```
124
+ * Download initial checkpoints by running:
125
+ ```
126
+ bash scripts/download_initial_checkpoints.sh
127
+ ```
128
+
129
+ ## Running HACO
130
+ ### Train
131
+ TBA by the June, 2025.
132
+
133
+ ### Test
134
+ To evaluate HACO on [MOW](https://github.com/ZheC/MOW) dataset, please run:
135
+ ```
136
+ python test.py --backbone {BACKBONE_TYPE} --checkpoint {CKPT_PATH}
137
+ ```
138
+
139
+ For example,
140
+ ```
141
+ # ViT-H (Default, HaMeR initialized) backbone
142
+ python test.py --backbone hamer --checkpoint release_checkpoint/haco_final_hamer_checkpoint.ckpt
143
+
144
+ # ViT-L (ImageNet initialized) backbone
145
+ python test.py --backbone vit-l-16 --checkpoint release_checkpoint/haco_final_vit_l_checkpoint.ckpt
146
+
147
+ # ViT-B (ImageNet initialized) backbone
148
+ python test.py --backbone vit-b-16 --checkpoint release_checkpoint/haco_final_vit_b_checkpoint.ckpt
149
+
150
+ # ViT-S (ImageNet initialized) backbone
151
+ python test.py --backbone vit-s-16 --checkpoint release_checkpoint/haco_final_vit_s_checkpoint.ckpt
152
+
153
+ # FPN (HandOccNet initialized) backbone
154
+ python test.py --backbone handoccnet --checkpoint release_checkpoint/haco_final_handoccnet_checkpoint.ckpt
155
+
156
+ # HRNet-W48 (ImageNet initialized) backbone
157
+ python test.py --backbone hrnet-w48 --checkpoint release_checkpoint/haco_final_hrnet_w48_checkpoint.ckpt
158
+
159
+ # HRNet-W32 (ImageNet initialized) backbone
160
+ python test.py --backbone hrnet-w32 --checkpoint release_checkpoint/haco_final_hrnet_w32_checkpoint.ckpt
161
+
162
+ # ResNet-152 (ImageNet initialized) backbone
163
+ python test.py --backbone resnet-152 --checkpoint release_checkpoint/haco_final_resnet_152_checkpoint.ckpt
164
+
165
+ # ResNet-101 (ImageNet initialized) backbone
166
+ python test.py --backbone resnet-101 --checkpoint release_checkpoint/haco_final_resnet_101_checkpoint.ckpt
167
+
168
+ # ResNet-50 (ImageNet initialized) backbone
169
+ python test.py --backbone resnet-50 --checkpoint release_checkpoint/haco_final_resnet_50_checkpoint.ckpt
170
+
171
+ # ResNet-34 (ImageNet initialized) backbone
172
+ python test.py --backbone resnet-34 --checkpoint release_checkpoint/haco_final_resnet_34_checkpoint.ckpt
173
+
174
+ # ResNet-18 (ImageNet initialized) backbone
175
+ python test.py --backbone resnet-18 --checkpoint release_checkpoint/haco_final_resnet_18_checkpoint.ckpt
176
+ ```
177
+
178
+
179
+
180
+ ## Technical Q&A
181
+ * ImportError: cannot import name 'bool' from 'numpy': Please just comment out the line "from numpy import bool, int, float, complex, object, unicode, str, nan, inf".
182
+ * `np.int` was a deprecated alias for the builtin `int`. To avoid this error in existing code, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information: Please refer to [here](https://github.com/scikit-optimize/scikit-optimize/issues/1171)
183
+
184
+
185
+ ## Acknowledgement
186
+ We thank:
187
+ * [DECO](https://openaccess.thecvf.com/content/ICCV2023/papers/Tripathi_DECO_Dense_Estimation_of_3D_Human-Scene_Contact_In_The_Wild_ICCV_2023_paper.pdf) for human-scene contact estimation.
188
+ * [CB Loss](https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf) for inspiration on VCB Loss.
189
+ * [HaMeR](https://openaccess.thecvf.com/content/CVPR2024/papers/Pavlakos_Reconstructing_Hands_in_3D_with_Transformers_CVPR_2024_paper.pdf) for Transformer-based regression architecture.
190
+
191
+
192
+
193
+ ## Reference
194
+ ```
195
+ @article{jung2025haco,
196
+ title = {Learning Dense Hand Contact Estimation from Imbalanced Data},
197
+ author = {Jung, Daniel Sungho and Lee, Kyoung Mu},
198
+ journal = {arXiv preprint arXiv:2505.11152},
199
+ year = {2025}
200
+ }
201
+ ```
data/MOW/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (6.14 kB). View file
 
data/MOW/dataset.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import trimesh
5
+ import numpy as np
6
+ import point_cloud_utils as pcu
7
+
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+ from torchvision.transforms import Normalize
11
+
12
+ from lib.core.config import cfg
13
+ from lib.utils.preprocessing import augmentation_contact, process_human_model_output_orig, mask2bbox
14
+ from lib.utils.func_utils import load_img
15
+ from lib.utils.mesh_utils import center_vertices, load_obj_nr
16
+ from lib.utils.contact_utils import get_ho_contact_and_offset
17
+ from lib.utils.human_models import mano
18
+
19
+
20
+
21
+ class MOW(Dataset):
22
+ def __init__(self, transform, data_split):
23
+ super(MOW, self).__init__()
24
+ self.__dict__.update(locals())
25
+
26
+ self.transfrom = transform
27
+ dataset_name = 'mow'
28
+
29
+ self.data_split = data_split
30
+ self.root_path = root_path = 'data/MOW'
31
+
32
+ self.data_dir = os.path.join(self.root_path, 'data')
33
+ self.split_dir = os.path.join(self.root_path, 'splits') # This inherits IHOI
34
+ self.watertight_obj_model_dir = os.path.join(self.data_dir, 'watertight_models')
35
+ os.makedirs(self.watertight_obj_model_dir, exist_ok=True)
36
+
37
+ with open(os.path.join(self.data_dir, 'poses.json'), 'r') as f:
38
+ annos = json.load(f)
39
+
40
+ self.db = {}
41
+ for anno in annos:
42
+ self.db[anno['image_id']] = anno
43
+ del annos
44
+
45
+ self.split = {'train': np.load('data/MOW/splits/mow_train.npy').tolist(), 'test': np.load('data/MOW/splits/mow_test.npy').tolist()}
46
+ self.length = len(self.split[data_split])
47
+
48
+ self.use_preprocessed_data = True
49
+ self.use_preprocessed_watertight_mesh = True
50
+ self.contact_data_path = os.path.join(root_path, 'preprocessed_data', data_split, 'contact_data')
51
+ os.makedirs(self.contact_data_path, exist_ok=True)
52
+
53
+ def __len__(self):
54
+ return self.length
55
+
56
+ def __getitem__(self, index):
57
+ sample_id = self.split[self.data_split][index]
58
+ ann = self.db[sample_id]
59
+ image_id = ann['image_id']
60
+
61
+ img_path = os.path.join(self.data_dir, 'images', f'{image_id}.jpg')
62
+ orig_img = load_img(img_path)
63
+
64
+ mask_ho_path = os.path.join(self.data_dir, 'masks/both', f'{image_id}.jpg')
65
+ mask_ho = (cv2.imread(mask_ho_path) > 128)[:, :, 0]
66
+ bbox_ho = mask2bbox(mask_ho, expansion_factor=cfg.DATASET.ho_bbox_expand_ratio)
67
+
68
+
69
+ ############################### PROCESS CROP AND AUGMENTATION ################################
70
+ # Crop image
71
+ img, img2bb_trans, bb2img_trans, rot, do_flip, color_scale = augmentation_contact(orig_img.copy(), bbox_ho, self.data_split, enforce_flip=False)
72
+ crop_img = img.copy()
73
+
74
+ # Transform for 3D HMR
75
+ if ('resnet' in cfg.MODEL.backbone_type or 'hrnet' in cfg.MODEL.backbone_type or 'handoccnet' in cfg.MODEL.backbone_type):
76
+ img = self.transform(img.astype(np.float32)/255.0)
77
+ elif (cfg.MODEL.backbone_type in ['hamer']) or ('vit' in cfg.MODEL.backbone_type):
78
+ normalize_img = Normalize(mean=cfg.MODEL.img_mean, std=cfg.MODEL.img_std)
79
+ img = img.transpose(2, 0, 1) / 255.0
80
+ img = normalize_img(torch.from_numpy(img)).float()
81
+ else:
82
+ raise NotImplementedError
83
+ ############################### PROCESS CROP AND AUGMENTATION ################################
84
+
85
+
86
+ mano_valid = np.ones((1), dtype=np.float32)
87
+
88
+
89
+ if not self.use_preprocessed_data:
90
+ hand_t = ann['hand_t']
91
+ hand_pose = ann['hand_pose']
92
+ hand_R = ann['hand_R']
93
+ hand_s = ann['hand_s']
94
+ hand_trans = ann['trans']
95
+
96
+ obj_instance = ann['obj_url'].split('/')[-1].split('.obj')[0]
97
+ obj_rest_mesh_path = os.path.join(self.data_dir, 'models', f'{obj_instance}.obj')
98
+ obj_R = np.array(ann['R']).reshape(3, 3)
99
+ obj_t = np.array(ann['t']).reshape((1, 3))
100
+ obj_s = np.array(ann['s'], dtype=np.float32)
101
+ obj_name = ann['obj_name']
102
+
103
+ mano_param = {'pose': np.array(hand_pose), 'shape': np.zeros(1), 'trans': np.array(hand_trans), 'hand_type': 'right'}
104
+ mano_mesh_cam, mano_joint_cam, mano_pose, mano_shape, mano_trans = process_human_model_output_orig(mano_param, {}) # mano_mesh_cam is exactly same with output.vertices in official MOW
105
+
106
+ mano_mesh_cam = (mano_mesh_cam @ np.array(hand_R).reshape(3, 3))
107
+ mano_mesh_cam += np.array(hand_t)[:, None].transpose(1, 0)
108
+ mano_mesh_cam *= np.array(hand_s) # mano_mesh_cam is exactly same with hand.vertices in official MOW
109
+ hand_mesh = trimesh.Trimesh(mano_mesh_cam, mano.watertight_face['right'])
110
+
111
+ obj_rest_verts, obj_rest_faces = load_obj_nr(obj_rest_mesh_path)
112
+ obj_rest_verts, obj_rest_faces = obj_rest_verts.detach().cpu().numpy(), obj_rest_faces.detach().cpu().numpy()
113
+ obj_rest_mesh = trimesh.Trimesh(obj_rest_verts, obj_rest_faces)
114
+
115
+ # Make object mesh watertight
116
+ watertight_obj_model_path = os.path.join(self.watertight_obj_model_dir, f'{obj_instance}.obj')
117
+
118
+ if self.use_preprocessed_watertight_mesh and os.path.exists(watertight_obj_model_path):
119
+ mesh_obj_watertight = trimesh.load(watertight_obj_model_path)
120
+
121
+ # post-process
122
+ trimesh.repair.fix_normals(mesh_obj_watertight)
123
+ trimesh.repair.fix_inversion(mesh_obj_watertight)
124
+ trimesh.repair.fill_holes(mesh_obj_watertight)
125
+
126
+ obj_rest_mesh = mesh_obj_watertight
127
+ else:
128
+ print('Building new watertight mesh!!!!')
129
+ resolution = 50_000
130
+ obj_rest_mesh.vertices, obj_rest_mesh.faces = pcu.make_mesh_watertight(obj_rest_mesh.vertices, obj_rest_mesh.faces, resolution)
131
+ if not os.path.exists(watertight_obj_model_path):
132
+ _ = obj_rest_mesh.export(watertight_obj_model_path)
133
+
134
+ obj_rest_verts, obj_rest_faces = center_vertices(obj_rest_mesh.vertices, obj_rest_mesh.faces)
135
+ obj_verts = np.dot(obj_rest_verts, obj_R)
136
+ obj_verts += obj_t
137
+ obj_verts *= obj_s
138
+ obj_mesh = trimesh.Trimesh(obj_verts, obj_rest_faces)
139
+
140
+ # Contact data
141
+ contact_h, obj_coord_c, contact_valid, inter_coord_valid = get_ho_contact_and_offset(hand_mesh, obj_mesh, cfg.MODEL.c_thres_in_the_wild)
142
+ contact_h = contact_h.astype(np.float32)
143
+ contact_data = dict(contact_h=contact_h)
144
+
145
+ if True:
146
+ np.save(os.path.join(self.contact_data_path, f'{sample_id}.npy'), contact_h)
147
+ else:
148
+ contact_h = np.load(os.path.join(self.contact_data_path, f'{sample_id}.npy')).astype(np.float32)
149
+ contact_data = dict(contact_h=contact_h)
150
+
151
+
152
+ input_data = dict(image=img)
153
+ targets_data = dict(contact_data=contact_data)
154
+ meta_info = dict(sample_id=sample_id, orig_img=orig_img, mano_valid=mano_valid)
155
+
156
+ return dict(input_data=input_data, targets_data=targets_data, meta_info=meta_info)
data/dataset.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ from torch.utils.data.dataset import Dataset
4
+
5
+
6
+ class MultipleDatasets(Dataset):
7
+ def __init__(self, dbs, make_same_len=True):
8
+ self.dbs = dbs
9
+ self.db_num = len(self.dbs)
10
+ self.max_db_data_num = max([len(db) for db in dbs])
11
+ self.db_len_cumsum = np.cumsum([len(db) for db in dbs])
12
+ self.make_same_len = make_same_len
13
+
14
+ def __len__(self):
15
+ # all dbs have the same length
16
+ if self.make_same_len:
17
+ return self.max_db_data_num * self.db_num
18
+ # each db has different length
19
+ else:
20
+ return sum([len(db) for db in self.dbs])
21
+
22
+ def __getitem__(self, index):
23
+ if self.make_same_len:
24
+ db_idx = index // self.max_db_data_num
25
+ data_idx = index % self.max_db_data_num
26
+ if data_idx >= len(self.dbs[db_idx]) * (self.max_db_data_num // len(self.dbs[db_idx])): # last batch: random sampling
27
+ data_idx = random.randint(0,len(self.dbs[db_idx])-1)
28
+ else: # before last batch: use modular
29
+ data_idx = data_idx % len(self.dbs[db_idx])
30
+ else:
31
+ for i in range(self.db_num):
32
+ if index < self.db_len_cumsum[i]:
33
+ db_idx = i
34
+ break
35
+ if db_idx == 0:
36
+ data_idx = index
37
+ else:
38
+ data_idx = index - self.db_len_cumsum[db_idx-1]
39
+
40
+ return self.dbs[db_idx][data_idx]
demo.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import argparse
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
+ import mediapipe as mp
9
+ from mediapipe.tasks.python import vision
10
+ from mediapipe.tasks.python import BaseOptions
11
+
12
+ from lib.core.config import cfg, update_config
13
+ from lib.models.model import HACO
14
+ from lib.utils.human_models import mano
15
+ from lib.utils.contact_utils import get_contact_thres
16
+ from lib.utils.vis_utils import ContactRenderer, draw_landmarks_on_image
17
+ from lib.utils.preprocessing import augmentation_contact
18
+ from lib.utils.demo_utils import remove_small_contact_components
19
+
20
+
21
+ parser = argparse.ArgumentParser(description='Demo HACO')
22
+ parser.add_argument('--backbone', type=str, default='hamer', choices=['hamer', 'vit-l-16', 'vit-b-16', 'vit-s-16', 'handoccnet', 'hrnet-w48', 'hrnet-w32', 'resnet-152', 'resnet-101', 'resnet-50', 'resnet-34', 'resnet-18'], help='backbone model')
23
+ parser.add_argument('--checkpoint', type=str, default='', help='model path for demo')
24
+ parser.add_argument('--input_path', type=str, default='asset/example_images', help='image path for demo')
25
+ args = parser.parse_args()
26
+
27
+
28
+ # Set device as CUDA
29
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+
31
+
32
+ # Initialize directories
33
+ experiment_dir = 'experiments_demo_image'
34
+
35
+
36
+ # Load config
37
+ update_config(backbone_type=args.backbone, exp_dir=experiment_dir)
38
+
39
+
40
+ # Initialize renderer
41
+ contact_renderer = ContactRenderer()
42
+
43
+
44
+ # Load demo images
45
+ input_dir = args.input_path
46
+ images = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
47
+
48
+
49
+ # Initialize MediaPipe HandLandmarker
50
+ base_options = BaseOptions(model_asset_path=cfg.MODEL.hand_landmarker_path)
51
+ hand_options = vision.HandLandmarkerOptions(base_options=base_options, num_hands=2)
52
+ detector = vision.HandLandmarker.create_from_options(hand_options)
53
+
54
+
55
+ ############# Model #############
56
+ model = HACO().to(device)
57
+ model.eval()
58
+ ############# Model #############
59
+
60
+
61
+ # Load model checkpoint if provided
62
+ if args.checkpoint:
63
+ checkpoint = torch.load(args.checkpoint, map_location=device)
64
+ model.load_state_dict(checkpoint['state_dict'])
65
+
66
+
67
+ ############################### Demo Loop ###############################
68
+ for i, frame_name in tqdm(enumerate(images), total=len(images)):
69
+ print(f"Processing: {frame_name}")
70
+
71
+ # Load and convert image
72
+ frame_path = os.path.join(input_dir, frame_name)
73
+ frame = cv2.imread(frame_path)
74
+ orig_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
75
+ frame_name_base = os.path.splitext(frame_name)[0]
76
+
77
+ # Hand landmark detection
78
+ mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=orig_img.copy())
79
+ detection_result = detector.detect(mp_image)
80
+ annotated_image, right_hand_bbox = draw_landmarks_on_image(orig_img.copy(), detection_result)
81
+
82
+ if right_hand_bbox is None:
83
+ print(f"Skipping {frame_name} - no hand detected.")
84
+ continue
85
+
86
+ print(f"Frame {i}: Right hand bbox: {right_hand_bbox}")
87
+
88
+ # Image preprocessing
89
+ crop_img, img2bb_trans, bb2img_trans, rot, do_flip, color_scale = augmentation_contact(orig_img.copy(), right_hand_bbox, 'test', enforce_flip=False)
90
+
91
+ # Convert to model input format
92
+ if args.backbone in ['handoccnet'] or 'resnet' in cfg.MODEL.backbone_type or 'hrnet' in cfg.MODEL.backbone_type:
93
+ from torchvision import transforms
94
+ img_tensor = transforms.ToTensor()(crop_img.astype(np.float32) / 255.0)
95
+ elif args.backbone in ['hamer'] or 'vit' in cfg.MODEL.backbone_type:
96
+ from torchvision.transforms import Normalize
97
+ normalize = Normalize(mean=cfg.MODEL.img_mean, std=cfg.MODEL.img_std)
98
+ img_tensor = crop_img.transpose(2, 0, 1) / 255.0
99
+ img_tensor = normalize(torch.from_numpy(img_tensor)).float()
100
+ else:
101
+ raise NotImplementedError(f"Unsupported backbone: {args.backbone}")
102
+
103
+ ############# Run model #############
104
+ with torch.no_grad():
105
+ outputs = model({'input': {'image': img_tensor[None].to(device)}}, mode="test")
106
+ ############# Run model #############
107
+
108
+ # Save result
109
+ os.makedirs('outputs', exist_ok=True)
110
+ os.makedirs('outputs/detection', exist_ok=True)
111
+ os.makedirs('outputs/crop_img', exist_ok=True)
112
+ os.makedirs('outputs/contact', exist_ok=True)
113
+
114
+ cv2.imwrite(f'outputs/detection/{frame_name_base}.png', cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR))
115
+ cv2.imwrite(f'outputs/crop_img/{frame_name_base}.png', crop_img[..., ::-1])
116
+
117
+ eval_thres = get_contact_thres(args.backbone)
118
+ contact_mask = (outputs['contact_out'][0] > eval_thres).detach().cpu().numpy()
119
+ contact_mask = remove_small_contact_components(contact_mask, faces=mano.watertight_face['right'], min_size=20)
120
+ contact_rendered = contact_renderer.render_contact(crop_img[..., ::-1], contact_mask)
121
+ cv2.imwrite(f'outputs/contact/{frame_name_base}.png', contact_rendered)
122
+ ############################### Demo Loop ###############################
demo_video.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import argparse
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
+ import mediapipe as mp
9
+ from mediapipe.tasks.python import vision
10
+ from mediapipe.tasks.python import BaseOptions
11
+
12
+ from lib.core.config import cfg, update_config
13
+ from lib.models.model import HACO
14
+ from lib.utils.human_models import mano
15
+ from lib.utils.contact_utils import get_contact_thres
16
+ from lib.utils.vis_utils import ContactRenderer, draw_landmarks_on_image
17
+ from lib.utils.preprocessing import augmentation_contact
18
+ from lib.utils.demo_utils import smooth_bbox, smooth_contact_mask, remove_small_contact_components, initialize_video_writer, extract_frames_with_hand, find_longest_continuous_segment
19
+
20
+
21
+ parser = argparse.ArgumentParser(description='Demo HACO')
22
+ parser.add_argument('--backbone', type=str, default='hamer', choices=['hamer', 'vit-l-16', 'vit-b-16', 'vit-s-16', 'handoccnet', 'hrnet-w48', 'hrnet-w32', 'resnet-152', 'resnet-101', 'resnet-50', 'resnet-34', 'resnet-18'], help='backbone model')
23
+ parser.add_argument('--checkpoint', type=str, default='', help='model path for demo')
24
+ parser.add_argument('--input_path', type=str, default='asset/example_videos', help='video path for demo')
25
+ args = parser.parse_args()
26
+
27
+
28
+ # Set device as CUDA
29
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+
31
+
32
+ # Initialize directories
33
+ experiment_dir = 'experiments_demo_video'
34
+
35
+
36
+ # Load config
37
+ update_config(backbone_type=args.backbone, exp_dir=experiment_dir)
38
+
39
+
40
+ # Initialize renderer
41
+ contact_renderer = ContactRenderer()
42
+
43
+
44
+ # Load demo videos
45
+ input_dir = args.input_path
46
+ video_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.mp4', '.avi', '.mov'))]
47
+
48
+
49
+ # Initialize MediaPipe HandLandmarker
50
+ base_options = BaseOptions(model_asset_path=cfg.MODEL.hand_landmarker_path)
51
+ hand_options = vision.HandLandmarkerOptions(base_options=base_options, num_hands=2)
52
+ detector = vision.HandLandmarker.create_from_options(hand_options)
53
+
54
+
55
+ ############# Model #############
56
+ model = HACO().to(device)
57
+ model.eval()
58
+ ############# Model #############
59
+
60
+
61
+ # Load model checkpoint if provided
62
+ if args.checkpoint:
63
+ checkpoint = torch.load(args.checkpoint, map_location=device)
64
+ model.load_state_dict(checkpoint['state_dict'])
65
+
66
+
67
+ ############################### Demo Loop ###############################
68
+ for i, video_name in tqdm(enumerate(video_files), total=len(video_files)):
69
+ print(f"Processing: {video_name}")
70
+
71
+ # Organize input and output path
72
+ video_path = os.path.join(input_dir, video_name)
73
+ os.makedirs("outputs_video", exist_ok=True)
74
+ output_path = os.path.join("outputs_video", f"{os.path.splitext(video_name)[0]}_out.mp4")
75
+
76
+ # Load and convert video
77
+ cap = cv2.VideoCapture(video_path)
78
+ fps = cap.get(cv2.CAP_PROP_FPS)
79
+ fps = 30 if fps == 0 or np.isnan(fps) else fps
80
+
81
+ # Extract meaningful video segment
82
+ frames_with_hand = extract_frames_with_hand(cap, detector)
83
+ longest_segment = find_longest_continuous_segment(frames_with_hand)
84
+
85
+ if not longest_segment:
86
+ print(f"No hand detected in any continuous segment for {video_name}")
87
+ continue
88
+
89
+ writer = None
90
+ smoothed_bbox = None
91
+ smoothed_contact = None
92
+
93
+ for _, frame, bbox in longest_segment:
94
+ # Image preprocessing
95
+ smoothed_bbox = smooth_bbox(smoothed_bbox, bbox, alpha=0.8)
96
+ orig_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
97
+ crop_img, img2bb_trans, bb2img_trans, rot, do_flip, color_scale = augmentation_contact(orig_img.copy(), smoothed_bbox, 'test', enforce_flip=False, bkg_color='white')
98
+
99
+ # Convert to model input format
100
+ if args.backbone in ['handoccnet'] or 'resnet' in cfg.MODEL.backbone_type or 'hrnet' in cfg.MODEL.backbone_type:
101
+ from torchvision import transforms
102
+ img_tensor = transforms.ToTensor()(crop_img.astype(np.float32) / 255.0)
103
+ elif args.backbone in ['hamer'] or 'vit' in cfg.MODEL.backbone_type:
104
+ from torchvision.transforms import Normalize
105
+ normalize = Normalize(mean=cfg.MODEL.img_mean, std=cfg.MODEL.img_std)
106
+ img_tensor = crop_img.transpose(2, 0, 1) / 255.0
107
+ img_tensor = normalize(torch.from_numpy(img_tensor)).float()
108
+ else:
109
+ raise NotImplementedError(f"Unsupported backbone: {args.backbone}")
110
+
111
+ ############# Run model #############
112
+ with torch.no_grad():
113
+ outputs = model({'input': {'image': img_tensor[None].to(device)}}, mode="test")
114
+ ############# Run model #############
115
+
116
+ # Save result
117
+ eval_thres = get_contact_thres(args.backbone)
118
+ raw_contact = (outputs['contact_out'][0] > eval_thres).detach().cpu().numpy()
119
+ smoothed_contact = smooth_contact_mask(smoothed_contact, raw_contact, alpha=0.8)
120
+ contact_mask = smoothed_contact > 0.5
121
+ contact_mask = remove_small_contact_components(contact_mask, faces=mano.watertight_face['right'], min_size=20)
122
+ contact_rendered = contact_renderer.render_contact(crop_img, contact_mask, mode='demo')
123
+
124
+ if writer is None:
125
+ ch, cw = contact_rendered.shape[:2]
126
+ writer = initialize_video_writer(output_path, fps, (cw, ch))
127
+
128
+ writer.write(cv2.cvtColor(contact_rendered, cv2.COLOR_RGB2BGR))
129
+
130
+ if writer:
131
+ writer.release()
132
+ ############################### Demo Loop ###############################
lib/core/__pycache__/config.cpython-38.pyc ADDED
Binary file (3.07 kB). View file
 
lib/core/__pycache__/logger.cpython-38.pyc ADDED
Binary file (2.03 kB). View file
 
lib/core/config.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from easydict import EasyDict as edict
5
+
6
+ from lib.core.logger import ColorLogger
7
+ from lib.utils.log_utils import init_dirs
8
+
9
+
10
+ cfg = edict()
11
+
12
+
13
+ """ Dataset """
14
+ cfg.DATASET = edict()
15
+ cfg.DATASET.train_name = ['ObMan', 'DexYCB', 'HO3D', 'MOW', 'H2O3D', 'HOI4D', 'H2O', 'ARCTIC', 'InterHand26M', 'HIC', 'PROX', 'RICH', 'Decaf', 'Hi4D']
16
+ cfg.DATASET.test_name = 'MOW' # ONLY TEST ONE DATASET AT A TIME
17
+ cfg.DATASET.workers = 2
18
+ cfg.DATASET.random_seed = 314
19
+ cfg.DATASET.ho_bbox_expand_ratio = 1.3
20
+ cfg.DATASET.hand_bbox_expand_ratio = 1.3
21
+ cfg.DATASET.ho_big_bbox_expand_ratio = 2.0
22
+ cfg.DATASET.hand_scene_bbox_expand_ratio = 2.5
23
+ cfg.DATASET.obj_bbox_expand_ratio = 1.5
24
+
25
+
26
+ """ Model - HMR """
27
+ cfg.MODEL = edict()
28
+ cfg.MODEL.seed = 314
29
+ cfg.MODEL.input_img_shape = (256, 256)
30
+ cfg.MODEL.img_mean = (0.485, 0.456, 0.406)
31
+ cfg.MODEL.img_std = (0.229, 0.224, 0.225)
32
+ # MANO
33
+ cfg.MODEL.human_model_path = 'data/base_data/human_models'
34
+ # Contact
35
+ cfg.MODEL.contact_means_path = 'data/base_data/contact_data/dexycb/contact_means_dexycb.npy'
36
+ # Backbone
37
+ cfg.MODEL.backbone_type = ''
38
+ cfg.MODEL.hamer_backbone_pretrained_path = 'data/base_data/pretrained_models/hamer/hamer.ckpt'
39
+ cfg.MODEL.hrnet_w32_backbone_config_path = 'data/base_data/pretrained_models/hrnet/cls_hrnet_w32_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
40
+ cfg.MODEL.hrnet_w32_backbone_pretrained_path = 'data/base_data/pretrained_models/hrnet/hrnet_w32-36af842e.pth'
41
+ cfg.MODEL.hrnet_w48_backbone_config_path = 'data/base_data/pretrained_models/hrnet/cls_hrnet_w48_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
42
+ cfg.MODEL.hrnet_w48_backbone_pretrained_path = 'data/base_data/pretrained_models/hrnet/hrnet_w48-8ef0771d.pth'
43
+ cfg.MODEL.handoccnet_backbone_pretrained_path = 'data/base_data/pretrained_models/handoccnet/snapshot_demo.pth.tar'
44
+ # Multi-level joint regressor
45
+ cfg.MODEL.V_regressor_336_path = 'data/base_data/human_models/mano/V_regressor_336.npy'
46
+ cfg.MODEL.V_regressor_84_path = 'data/base_data/human_models/mano/V_regressor_84.npy'
47
+ # Hand Detector
48
+ cfg.MODEL.hand_landmarker_path = 'data/base_data/demo_data/hand_landmarker.task'
49
+
50
+
51
+ """ Train Detail """
52
+ cfg.TRAIN = edict()
53
+ cfg.TRAIN.batch = 24
54
+ cfg.TRAIN.epoch = 10
55
+ cfg.TRAIN.lr = 1e-5
56
+ cfg.TRAIN.weight_decay = 0.0001
57
+ cfg.TRAIN.milestones = (5, 10)
58
+ cfg.TRAIN.step_size = 10
59
+ cfg.TRAIN.gamma = 0.9
60
+ cfg.TRAIN.betas = (0.9, 0.95)
61
+ cfg.TRAIN.print_freq = 5
62
+
63
+ cfg.TRAIN.loss_weight = 1.0
64
+
65
+
66
+ """ Test Detail """
67
+ cfg.TEST = edict()
68
+ cfg.TEST.batch = 1
69
+
70
+
71
+ """ CAMERA """
72
+ cfg.CAMERA = edict()
73
+
74
+ np.random.seed(cfg.DATASET.random_seed)
75
+ torch.manual_seed(cfg.DATASET.random_seed)
76
+ torch.backends.cudnn.benchmark = True
77
+ logger = None
78
+
79
+
80
+ def update_config(backbone_type='', exp_dir='', ckpt_path=''):
81
+ if backbone_type == '':
82
+ backbone_type = 'hamer'
83
+ cfg.MODEL.backbone_type = backbone_type
84
+
85
+ global logger
86
+ log_dir = os.path.join(exp_dir, 'log')
87
+ try:
88
+ init_dirs([log_dir])
89
+ logger = ColorLogger(log_dir)
90
+ logger.info("Logger initialized successfully!")
91
+ except Exception as e:
92
+ print(f"Failed to initialize logger: {e}")
93
+ logger = None
lib/core/logger.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os.path as osp
3
+ import warnings
4
+
5
+
6
+ warnings.filterwarnings("ignore")
7
+
8
+ OK = '\033[92m'
9
+ WARNING = '\033[93m'
10
+ FAIL = '\033[91m'
11
+ END = '\033[0m'
12
+
13
+ PINK = '\033[95m'
14
+ BLUE = '\033[94m'
15
+ GREEN = OK
16
+ RED = FAIL
17
+ WHITE = END
18
+ YELLOW = WARNING
19
+
20
+
21
+ class ColorLogger():
22
+ def __init__(self, log_dir, log_name='log.txt'):
23
+ # set log
24
+ self._logger = logging.getLogger(log_name)
25
+ self._logger.setLevel(logging.INFO)
26
+ log_file = osp.join(log_dir, log_name)
27
+ file_log = logging.FileHandler(log_file, mode='a')
28
+ file_log.setLevel(logging.INFO)
29
+ console_log = logging.StreamHandler()
30
+ console_log.setLevel(logging.INFO)
31
+ file_formatter = logging.Formatter(
32
+ "%(asctime)s %(message)s",
33
+ "%m-%d %H:%M:%S")
34
+ console_formatter = logging.Formatter(
35
+ "{}%(asctime)s{} %(message)s".format(GREEN, END),
36
+ "%m-%d %H:%M:%S")
37
+ file_log.setFormatter(file_formatter)
38
+ console_log.setFormatter(console_formatter)
39
+ self._logger.addHandler(file_log)
40
+ self._logger.addHandler(console_log)
41
+
42
+ def debug(self, msg):
43
+ self._logger.debug(str(msg))
44
+
45
+ def info(self, msg):
46
+ self._logger.info(str(msg))
47
+
48
+ def warning(self, msg):
49
+ self._logger.warning(WARNING + 'WRN: ' + str(msg) + END)
50
+
51
+ def critical(self, msg):
52
+ self._logger.critical(RED + 'CRI: ' + str(msg) + END)
53
+
54
+ def error(self, msg):
55
+ self._logger.error(RED + 'ERR: ' + str(msg) + END)
lib/models/__pycache__/model.cpython-38.pyc ADDED
Binary file (3.75 kB). View file
 
lib/models/backbone/__pycache__/backbone_hamer_style.cpython-38.pyc ADDED
Binary file (9.19 kB). View file
 
lib/models/backbone/__pycache__/resnet.cpython-38.pyc ADDED
Binary file (3.02 kB). View file
 
lib/models/backbone/__pycache__/vit.cpython-38.pyc ADDED
Binary file (1.33 kB). View file
 
lib/models/backbone/backbone_hamer_style.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint as checkpoint
5
+
6
+ import numpy as np
7
+ from functools import partial
8
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
9
+
10
+
11
+ # This module is from HaMeR (https://github.com/geopavlakos/hamer). Initial configurations follows cfg from their final model.
12
+ class ViT_HaMeR(nn.Module):
13
+ def __init__(self,
14
+ img_size=(256, 192), patch_size=16, in_chans=3, num_classes=80, embed_dim=1280, depth=32,
15
+ num_heads=16, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
16
+ drop_path_rate=0.55, hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
17
+ frozen_stages=-1, ratio=1, last_norm=True,
18
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,
19
+ ):
20
+ # Protect mutable default arguments
21
+ super(ViT_HaMeR, self).__init__()
22
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
23
+ self.num_classes = num_classes
24
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
25
+ self.frozen_stages = frozen_stages
26
+ self.use_checkpoint = use_checkpoint
27
+ self.patch_padding = patch_padding
28
+ self.freeze_attn = freeze_attn
29
+ self.freeze_ffn = freeze_ffn
30
+ self.depth = depth
31
+
32
+ if hybrid_backbone is not None:
33
+ self.patch_embed = HybridEmbed(
34
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
35
+ else:
36
+ self.patch_embed = PatchEmbed(
37
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
38
+ num_patches = self.patch_embed.num_patches
39
+
40
+ # since the pretraining model has class token
41
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
42
+
43
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
44
+
45
+ self.blocks = nn.ModuleList([
46
+ Block(
47
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
48
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
49
+ )
50
+ for i in range(depth)])
51
+
52
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
53
+
54
+ if self.pos_embed is not None:
55
+ trunc_normal_(self.pos_embed, std=.02)
56
+
57
+ self._freeze_stages()
58
+
59
+ def _freeze_stages(self):
60
+ """Freeze parameters."""
61
+ if self.frozen_stages >= 0:
62
+ self.patch_embed.eval()
63
+ for param in self.patch_embed.parameters():
64
+ param.requires_grad = False
65
+
66
+ for i in range(1, self.frozen_stages + 1):
67
+ m = self.blocks[i]
68
+ m.eval()
69
+ for param in m.parameters():
70
+ param.requires_grad = False
71
+
72
+ if self.freeze_attn:
73
+ for i in range(0, self.depth):
74
+ m = self.blocks[i]
75
+ m.attn.eval()
76
+ m.norm1.eval()
77
+ for param in m.attn.parameters():
78
+ param.requires_grad = False
79
+ for param in m.norm1.parameters():
80
+ param.requires_grad = False
81
+
82
+ if self.freeze_ffn:
83
+ self.pos_embed.requires_grad = False
84
+ self.patch_embed.eval()
85
+ for param in self.patch_embed.parameters():
86
+ param.requires_grad = False
87
+ for i in range(0, self.depth):
88
+ m = self.blocks[i]
89
+ m.mlp.eval()
90
+ m.norm2.eval()
91
+ for param in m.mlp.parameters():
92
+ param.requires_grad = False
93
+ for param in m.norm2.parameters():
94
+ param.requires_grad = False
95
+
96
+ def init_weights(self):
97
+ """Initialize the weights in backbone.
98
+ Args:
99
+ pretrained (str, optional): Path to pre-trained weights.
100
+ Defaults to None.
101
+ """
102
+ def _init_weights(m):
103
+ if isinstance(m, nn.Linear):
104
+ trunc_normal_(m.weight, std=.02)
105
+ if isinstance(m, nn.Linear) and m.bias is not None:
106
+ nn.init.constant_(m.bias, 0)
107
+ elif isinstance(m, nn.LayerNorm):
108
+ nn.init.constant_(m.bias, 0)
109
+ nn.init.constant_(m.weight, 1.0)
110
+
111
+ self.apply(_init_weights)
112
+
113
+ def get_num_layers(self):
114
+ return len(self.blocks)
115
+
116
+ @torch.jit.ignore
117
+ def no_weight_decay(self):
118
+ return {'pos_embed', 'cls_token'}
119
+
120
+ def forward_features(self, x):
121
+ B, C, H, W = x.shape
122
+ x, (Hp, Wp) = self.patch_embed(x)
123
+
124
+ if self.pos_embed is not None:
125
+ # fit for multiple GPU training
126
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
127
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
128
+
129
+ for blk in self.blocks:
130
+ if self.use_checkpoint:
131
+ x = checkpoint.checkpoint(blk, x)
132
+ else:
133
+ x = blk(x)
134
+
135
+ x = self.last_norm(x)
136
+
137
+ xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
138
+
139
+ return xp
140
+
141
+ def forward(self, x):
142
+ x = x[:,:,:,32:-32] # This is revised from HaMeR code so that this process is done within backbone, not in model.py (follows HaMeR model code)
143
+ x = self.forward_features(x)
144
+ return x
145
+
146
+ def train(self, mode=True):
147
+ """Convert the model into training mode."""
148
+ super().train(mode)
149
+ self._freeze_stages()
150
+
151
+
152
+
153
+ class Attention_for_vit(nn.Module):
154
+ def __init__(
155
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
156
+ proj_drop=0., attn_head_dim=None,):
157
+ super().__init__()
158
+ self.num_heads = num_heads
159
+ head_dim = dim // num_heads
160
+ self.dim = dim
161
+
162
+ if attn_head_dim is not None:
163
+ head_dim = attn_head_dim
164
+ all_head_dim = head_dim * self.num_heads
165
+
166
+ self.scale = qk_scale or head_dim ** -0.5
167
+
168
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
169
+
170
+ self.attn_drop = nn.Dropout(attn_drop)
171
+ self.proj = nn.Linear(all_head_dim, dim)
172
+ self.proj_drop = nn.Dropout(proj_drop)
173
+
174
+ def forward(self, x):
175
+ B, N, C = x.shape
176
+ qkv = self.qkv(x)
177
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
178
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
179
+
180
+ q = q * self.scale
181
+ attn = (q @ k.transpose(-2, -1))
182
+
183
+ attn = attn.softmax(dim=-1)
184
+ attn = self.attn_drop(attn)
185
+
186
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
187
+ x = self.proj(x)
188
+ x = self.proj_drop(x)
189
+
190
+ return x
191
+
192
+
193
+ class Mlp(nn.Module):
194
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
195
+ super().__init__()
196
+ out_features = out_features or in_features
197
+ hidden_features = hidden_features or in_features
198
+ self.fc1 = nn.Linear(in_features, hidden_features)
199
+ self.act = act_layer()
200
+ self.fc2 = nn.Linear(hidden_features, out_features)
201
+ self.drop = nn.Dropout(drop)
202
+
203
+ def forward(self, x):
204
+ x = self.fc1(x)
205
+ x = self.act(x)
206
+ x = self.fc2(x)
207
+ x = self.drop(x)
208
+ return x
209
+
210
+
211
+ class DropPath(nn.Module):
212
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
213
+ """
214
+ def __init__(self, drop_prob=None):
215
+ super(DropPath, self).__init__()
216
+ self.drop_prob = drop_prob
217
+
218
+ def forward(self, x):
219
+ return drop_path(x, self.drop_prob, self.training)
220
+
221
+ def extra_repr(self):
222
+ return 'p={}'.format(self.drop_prob)
223
+
224
+
225
+ class Block(nn.Module):
226
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
227
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
228
+ norm_layer=nn.LayerNorm, attn_head_dim=None
229
+ ):
230
+ super().__init__()
231
+
232
+ self.norm1 = norm_layer(dim)
233
+ self.attn = Attention_for_vit(
234
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
235
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
236
+ )
237
+
238
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
239
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
240
+ self.norm2 = norm_layer(dim)
241
+ mlp_hidden_dim = int(dim * mlp_ratio)
242
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
243
+
244
+ def forward(self, x):
245
+ x = x + self.drop_path(self.attn(self.norm1(x)))
246
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
247
+ return x
248
+
249
+
250
+
251
+ class PatchEmbed(nn.Module):
252
+ """ Image to Patch Embedding
253
+ """
254
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
255
+ super().__init__()
256
+ img_size = to_2tuple(img_size)
257
+ patch_size = to_2tuple(patch_size)
258
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
259
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
260
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
261
+ self.img_size = img_size
262
+ self.patch_size = patch_size
263
+ self.num_patches = num_patches
264
+
265
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
266
+
267
+ def forward(self, x, **kwargs):
268
+ B, C, H, W = x.shape
269
+ x = self.proj(x)
270
+ Hp, Wp = x.shape[2], x.shape[3]
271
+
272
+ x = x.flatten(2).transpose(1, 2)
273
+ return x, (Hp, Wp)
lib/models/backbone/fpn.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is from HandOccNet (https://github.com/namepllet/HandOccNet)
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.model_zoo as model_zoo
6
+
7
+
8
+ class BasicConv(nn.Module):
9
+ def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
10
+ super(BasicConv, self).__init__()
11
+ self.out_channels = out_planes
12
+ self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
13
+ self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
14
+ self.relu = nn.ReLU() if relu else None
15
+
16
+ def forward(self, x):
17
+ x = self.conv(x)
18
+ if self.bn is not None:
19
+ x = self.bn(x)
20
+ if self.relu is not None:
21
+ x = self.relu(x)
22
+ return x
23
+
24
+ class Flatten(nn.Module):
25
+ def forward(self, x):
26
+ return x.view(x.size(0), -1)
27
+
28
+ class ChannelGate(nn.Module):
29
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
30
+ super(ChannelGate, self).__init__()
31
+ self.gate_channels = gate_channels
32
+ self.mlp = nn.Sequential(
33
+ Flatten(),
34
+ nn.Linear(gate_channels, gate_channels // reduction_ratio),
35
+ nn.ReLU(),
36
+ nn.Linear(gate_channels // reduction_ratio, gate_channels)
37
+ )
38
+ self.pool_types = pool_types
39
+ def forward(self, x):
40
+ channel_att_sum = None
41
+ for pool_type in self.pool_types:
42
+ if pool_type=='avg':
43
+ avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
44
+ channel_att_raw = self.mlp( avg_pool )
45
+ elif pool_type=='max':
46
+ max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
47
+ channel_att_raw = self.mlp( max_pool )
48
+ elif pool_type=='lp':
49
+ lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
50
+ channel_att_raw = self.mlp( lp_pool )
51
+ elif pool_type=='lse':
52
+ # LSE pool only
53
+ lse_pool = logsumexp_2d(x)
54
+ channel_att_raw = self.mlp( lse_pool )
55
+
56
+ if channel_att_sum is None:
57
+ channel_att_sum = channel_att_raw
58
+ else:
59
+ channel_att_sum = channel_att_sum + channel_att_raw
60
+
61
+ scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
62
+ return x * scale
63
+
64
+ def logsumexp_2d(tensor):
65
+ tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
66
+ s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
67
+ outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
68
+ return outputs
69
+
70
+ class ChannelPool(nn.Module):
71
+ def forward(self, x):
72
+ return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
73
+
74
+ class SpatialGate(nn.Module):
75
+ def __init__(self):
76
+ super(SpatialGate, self).__init__()
77
+ kernel_size = 7
78
+ self.compress = ChannelPool()
79
+ self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
80
+ def forward(self, x):
81
+ x_compress = self.compress(x)
82
+ x_out = self.spatial(x_compress)
83
+ scale = F.sigmoid(x_out) # broadcasting
84
+ return x*scale, x*(1-scale)
85
+
86
+
87
+ class FPN(nn.Module):
88
+ def __init__(self, pretrained=True):
89
+ super(FPN, self).__init__()
90
+ self.in_planes = 64
91
+
92
+ resnet = resnet50(pretrained=pretrained)
93
+
94
+ self.toplayer = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) # Reduce channels
95
+
96
+ self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.leakyrelu, resnet.maxpool)
97
+ self.layer1 = nn.Sequential(resnet.layer1)
98
+ self.layer2 = nn.Sequential(resnet.layer2)
99
+ self.layer3 = nn.Sequential(resnet.layer3)
100
+ self.layer4 = nn.Sequential(resnet.layer4)
101
+
102
+ # Smooth layers
103
+ #self.smooth1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
104
+ self.smooth2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
105
+ self.smooth3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
106
+
107
+ # Lateral layers
108
+ self.latlayer1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
109
+ self.latlayer2 = nn.Conv2d( 512, 256, kernel_size=1, stride=1, padding=0)
110
+ self.latlayer3 = nn.Conv2d( 256, 256, kernel_size=1, stride=1, padding=0)
111
+
112
+ # Attention Module
113
+ self.attention_module = SpatialGate()
114
+
115
+ self.pool = nn.AvgPool2d(2, stride=2)
116
+
117
+ def _upsample_add(self, x, y):
118
+ _, _, H, W = y.size()
119
+ return F.interpolate(x, size=(H,W), mode='bilinear', align_corners=False) + y
120
+
121
+ def forward(self, x):
122
+ # Bottom-up
123
+ c1 = self.layer0(x)
124
+ c2 = self.layer1(c1)
125
+ c3 = self.layer2(c2)
126
+ c4 = self.layer3(c3)
127
+ c5 = self.layer4(c4)
128
+ # Top-down
129
+ p5 = self.toplayer(c5)
130
+ p4 = self._upsample_add(p5, self.latlayer1(c4))
131
+ p3 = self._upsample_add(p4, self.latlayer2(c3))
132
+ p2 = self._upsample_add(p3, self.latlayer3(c2))
133
+ # Smooth
134
+ #p4 = self.smooth1(p4)
135
+ p3 = self.smooth2(p3)
136
+ p2 = self.smooth3(p2)
137
+
138
+ # Attention
139
+ p2 = self.pool(p2)
140
+ primary_feats, secondary_feats = self.attention_module(p2)
141
+
142
+ return primary_feats #, secondary_feats
143
+
144
+
145
+ class ResNet(nn.Module):
146
+ def __init__(self, block, layers, num_classes=1000):
147
+ self.inplanes = 64
148
+ super(ResNet, self).__init__()
149
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
150
+ self.bn1 = nn.BatchNorm2d(64)
151
+ self.leakyrelu = nn.LeakyReLU(inplace=True)
152
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
153
+ self.layer1 = self._make_layer(block, 64, layers[0])
154
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
155
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
156
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
157
+ self.avgpool = nn.AvgPool2d(7, stride=1)
158
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
159
+
160
+ for m in self.modules():
161
+ if isinstance(m, nn.Conv2d):
162
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="leaky_relu")
163
+ elif isinstance(m, nn.BatchNorm2d):
164
+ nn.init.constant_(m.weight, 1)
165
+ nn.init.constant_(m.bias, 0)
166
+
167
+ def _make_layer(self, block, planes, blocks, stride=1):
168
+ downsample = None
169
+ if stride != 1 or self.inplanes != planes * block.expansion:
170
+ downsample = nn.Sequential(
171
+ nn.Conv2d(self.inplanes, planes * block.expansion,
172
+ kernel_size=1, stride=stride, bias=False),
173
+ nn.BatchNorm2d(planes * block.expansion))
174
+ layers = []
175
+ layers.append(block(self.inplanes, planes, stride, downsample))
176
+ self.inplanes = planes * block.expansion
177
+ for i in range(1, blocks):
178
+ layers.append(block(self.inplanes, planes))
179
+
180
+ return nn.Sequential(*layers)
181
+
182
+ def forward(self, x):
183
+ x = self.conv1(x)
184
+ x = self.bn1(x)
185
+ x = self.leakyrelu(x)
186
+ x = self.maxpool(x)
187
+
188
+ x = self.layer1(x)
189
+ x = self.layer2(x)
190
+ x = self.layer3(x)
191
+ x = self.layer4(x)
192
+
193
+ x = x.mean(3).mean(2)
194
+ x = x.view(x.size(0), -1)
195
+ x = self.fc(x)
196
+ return x
197
+
198
+
199
+ def resnet50(pretrained=False, **kwargs):
200
+ """Constructs a ResNet-50 model Encoder"""
201
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
202
+ if pretrained:
203
+ model.load_state_dict(model_zoo.load_url("https://download.pytorch.org/models/resnet50-19c8e357.pth"))
204
+ return model
205
+
206
+
207
+ def conv3x3(in_planes, out_planes, stride=1):
208
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
209
+
210
+
211
+ class BasicBlock(nn.Module):
212
+ expansion = 1
213
+
214
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
215
+ super(BasicBlock, self).__init__()
216
+ self.conv1 = conv3x3(inplanes, planes, stride)
217
+ self.bn1 = nn.BatchNorm2d(planes)
218
+ self.leakyrelu = nn.LeakyReLU(inplace=True)
219
+ self.conv2 = conv3x3(planes, planes)
220
+ self.bn2 = nn.BatchNorm2d(planes)
221
+ self.downsample = downsample
222
+ self.stride = stride
223
+
224
+ def forward(self, x):
225
+ residual = x
226
+
227
+ out = self.conv1(x)
228
+ out = self.bn1(out)
229
+ out = self.leakyrelu(out)
230
+
231
+ out = self.conv2(out)
232
+ out = self.bn2(out)
233
+
234
+ if self.downsample is not None:
235
+ residual = self.downsample(x)
236
+
237
+ out += residual
238
+ out = self.leakyrelu(out)
239
+
240
+ return out
241
+
242
+
243
+ class Bottleneck(nn.Module):
244
+ expansion = 4
245
+
246
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
247
+ super(Bottleneck, self).__init__()
248
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
249
+ self.bn1 = nn.BatchNorm2d(planes)
250
+ self.conv2 = nn.Conv2d(
251
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
252
+ )
253
+ self.bn2 = nn.BatchNorm2d(planes)
254
+ self.conv3 = nn.Conv2d(
255
+ planes, planes * self.expansion, kernel_size=1, bias=False
256
+ )
257
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
258
+ self.leakyrelu = nn.LeakyReLU(inplace=True)
259
+ self.downsample = downsample
260
+ self.stride = stride
261
+
262
+ def forward(self, x):
263
+ residual = x
264
+
265
+ out = self.conv1(x)
266
+ out = self.bn1(out)
267
+ out = self.leakyrelu(out)
268
+
269
+ out = self.conv2(out)
270
+ out = self.bn2(out)
271
+ out = self.leakyrelu(out)
272
+
273
+ out = self.conv3(out)
274
+ out = self.bn3(out)
275
+
276
+ if self.downsample is not None:
277
+ residual = self.downsample(x)
278
+
279
+ out += residual
280
+ out = self.leakyrelu(out)
281
+
282
+ return out
lib/models/backbone/hrnet.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft
3
+ # Licensed under the MIT License.
4
+ # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5
+ # Modified by Ke Sun (sunk@mail.ustc.edu.cn)
6
+ # Modified by Kevin Lin (keli@microsoft.com)
7
+ # ------------------------------------------------------------------------------
8
+
9
+ from __future__ import absolute_import
10
+ from __future__ import division
11
+ from __future__ import print_function
12
+
13
+ import os
14
+ import logging
15
+ import numpy as np
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch._utils
20
+ import torch.nn.functional as F
21
+ BN_MOMENTUM = 0.1
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def conv3x3(in_planes, out_planes, stride=1):
26
+ """3x3 convolution with padding"""
27
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
28
+ padding=1, bias=False)
29
+
30
+
31
+ class BasicBlock(nn.Module):
32
+ expansion = 1
33
+
34
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
35
+ super(BasicBlock, self).__init__()
36
+ self.conv1 = conv3x3(inplanes, planes, stride)
37
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
38
+ self.relu = nn.ReLU(inplace=True)
39
+ self.conv2 = conv3x3(planes, planes)
40
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
41
+ self.downsample = downsample
42
+ self.stride = stride
43
+
44
+ def forward(self, x):
45
+ residual = x
46
+
47
+ out = self.conv1(x)
48
+ out = self.bn1(out)
49
+ out = self.relu(out)
50
+
51
+ out = self.conv2(out)
52
+ out = self.bn2(out)
53
+
54
+ if self.downsample is not None:
55
+ residual = self.downsample(x)
56
+
57
+ out += residual
58
+ out = self.relu(out)
59
+
60
+ return out
61
+
62
+
63
+ class Bottleneck(nn.Module):
64
+ expansion = 4
65
+
66
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
67
+ super(Bottleneck, self).__init__()
68
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
69
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
70
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
71
+ padding=1, bias=False)
72
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
73
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
74
+ bias=False)
75
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
76
+ momentum=BN_MOMENTUM)
77
+ self.relu = nn.ReLU(inplace=True)
78
+ self.downsample = downsample
79
+ self.stride = stride
80
+
81
+ def forward(self, x):
82
+ residual = x
83
+
84
+ out = self.conv1(x)
85
+ out = self.bn1(out)
86
+ out = self.relu(out)
87
+
88
+ out = self.conv2(out)
89
+ out = self.bn2(out)
90
+ out = self.relu(out)
91
+
92
+ out = self.conv3(out)
93
+ out = self.bn3(out)
94
+
95
+ if self.downsample is not None:
96
+ residual = self.downsample(x)
97
+
98
+ out += residual
99
+ out = self.relu(out)
100
+
101
+ return out
102
+
103
+
104
+ class HighResolutionModule(nn.Module):
105
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
106
+ num_channels, fuse_method, multi_scale_output=True):
107
+ super(HighResolutionModule, self).__init__()
108
+ self._check_branches(
109
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
110
+
111
+ self.num_inchannels = num_inchannels
112
+ self.fuse_method = fuse_method
113
+ self.num_branches = num_branches
114
+
115
+ self.multi_scale_output = multi_scale_output
116
+
117
+ self.branches = self._make_branches(
118
+ num_branches, blocks, num_blocks, num_channels)
119
+ self.fuse_layers = self._make_fuse_layers()
120
+ self.relu = nn.ReLU(False)
121
+
122
+ def _check_branches(self, num_branches, blocks, num_blocks,
123
+ num_inchannels, num_channels):
124
+ if num_branches != len(num_blocks):
125
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
126
+ num_branches, len(num_blocks))
127
+ logger.error(error_msg)
128
+ raise ValueError(error_msg)
129
+
130
+ if num_branches != len(num_channels):
131
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
132
+ num_branches, len(num_channels))
133
+ logger.error(error_msg)
134
+ raise ValueError(error_msg)
135
+
136
+ if num_branches != len(num_inchannels):
137
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
138
+ num_branches, len(num_inchannels))
139
+ logger.error(error_msg)
140
+ raise ValueError(error_msg)
141
+
142
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
143
+ stride=1):
144
+ downsample = None
145
+ if stride != 1 or \
146
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
147
+ downsample = nn.Sequential(
148
+ nn.Conv2d(self.num_inchannels[branch_index],
149
+ num_channels[branch_index] * block.expansion,
150
+ kernel_size=1, stride=stride, bias=False),
151
+ nn.BatchNorm2d(num_channels[branch_index] * block.expansion,
152
+ momentum=BN_MOMENTUM),
153
+ )
154
+
155
+ layers = []
156
+ layers.append(block(self.num_inchannels[branch_index],
157
+ num_channels[branch_index], stride, downsample))
158
+ self.num_inchannels[branch_index] = \
159
+ num_channels[branch_index] * block.expansion
160
+ for i in range(1, num_blocks[branch_index]):
161
+ layers.append(block(self.num_inchannels[branch_index],
162
+ num_channels[branch_index]))
163
+
164
+ return nn.Sequential(*layers)
165
+
166
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
167
+ branches = []
168
+
169
+ for i in range(num_branches):
170
+ branches.append(
171
+ self._make_one_branch(i, block, num_blocks, num_channels))
172
+
173
+ return nn.ModuleList(branches)
174
+
175
+ def _make_fuse_layers(self):
176
+ if self.num_branches == 1:
177
+ return None
178
+
179
+ num_branches = self.num_branches
180
+ num_inchannels = self.num_inchannels
181
+ fuse_layers = []
182
+ for i in range(num_branches if self.multi_scale_output else 1):
183
+ fuse_layer = []
184
+ for j in range(num_branches):
185
+ if j > i:
186
+ fuse_layer.append(nn.Sequential(
187
+ nn.Conv2d(num_inchannels[j],
188
+ num_inchannels[i],
189
+ 1,
190
+ 1,
191
+ 0,
192
+ bias=False),
193
+ nn.BatchNorm2d(num_inchannels[i],
194
+ momentum=BN_MOMENTUM),
195
+ nn.Upsample(scale_factor=2**(j-i), mode='nearest')))
196
+ elif j == i:
197
+ fuse_layer.append(None)
198
+ else:
199
+ conv3x3s = []
200
+ for k in range(i-j):
201
+ if k == i - j - 1:
202
+ num_outchannels_conv3x3 = num_inchannels[i]
203
+ conv3x3s.append(nn.Sequential(
204
+ nn.Conv2d(num_inchannels[j],
205
+ num_outchannels_conv3x3,
206
+ 3, 2, 1, bias=False),
207
+ nn.BatchNorm2d(num_outchannels_conv3x3,
208
+ momentum=BN_MOMENTUM)))
209
+ else:
210
+ num_outchannels_conv3x3 = num_inchannels[j]
211
+ conv3x3s.append(nn.Sequential(
212
+ nn.Conv2d(num_inchannels[j],
213
+ num_outchannels_conv3x3,
214
+ 3, 2, 1, bias=False),
215
+ nn.BatchNorm2d(num_outchannels_conv3x3,
216
+ momentum=BN_MOMENTUM),
217
+ nn.ReLU(False)))
218
+ fuse_layer.append(nn.Sequential(*conv3x3s))
219
+ fuse_layers.append(nn.ModuleList(fuse_layer))
220
+
221
+ return nn.ModuleList(fuse_layers)
222
+
223
+ def get_num_inchannels(self):
224
+ return self.num_inchannels
225
+
226
+ def forward(self, x):
227
+ if self.num_branches == 1:
228
+ return [self.branches[0](x[0])]
229
+
230
+ for i in range(self.num_branches):
231
+ x[i] = self.branches[i](x[i])
232
+
233
+ x_fuse = []
234
+ for i in range(len(self.fuse_layers)):
235
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
236
+ for j in range(1, self.num_branches):
237
+ if i == j:
238
+ y = y + x[j]
239
+ else:
240
+ y = y + self.fuse_layers[i][j](x[j])
241
+ x_fuse.append(self.relu(y))
242
+
243
+ return x_fuse
244
+
245
+
246
+ blocks_dict = {
247
+ 'BASIC': BasicBlock,
248
+ 'BOTTLENECK': Bottleneck
249
+ }
250
+
251
+
252
+ class HighResolutionNet(nn.Module):
253
+ def __init__(self, cfg, **kwargs):
254
+ super(HighResolutionNet, self).__init__()
255
+
256
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
257
+ bias=False)
258
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
259
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
260
+ bias=False)
261
+ self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
262
+ self.relu = nn.ReLU(inplace=True)
263
+
264
+ self.stage1_cfg = cfg['MODEL']['EXTRA']['STAGE1']
265
+ num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
266
+ block = blocks_dict[self.stage1_cfg['BLOCK']]
267
+ num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
268
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
269
+ stage1_out_channel = block.expansion*num_channels
270
+
271
+ self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']
272
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
273
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
274
+ num_channels = [
275
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
276
+ self.transition1 = self._make_transition_layer(
277
+ [stage1_out_channel], num_channels)
278
+ self.stage2, pre_stage_channels = self._make_stage(
279
+ self.stage2_cfg, num_channels)
280
+
281
+ self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3']
282
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
283
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
284
+ num_channels = [
285
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
286
+ self.transition2 = self._make_transition_layer(
287
+ pre_stage_channels, num_channels)
288
+ self.stage3, pre_stage_channels = self._make_stage(
289
+ self.stage3_cfg, num_channels)
290
+
291
+ self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4']
292
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
293
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
294
+ num_channels = [
295
+ num_channels[i] * block.expansion for i in range(len(num_channels))]
296
+ self.transition3 = self._make_transition_layer(
297
+ pre_stage_channels, num_channels)
298
+ self.stage4, pre_stage_channels = self._make_stage(
299
+ self.stage4_cfg, num_channels, multi_scale_output=True)
300
+
301
+ # Classification Head
302
+ self.incre_modules, self.downsamp_modules, \
303
+ self.final_layer = self._make_head(pre_stage_channels)
304
+
305
+ self.classifier = nn.Linear(2048, 1000)
306
+
307
+ def _make_head(self, pre_stage_channels):
308
+ head_block = Bottleneck
309
+ head_channels = [32, 64, 128, 256]
310
+
311
+ # Increasing the #channels on each resolution
312
+ # from C, 2C, 4C, 8C to 128, 256, 512, 1024
313
+ incre_modules = []
314
+ for i, channels in enumerate(pre_stage_channels):
315
+ incre_module = self._make_layer(head_block,
316
+ channels,
317
+ head_channels[i],
318
+ 1,
319
+ stride=1)
320
+ incre_modules.append(incre_module)
321
+ incre_modules = nn.ModuleList(incre_modules)
322
+
323
+ # downsampling modules
324
+ downsamp_modules = []
325
+ for i in range(len(pre_stage_channels)-1):
326
+ in_channels = head_channels[i] * head_block.expansion
327
+ out_channels = head_channels[i+1] * head_block.expansion
328
+
329
+ downsamp_module = nn.Sequential(
330
+ nn.Conv2d(in_channels=in_channels,
331
+ out_channels=out_channels,
332
+ kernel_size=3,
333
+ stride=2,
334
+ padding=1),
335
+ nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM),
336
+ nn.ReLU(inplace=True)
337
+ )
338
+
339
+ downsamp_modules.append(downsamp_module)
340
+ downsamp_modules = nn.ModuleList(downsamp_modules)
341
+
342
+ final_layer = nn.Sequential(
343
+ nn.Conv2d(
344
+ in_channels=head_channels[3] * head_block.expansion,
345
+ out_channels=2048,
346
+ kernel_size=1,
347
+ stride=1,
348
+ padding=0
349
+ ),
350
+ nn.BatchNorm2d(2048, momentum=BN_MOMENTUM),
351
+ nn.ReLU(inplace=True)
352
+ )
353
+
354
+ return incre_modules, downsamp_modules, final_layer
355
+
356
+ def _make_transition_layer(
357
+ self, num_channels_pre_layer, num_channels_cur_layer):
358
+ num_branches_cur = len(num_channels_cur_layer)
359
+ num_branches_pre = len(num_channels_pre_layer)
360
+
361
+ transition_layers = []
362
+ for i in range(num_branches_cur):
363
+ if i < num_branches_pre:
364
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
365
+ transition_layers.append(nn.Sequential(
366
+ nn.Conv2d(num_channels_pre_layer[i],
367
+ num_channels_cur_layer[i],
368
+ 3,
369
+ 1,
370
+ 1,
371
+ bias=False),
372
+ nn.BatchNorm2d(
373
+ num_channels_cur_layer[i], momentum=BN_MOMENTUM),
374
+ nn.ReLU(inplace=True)))
375
+ else:
376
+ transition_layers.append(None)
377
+ else:
378
+ conv3x3s = []
379
+ for j in range(i+1-num_branches_pre):
380
+ inchannels = num_channels_pre_layer[-1]
381
+ outchannels = num_channels_cur_layer[i] \
382
+ if j == i-num_branches_pre else inchannels
383
+ conv3x3s.append(nn.Sequential(
384
+ nn.Conv2d(
385
+ inchannels, outchannels, 3, 2, 1, bias=False),
386
+ nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
387
+ nn.ReLU(inplace=True)))
388
+ transition_layers.append(nn.Sequential(*conv3x3s))
389
+
390
+ return nn.ModuleList(transition_layers)
391
+
392
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
393
+ downsample = None
394
+ if stride != 1 or inplanes != planes * block.expansion:
395
+ downsample = nn.Sequential(
396
+ nn.Conv2d(inplanes, planes * block.expansion,
397
+ kernel_size=1, stride=stride, bias=False),
398
+ nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
399
+ )
400
+
401
+ layers = []
402
+ layers.append(block(inplanes, planes, stride, downsample))
403
+ inplanes = planes * block.expansion
404
+ for i in range(1, blocks):
405
+ layers.append(block(inplanes, planes))
406
+
407
+ return nn.Sequential(*layers)
408
+
409
+ def _make_stage(self, layer_config, num_inchannels,
410
+ multi_scale_output=True):
411
+ num_modules = layer_config['NUM_MODULES']
412
+ num_branches = layer_config['NUM_BRANCHES']
413
+ num_blocks = layer_config['NUM_BLOCKS']
414
+ num_channels = layer_config['NUM_CHANNELS']
415
+ block = blocks_dict[layer_config['BLOCK']]
416
+ fuse_method = layer_config['FUSE_METHOD']
417
+
418
+ modules = []
419
+ for i in range(num_modules):
420
+ # multi_scale_output is only used last module
421
+ if not multi_scale_output and i == num_modules - 1:
422
+ reset_multi_scale_output = False
423
+ else:
424
+ reset_multi_scale_output = True
425
+
426
+ modules.append(
427
+ HighResolutionModule(num_branches,
428
+ block,
429
+ num_blocks,
430
+ num_inchannels,
431
+ num_channels,
432
+ fuse_method,
433
+ reset_multi_scale_output)
434
+ )
435
+ num_inchannels = modules[-1].get_num_inchannels()
436
+
437
+ return nn.Sequential(*modules), num_inchannels
438
+
439
+ def forward(self, x):
440
+ x = self.conv1(x)
441
+ x = self.bn1(x)
442
+ x = self.relu(x)
443
+ x = self.conv2(x)
444
+ x = self.bn2(x)
445
+ x = self.relu(x)
446
+ x = self.layer1(x)
447
+
448
+ x_list = []
449
+ for i in range(self.stage2_cfg['NUM_BRANCHES']):
450
+ if self.transition1[i] is not None:
451
+ x_list.append(self.transition1[i](x))
452
+ else:
453
+ x_list.append(x)
454
+ y_list = self.stage2(x_list)
455
+
456
+ x_list = []
457
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
458
+ if self.transition2[i] is not None:
459
+ x_list.append(self.transition2[i](y_list[-1]))
460
+ else:
461
+ x_list.append(y_list[i])
462
+ y_list = self.stage3(x_list)
463
+
464
+ x_list = []
465
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
466
+ if self.transition3[i] is not None:
467
+ x_list.append(self.transition3[i](y_list[-1]))
468
+ else:
469
+ x_list.append(y_list[i])
470
+ y_list = self.stage4(x_list)
471
+
472
+ # Classification Head
473
+ y = self.incre_modules[0](y_list[0])
474
+ for i in range(len(self.downsamp_modules)):
475
+ y = self.incre_modules[i+1](y_list[i+1]) + \
476
+ self.downsamp_modules[i](y)
477
+
478
+ y = self.final_layer(y)
479
+
480
+ # if torch._C._get_tracing_state():
481
+ # y = y.flatten(start_dim=2).mean(dim=2)
482
+ # else:
483
+ # y = F.avg_pool2d(y, kernel_size=y.size()
484
+ # [2:]).view(y.size(0), -1)
485
+
486
+ # y = self.classifier(y)
487
+
488
+ return y
489
+
490
+ def init_weights(self, pretrained='',):
491
+ logger.info('=> init weights from normal distribution')
492
+ for m in self.modules():
493
+ if isinstance(m, nn.Conv2d):
494
+ nn.init.kaiming_normal_(
495
+ m.weight, mode='fan_out', nonlinearity='relu')
496
+ elif isinstance(m, nn.BatchNorm2d):
497
+ nn.init.constant_(m.weight, 1)
498
+ nn.init.constant_(m.bias, 0)
499
+ if os.path.isfile(pretrained):
500
+ pretrained_dict = torch.load(pretrained)
501
+ logger.info('=> loading pretrained model {}'.format(pretrained))
502
+ print('=> loading pretrained model {}'.format(pretrained))
503
+ model_dict = self.state_dict()
504
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
505
+ if k in model_dict.keys()}
506
+ # for k, _ in pretrained_dict.items():
507
+ # logger.info(
508
+ # '=> loading {} pretrained model {}'.format(k, pretrained))
509
+ # print('=> loading {} pretrained model {}'.format(k, pretrained))
510
+ model_dict.update(pretrained_dict)
511
+ self.load_state_dict(model_dict)
512
+ # code.interact(local=locals())
513
+
514
+
515
+ def get_cls_net(config, pretrained, **kwargs):
516
+ model = HighResolutionNet(config, **kwargs)
517
+ model.init_weights(pretrained=pretrained)
518
+ return model
lib/models/backbone/resnet.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is from HandOccNet (https://github.com/mks0601/Hand4Whole_RELEASE/blob/main/common/nets/resnet.py)
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision.models.resnet import BasicBlock, Bottleneck
5
+
6
+
7
+ class ResNetBackbone(nn.Module):
8
+ def __init__(self, resnet_type):
9
+
10
+ resnet_spec = {18: (BasicBlock, [2, 2, 2, 2], [64, 64, 128, 256, 512], 'resnet18'),
11
+ 34: (BasicBlock, [3, 4, 6, 3], [64, 64, 128, 256, 512], 'resnet34'),
12
+ 50: (Bottleneck, [3, 4, 6, 3], [64, 256, 512, 1024, 2048], 'resnet50'),
13
+ 101: (Bottleneck, [3, 4, 23, 3], [64, 256, 512, 1024, 2048], 'resnet101'),
14
+ 152: (Bottleneck, [3, 8, 36, 3], [64, 256, 512, 1024, 2048], 'resnet152')}
15
+ block, layers, channels, name = resnet_spec[resnet_type]
16
+
17
+ self.name = name
18
+ self.inplanes = 64
19
+ super(ResNetBackbone, self).__init__()
20
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
21
+ bias=False)
22
+ self.bn1 = nn.BatchNorm2d(64)
23
+ self.relu = nn.ReLU(inplace=True)
24
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
25
+
26
+ self.layer1 = self._make_layer(block, 64, layers[0])
27
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
28
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
29
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
30
+
31
+ for m in self.modules():
32
+ if isinstance(m, nn.Conv2d):
33
+ nn.init.normal_(m.weight, mean=0, std=0.001)
34
+ elif isinstance(m, nn.BatchNorm2d):
35
+ nn.init.constant_(m.weight, 1)
36
+ nn.init.constant_(m.bias, 0)
37
+
38
+ def _make_layer(self, block, planes, blocks, stride=1):
39
+ downsample = None
40
+ if stride != 1 or self.inplanes != planes * block.expansion:
41
+ downsample = nn.Sequential(
42
+ nn.Conv2d(self.inplanes, planes * block.expansion,
43
+ kernel_size=1, stride=stride, bias=False),
44
+ nn.BatchNorm2d(planes * block.expansion),
45
+ )
46
+
47
+ layers = []
48
+ layers.append(block(self.inplanes, planes, stride, downsample))
49
+ self.inplanes = planes * block.expansion
50
+ for i in range(1, blocks):
51
+ layers.append(block(self.inplanes, planes))
52
+
53
+ return nn.Sequential(*layers)
54
+
55
+ def forward(self, x):
56
+ x = self.conv1(x)
57
+ x = self.bn1(x)
58
+ x = self.relu(x)
59
+ x = self.maxpool(x)
60
+
61
+ x = self.layer1(x)
62
+ x = self.layer2(x)
63
+ x = self.layer3(x)
64
+ x = self.layer4(x)
65
+ return x
66
+
67
+ def init_weights(self):
68
+ import torchvision.models as models
69
+
70
+ if self.name == 'resnet18':
71
+ org_resnet = models.resnet18(pretrained=True)
72
+ elif self.name == 'resnet34':
73
+ org_resnet = models.resnet34(pretrained=True)
74
+ elif self.name == 'resnet50':
75
+ org_resnet = models.resnet50(pretrained=True)
76
+ elif self.name == 'resnet101':
77
+ org_resnet = models.resnet101(pretrained=True)
78
+ elif self.name == 'resnet152':
79
+ org_resnet = models.resnet152(pretrained=True)
80
+ else:
81
+ raise ValueError(f"Unsupported model name: {self.name}")
82
+
83
+ # Drop the original fully connected layer
84
+ org_resnet.fc = None # Or you can set it to nn.Identity()
85
+
86
+ # If you're loading weights manually, extract the state_dict
87
+ org_resnet_state = org_resnet.state_dict()
88
+
89
+ # Remove FC layer weights to avoid mismatch
90
+ org_resnet_state.pop('fc.weight', None)
91
+ org_resnet_state.pop('fc.bias', None)
92
+
93
+ # Load into your model
94
+ self.load_state_dict(org_resnet_state, strict=False)
95
+ print("Initialized ResNet from torchvision with pretrained=True")
lib/models/backbone/vit.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch.nn as nn
3
+
4
+
5
+ class ViTBackbone(nn.Module):
6
+ def __init__(self, model_name='vit_base_patch16_224', pretrained=True, return_cls=False):
7
+ """
8
+ Args:
9
+ model_name (str): 'vit_base_patch16_224' or 'vit_large_patch16_224'
10
+ pretrained (bool): load pretrained weights from timm
11
+ return_cls (bool): if True, return CLS token instead of patch tokens
12
+ """
13
+ super().__init__()
14
+ self.return_cls = return_cls
15
+
16
+ # Load model with no classification head
17
+ self.vit = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
18
+
19
+ # Get dimensions
20
+ self.embed_dim = self.vit.embed_dim # 768 for B/16, 1024 for L/16
21
+ self.patch_size = self.vit.patch_embed.patch_size
22
+
23
+ def forward(self, x):
24
+ # Features includes CLS + patch tokens: [B, 1 + N, D]
25
+ x = self.vit.forward_features(x)
26
+
27
+ if self.return_cls:
28
+ return x[:, 0] # [B, D] – CLS token
29
+ else:
30
+ patch_tokens = x[:, 1:] # [B, N, D]
31
+ B, N, D = patch_tokens.shape
32
+ H = W = int(N ** 0.5)
33
+ return patch_tokens.view(B, D, H, W) # [B, H, W, D]
lib/models/decoder/__pycache__/decoder_hamer_style.cpython-38.pyc ADDED
Binary file (20.1 kB). View file
 
lib/models/decoder/decoder_hamer_style.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import numpy as np
3
+ from einops import rearrange
4
+ from inspect import isfunction
5
+ from typing import Callable, Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ import smplx
12
+ from smplx.lbs import vertices2joints
13
+ from smplx.utils import MANOOutput, to_tensor
14
+ from smplx.vertex_ids import vertex_ids
15
+
16
+ from lib.core.config import cfg
17
+ from lib.utils.human_models import mano
18
+
19
+
20
+ V_regressor_336 = np.load(cfg.MODEL.V_regressor_336_path)
21
+ V_regressor_84 = np.load(cfg.MODEL.V_regressor_84_path)
22
+
23
+
24
+ # This function is from HaMeR (https://github.com/geopavlakos/hamer).
25
+ def exists(val):
26
+ return val is not None
27
+
28
+
29
+ # This function is from HaMeR (https://github.com/geopavlakos/hamer).
30
+ def default(val, d):
31
+ if exists(val):
32
+ return val
33
+ return d() if isfunction(d) else d
34
+
35
+
36
+ # This class is from HaMeR (https://github.com/geopavlakos/hamer).
37
+ class Attention(nn.Module):
38
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
39
+ super().__init__()
40
+ inner_dim = dim_head * heads
41
+ project_out = not (heads == 1 and dim_head == dim)
42
+
43
+ self.heads = heads
44
+ self.scale = dim_head**-0.5
45
+
46
+ self.attend = nn.Softmax(dim=-1)
47
+ self.dropout = nn.Dropout(dropout)
48
+
49
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
50
+
51
+ self.to_out = (
52
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
53
+ if project_out
54
+ else nn.Identity()
55
+ )
56
+
57
+ def forward(self, x):
58
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
59
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
60
+
61
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
62
+
63
+ attn = self.attend(dots)
64
+ attn = self.dropout(attn)
65
+
66
+ out = torch.matmul(attn, v)
67
+ out = rearrange(out, "b h n d -> b n (h d)")
68
+ return self.to_out(out)
69
+
70
+
71
+ # This class is from HaMeR (https://github.com/geopavlakos/hamer).
72
+ class CrossAttention(nn.Module):
73
+ def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
74
+ super().__init__()
75
+ inner_dim = dim_head * heads
76
+ project_out = not (heads == 1 and dim_head == dim)
77
+
78
+ self.heads = heads
79
+ self.scale = dim_head**-0.5
80
+
81
+ self.attend = nn.Softmax(dim=-1)
82
+ self.dropout = nn.Dropout(dropout)
83
+
84
+ context_dim = default(context_dim, dim)
85
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
86
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
87
+
88
+ self.to_out = (
89
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
90
+ if project_out
91
+ else nn.Identity()
92
+ )
93
+
94
+ def forward(self, x, context=None):
95
+ context = default(context, x)
96
+ k, v = self.to_kv(context).chunk(2, dim=-1)
97
+ q = self.to_q(x)
98
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
99
+
100
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
101
+
102
+ attn = self.attend(dots)
103
+ attn = self.dropout(attn)
104
+
105
+ out = torch.matmul(attn, v)
106
+ out = rearrange(out, "b h n d -> b n (h d)")
107
+ return self.to_out(out)
108
+
109
+
110
+ # This class is from HaMeR (https://github.com/geopavlakos/hamer).
111
+ class FeedForward(nn.Module):
112
+ def __init__(self, dim, hidden_dim, dropout=0.0):
113
+ super().__init__()
114
+ self.net = nn.Sequential(
115
+ nn.Linear(dim, hidden_dim),
116
+ nn.GELU(),
117
+ nn.Dropout(dropout),
118
+ nn.Linear(hidden_dim, dim),
119
+ nn.Dropout(dropout),
120
+ )
121
+
122
+ def forward(self, x):
123
+ return self.net(x)
124
+
125
+
126
+ # This class is from HaMeR (https://github.com/geopavlakos/hamer).
127
+ class Transformer(nn.Module):
128
+ def __init__(
129
+ self,
130
+ dim: int,
131
+ depth: int,
132
+ heads: int,
133
+ dim_head: int,
134
+ mlp_dim: int,
135
+ dropout: float = 0.0,
136
+ norm: str = "layer",
137
+ norm_cond_dim: int = -1,
138
+ ):
139
+ super().__init__()
140
+ self.layers = nn.ModuleList([])
141
+ for _ in range(depth):
142
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
143
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
144
+ self.layers.append(
145
+ nn.ModuleList(
146
+ [
147
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
148
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
149
+ ]
150
+ )
151
+ )
152
+
153
+ def forward(self, x: torch.Tensor, *args):
154
+ for attn, ff in self.layers:
155
+ x = attn(x, *args) + x
156
+ x = ff(x, *args) + x
157
+ return x
158
+
159
+
160
+ class AdaptiveLayerNorm1D(torch.nn.Module):
161
+ def __init__(self, data_dim: int, norm_cond_dim: int):
162
+ super().__init__()
163
+ if data_dim <= 0:
164
+ raise ValueError(f"data_dim must be positive, but got {data_dim}")
165
+ if norm_cond_dim <= 0:
166
+ raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
167
+ self.norm = torch.nn.LayerNorm(
168
+ data_dim
169
+ ) # TODO: Check if elementwise_affine=True is correct
170
+ self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
171
+ torch.nn.init.zeros_(self.linear.weight)
172
+ torch.nn.init.zeros_(self.linear.bias)
173
+
174
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
175
+ # x: (batch, ..., data_dim)
176
+ # t: (batch, norm_cond_dim)
177
+ # return: (batch, data_dim)
178
+ x = self.norm(x)
179
+ alpha, beta = self.linear(t).chunk(2, dim=-1)
180
+
181
+ # Add singleton dimensions to alpha and beta
182
+ if x.dim() > 2:
183
+ alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
184
+ beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
185
+
186
+ return x * (1 + alpha) + beta
187
+
188
+
189
+ def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
190
+ if norm == "batch":
191
+ return torch.nn.BatchNorm1d(dim)
192
+ elif norm == "layer":
193
+ return torch.nn.LayerNorm(dim)
194
+ elif norm == "ada":
195
+ assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
196
+ return AdaptiveLayerNorm1D(dim, norm_cond_dim)
197
+ elif norm is None:
198
+ return torch.nn.Identity()
199
+ else:
200
+ raise ValueError(f"Unknown norm: {norm}")
201
+
202
+
203
+ class PreNorm(nn.Module):
204
+ def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
205
+ super().__init__()
206
+ self.norm = normalization_layer(norm, dim, norm_cond_dim)
207
+ self.fn = fn
208
+
209
+ def forward(self, x: torch.Tensor, *args, **kwargs):
210
+ if isinstance(self.norm, AdaptiveLayerNorm1D):
211
+ return self.fn(self.norm(x, *args), **kwargs)
212
+ else:
213
+ return self.fn(self.norm(x), **kwargs)
214
+
215
+
216
+ # This class is from HaMeR (https://github.com/geopavlakos/hamer).
217
+ class TransformerCrossAttn(nn.Module):
218
+ def __init__(
219
+ self,
220
+ dim: int,
221
+ depth: int,
222
+ heads: int,
223
+ dim_head: int,
224
+ mlp_dim: int,
225
+ dropout: float = 0.0,
226
+ norm: str = "layer",
227
+ norm_cond_dim: int = -1,
228
+ context_dim: Optional[int] = None,
229
+ ):
230
+ super().__init__()
231
+ self.layers = nn.ModuleList([])
232
+ for _ in range(depth):
233
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
234
+ ca = CrossAttention(
235
+ dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
236
+ )
237
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
238
+ self.layers.append(
239
+ nn.ModuleList(
240
+ [
241
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
242
+ PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
243
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
244
+ ]
245
+ )
246
+ )
247
+
248
+ def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
249
+ if context_list is None:
250
+ context_list = [context] * len(self.layers)
251
+ if len(context_list) != len(self.layers):
252
+ raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
253
+
254
+ for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
255
+ x = self_attn(x, *args) + x
256
+ x = cross_attn(x, *args, context=context_list[i]) + x
257
+ x = ff(x, *args) + x
258
+ return x
259
+
260
+
261
+ # This class is from HaMeR (https://github.com/geopavlakos/hamer).
262
+ class DropTokenDropout(nn.Module):
263
+ def __init__(self, p: float = 0.1):
264
+ super().__init__()
265
+ if p < 0 or p > 1:
266
+ raise ValueError(
267
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
268
+ )
269
+ self.p = p
270
+
271
+ def forward(self, x: torch.Tensor):
272
+ # x: (batch_size, seq_len, dim)
273
+ if self.training and self.p > 0:
274
+ zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
275
+ # TODO: permutation idx for each batch using torch.argsort
276
+ if zero_mask.any():
277
+ x = x[:, ~zero_mask, :]
278
+ return x
279
+
280
+
281
+ # This class is from HaMeR (https://github.com/geopavlakos/hamer).
282
+ class ZeroTokenDropout(nn.Module):
283
+ def __init__(self, p: float = 0.1):
284
+ super().__init__()
285
+ if p < 0 or p > 1:
286
+ raise ValueError(
287
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
288
+ )
289
+ self.p = p
290
+
291
+ def forward(self, x: torch.Tensor):
292
+ # x: (batch_size, seq_len, dim)
293
+ if self.training and self.p > 0:
294
+ zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
295
+ # Zero-out the masked tokens
296
+ x[zero_mask, :] = 0
297
+ return x
298
+
299
+
300
+ # This class is from HaMeR (https://github.com/geopavlakos/hamer).
301
+ class TransformerDecoder(nn.Module):
302
+ def __init__(
303
+ self,
304
+ num_tokens: int,
305
+ token_dim: int,
306
+ dim: int,
307
+ depth: int,
308
+ heads: int,
309
+ mlp_dim: int,
310
+ dim_head: int = 64,
311
+ dropout: float = 0.0,
312
+ emb_dropout: float = 0.0,
313
+ emb_dropout_type: str = 'drop',
314
+ norm: str = "layer",
315
+ norm_cond_dim: int = -1,
316
+ context_dim: Optional[int] = None,
317
+ skip_token_embedding: bool = False,
318
+ ):
319
+ super().__init__()
320
+ if not skip_token_embedding:
321
+ self.to_token_embedding = nn.Linear(token_dim, dim)
322
+ else:
323
+ self.to_token_embedding = nn.Identity()
324
+ if token_dim != dim:
325
+ raise ValueError(
326
+ f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
327
+ )
328
+
329
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
330
+ if emb_dropout_type == "drop":
331
+ self.dropout = DropTokenDropout(emb_dropout)
332
+ elif emb_dropout_type == "zero":
333
+ self.dropout = ZeroTokenDropout(emb_dropout)
334
+ elif emb_dropout_type == "normal":
335
+ self.dropout = nn.Dropout(emb_dropout)
336
+
337
+ self.transformer = TransformerCrossAttn(
338
+ dim,
339
+ depth,
340
+ heads,
341
+ dim_head,
342
+ mlp_dim,
343
+ dropout,
344
+ norm=norm,
345
+ norm_cond_dim=norm_cond_dim,
346
+ context_dim=context_dim,
347
+ )
348
+
349
+ def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
350
+ x = self.to_token_embedding(inp)
351
+ b, n, _ = x.shape
352
+
353
+ x = self.dropout(x)
354
+ x += self.pos_embedding[:, :n]
355
+
356
+ x = self.transformer(x, *args, context=context, context_list=context_list)
357
+ return x
358
+
359
+
360
+ def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
361
+ """
362
+ Convert 6D rotation representation to 3x3 rotation matrix.
363
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
364
+ Args:
365
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
366
+ Returns:
367
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
368
+ """
369
+ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
370
+ a1 = x[:, :, 0]
371
+ a2 = x[:, :, 1]
372
+ b1 = F.normalize(a1)
373
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
374
+ b3 = torch.cross(b1, b2)
375
+ return torch.stack((b1, b2, b3), dim=-1)
376
+
377
+
378
+ def aa_to_rotmat(theta: torch.Tensor):
379
+ """
380
+ Convert axis-angle representation to rotation matrix.
381
+ Works by first converting it to a quaternion.
382
+ Args:
383
+ theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
384
+ Returns:
385
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
386
+ """
387
+ norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
388
+ angle = torch.unsqueeze(norm, -1)
389
+ normalized = torch.div(theta, angle)
390
+ angle = angle * 0.5
391
+ v_cos = torch.cos(angle)
392
+ v_sin = torch.sin(angle)
393
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
394
+ return quat_to_rotmat(quat)
395
+
396
+
397
+ class MANO(smplx.MANOLayer):
398
+ def __init__(self, *args, joint_regressor_extra: Optional[str] = None, **kwargs):
399
+ """
400
+ Extension of the official MANO implementation to support more joints.
401
+ Args:
402
+ Same as MANOLayer.
403
+ joint_regressor_extra (str): Path to extra joint regressor.
404
+ """
405
+ super(MANO, self).__init__(*args, **kwargs)
406
+ mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]
407
+
408
+ #2, 3, 5, 4, 1
409
+ if joint_regressor_extra is not None:
410
+ self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
411
+ self.register_buffer('extra_joints_idxs', to_tensor(list(vertex_ids['mano'].values()), dtype=torch.long))
412
+ self.register_buffer('joint_map', torch.tensor(mano_to_openpose, dtype=torch.long))
413
+
414
+ def forward(self, *args, **kwargs) -> MANOOutput:
415
+ """
416
+ Run forward pass. Same as MANO and also append an extra set of joints if joint_regressor_extra is specified.
417
+ """
418
+ mano_output = super(MANO, self).forward(*args, **kwargs)
419
+ extra_joints = torch.index_select(mano_output.vertices, 1, self.extra_joints_idxs)
420
+ joints = torch.cat([mano_output.joints, extra_joints], dim=1)
421
+ joints = joints[:, self.joint_map, :]
422
+ if hasattr(self, 'joint_regressor_extra'):
423
+ extra_joints = vertices2joints(self.joint_regressor_extra, mano_output.vertices)
424
+ joints = torch.cat([joints, extra_joints], dim=1)
425
+ mano_output.joints = joints
426
+ return mano_output
427
+
428
+
429
+ class MANOTransformerDecoderHead(nn.Module):
430
+ """ Cross-attention based MANO Transformer decoder
431
+ """
432
+
433
+ def __init__(self):
434
+ super().__init__()
435
+ # self.cfg = cfg
436
+ self.joint_rep_type = '6d' #cfg.MODEL.MANO_HEAD.get('JOINT_REP', '6d')
437
+ self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type]
438
+ npose = self.joint_rep_dim * (cfg.MODEL.hamer_mano_num_hand_joints + 1)
439
+ self.npose = npose
440
+ self.input_is_mean_shape = False #cfg.MODEL.MANO_HEAD.get('TRANSFORMER_INPUT', 'zero') == 'mean_shape'
441
+ transformer_args = dict(
442
+ num_tokens=1,
443
+ token_dim=1,
444
+ dim=1024,
445
+ )
446
+ if cfg.MODEL.backbone_type in ['resnet-50', 'resnet-101', 'resnet-152', 'hrnet-w32', 'hrnet-w48']:
447
+ context_dim = 2048
448
+ elif cfg.MODEL.backbone_type in ['vit-l-16']:
449
+ context_dim = 1024
450
+ elif cfg.MODEL.backbone_type in ['vit-b-16']:
451
+ context_dim = 768
452
+ elif cfg.MODEL.backbone_type in ['resnet-18', 'resnet-34']:
453
+ context_dim = 512
454
+ elif cfg.MODEL.backbone_type in ['vit-s-16']:
455
+ context_dim = 384
456
+ elif cfg.MODEL.backbone_type in ['handoccnet']:
457
+ context_dim = 256
458
+ else:
459
+ context_dim = 1280
460
+
461
+ # transformer_args = (transformer_args | {'context_dim': 1280, 'depth': 6, 'dim_head': 64, 'dropout': 0.0, 'emb_dropout': 0.0, 'heads': 8, 'mlp_dim': 1024, 'norm': 'layer'})
462
+ transformer_args = {**transformer_args, 'context_dim': context_dim, 'depth': 6, 'dim_head': 64, 'dropout': 0.0, 'emb_dropout': 0.0, 'heads': 8, 'mlp_dim': 1024, 'norm': 'layer'}
463
+ self.transformer = TransformerDecoder(
464
+ **transformer_args
465
+ )
466
+ dim=transformer_args['dim']
467
+ self.decpose = nn.Linear(dim, npose)
468
+ self.decshape = nn.Linear(dim, 10)
469
+ self.deccam = nn.Linear(dim, 3)
470
+
471
+ mean_params = np.load(cfg.MODEL.hamer_mano_mean_params)
472
+ init_hand_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
473
+ init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
474
+ init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
475
+ self.register_buffer('init_hand_pose', init_hand_pose)
476
+ self.register_buffer('init_betas', init_betas)
477
+ self.register_buffer('init_cam', init_cam)
478
+
479
+ def forward(self, x, **kwargs):
480
+ batch_size = x.shape[0]
481
+ # vit pretrained backbone is channel-first. Change to token-first
482
+ x = rearrange(x, 'b c h w -> b (h w) c')
483
+
484
+ init_hand_pose = self.init_hand_pose.expand(batch_size, -1)
485
+ init_betas = self.init_betas.expand(batch_size, -1)
486
+ init_cam = self.init_cam.expand(batch_size, -1)
487
+
488
+ # TODO: Convert init_hand_pose to aa rep if needed
489
+ if self.joint_rep_type == 'aa':
490
+ raise NotImplementedError
491
+
492
+ pred_hand_pose = init_hand_pose
493
+ pred_betas = init_betas
494
+ pred_cam = init_cam
495
+ pred_hand_pose_list = []
496
+ pred_betas_list = []
497
+ pred_cam_list = []
498
+
499
+ # Input token to transformer is zero token
500
+ if self.input_is_mean_shape:
501
+ token = torch.cat([pred_hand_pose, pred_betas, pred_cam], dim=1)[:,None,:]
502
+ else:
503
+ token = torch.zeros(batch_size, 1, 1).to(x.device)
504
+
505
+ # Pass through transformer
506
+ token_out = self.transformer(token, context=x)
507
+ token_out = token_out.squeeze(1) # (B, C)
508
+
509
+ # Readout from token_out
510
+ pred_hand_pose = self.decpose(token_out) + pred_hand_pose
511
+ pred_betas = self.decshape(token_out) + pred_betas
512
+ pred_cam = self.deccam(token_out) + pred_cam
513
+ pred_hand_pose_list.append(pred_hand_pose)
514
+ pred_betas_list.append(pred_betas)
515
+ pred_cam_list.append(pred_cam)
516
+
517
+ # Convert self.joint_rep_type -> rotmat
518
+ joint_conversion_fn = {
519
+ '6d': rot6d_to_rotmat,
520
+ 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous())
521
+ }[self.joint_rep_type]
522
+
523
+ pred_mano_params_list = {}
524
+ pred_mano_params_list['hand_pose'] = torch.cat([joint_conversion_fn(pbp).view(batch_size, -1, 3, 3)[:, 1:, :, :] for pbp in pred_hand_pose_list], dim=0)
525
+ pred_mano_params_list['betas'] = torch.cat(pred_betas_list, dim=0)
526
+ pred_mano_params_list['cam'] = torch.cat(pred_cam_list, dim=0)
527
+ pred_hand_pose = joint_conversion_fn(pred_hand_pose).view(batch_size, cfg.MODEL.hamer_mano_num_hand_joints+1, 3, 3)
528
+
529
+ pred_mano_params = {'global_orient': pred_hand_pose[:, [0]],
530
+ 'hand_pose': pred_hand_pose[:, 1:],
531
+ 'betas': pred_betas}
532
+ return pred_mano_params, pred_cam, pred_mano_params_list
533
+
534
+
535
+ def perspective_projection(points: torch.Tensor,
536
+ translation: torch.Tensor,
537
+ focal_length: torch.Tensor,
538
+ camera_center: Optional[torch.Tensor] = None,
539
+ rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
540
+ """
541
+ Computes the perspective projection of a set of 3D points.
542
+ Args:
543
+ points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
544
+ translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
545
+ focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
546
+ camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
547
+ rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
548
+ Returns:
549
+ torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
550
+ """
551
+ batch_size = points.shape[0]
552
+ if rotation is None:
553
+ rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
554
+ if camera_center is None:
555
+ camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
556
+ # Populate intrinsic camera matrix K.
557
+ K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
558
+ K[:,0,0] = focal_length[:,0]
559
+ K[:,1,1] = focal_length[:,1]
560
+ K[:,2,2] = 1.
561
+ K[:,:-1, -1] = camera_center
562
+
563
+ # Transform points
564
+ points = torch.einsum('bij,bkj->bki', rotation, points)
565
+ points = points + translation.unsqueeze(1)
566
+
567
+ # Apply perspective distortion
568
+ projected_points = points / points[:,:,-1].unsqueeze(-1)
569
+
570
+ # Apply camera intrinsics
571
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
572
+
573
+ return projected_points[:, :, :-1]
574
+
575
+
576
+ # This module is modified from MANOTransformerDecoderHead of HaMeR (https://github.com/geopavlakos/hamer). All cfg are directly initialized.
577
+ class ContactTransformerDecoderHead(nn.Module):
578
+ """ Cross-attention based MANO Transformer decoder
579
+ """
580
+ def __init__(self):
581
+ super().__init__()
582
+ transformer_args = dict(
583
+ num_tokens=1,
584
+ token_dim=1,
585
+ dim=1024,
586
+ )
587
+ if cfg.MODEL.backbone_type in ['resnet-50', 'resnet-101', 'resnet-152', 'hrnet-w32', 'hrnet-w48']:
588
+ context_dim = 2048
589
+ elif cfg.MODEL.backbone_type in ['vit-l-16']:
590
+ context_dim = 1024
591
+ elif cfg.MODEL.backbone_type in ['vit-b-16']:
592
+ context_dim = 768
593
+ elif cfg.MODEL.backbone_type in ['resnet-18', 'resnet-34']:
594
+ context_dim = 512
595
+ elif cfg.MODEL.backbone_type in ['vit-s-16']:
596
+ context_dim = 384
597
+ elif cfg.MODEL.backbone_type in ['handoccnet']:
598
+ context_dim = 256
599
+ else:
600
+ context_dim = 1280
601
+ MANO_HEAD_TRANSFORMER_DECODER_CONFIG = {'depth': 6, 'heads': 8, 'mlp_dim': 1024, 'dim_head': 64, 'dropout': 0.0, 'emb_dropout': 0.0, 'norm': 'layer', 'context_dim': context_dim}
602
+ transformer_args.update(dict(MANO_HEAD_TRANSFORMER_DECODER_CONFIG))
603
+ self.transformer = TransformerDecoder(
604
+ **transformer_args
605
+ )
606
+ self.deccontact = nn.Linear(1024, 778)
607
+
608
+ CONTACT_MEAN_DIR = cfg.MODEL.contact_means_path # TODO: REPLACE THIS WITH CONTACT MEAN OF ENTIRE DATASETS
609
+ init_contact = nn.Parameter(torch.randn(1, 778, requires_grad=True))
610
+ self.register_buffer('init_contact', init_contact)
611
+
612
+ def forward(self, x, **kwargs): # x: [b, 1280, 16, 12] (if resnet-50, x: [b, 2048, 8, 8], resnet-34: [b, 512, 8, 8], hrnet-w32: [b, 2048, 8, 8])
613
+ batch_size = x.shape[0]
614
+ device = x.device
615
+
616
+ # vit pretrained backbone is channel-first. Change to token-first
617
+ x = rearrange(x, 'b c h w -> b (h w) c')
618
+
619
+ init_contact = self.init_contact.expand(batch_size, -1)
620
+ pred_contact = init_contact
621
+
622
+ token = torch.zeros(batch_size, 1, 1).to(x.device)
623
+
624
+ # Pass through transformer
625
+ token_out = self.transformer(token, context=x) # x: [b, 192, 1280]
626
+ token_out = token_out[:, 0] # (B, C)
627
+
628
+ # Readout from token_out
629
+ pred_contact = self.deccontact(token_out) + pred_contact
630
+ # pred_contact = pred_contact.sigmoid()
631
+
632
+ # Joint contact
633
+ pred_joint_contact = (torch.tensor(mano.joint_regressor, dtype=torch.float32, device=device) @ pred_contact.T).T
634
+ pred_mesh_contact_336 = (torch.tensor(V_regressor_336, dtype=torch.float32, device=device) @ pred_contact.T).T
635
+ pred_mesh_contact_84 = (torch.tensor(V_regressor_84, dtype=torch.float32, device=device) @ pred_contact.T).T
636
+
637
+ return pred_contact, pred_mesh_contact_336, pred_mesh_contact_84, pred_joint_contact
lib/models/model.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from lib.core.config import cfg
6
+
7
+
8
+
9
+ class HACO(nn.Module):
10
+ def __init__(self):
11
+ super(HACO, self).__init__()
12
+ if torch.cuda.is_available():
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ self.to(self.device)
15
+
16
+ # Load modules
17
+ self.backbone = get_backbone_network(type=cfg.MODEL.backbone_type)
18
+ self.decoder = get_decoder_network(type=cfg.MODEL.backbone_type)
19
+
20
+ def forward(self, inputs, mode='test'):
21
+ image = inputs['input']['image'].to(self.device)
22
+
23
+ if 'vit' in cfg.MODEL.backbone_type:
24
+ image = F.interpolate(image, size=(224, 224), mode='bilinear', align_corners=False)
25
+
26
+ img_feat = self.backbone(image)
27
+ contact_out, contact_336_out, contact_84_out, contact_joint_out = self.decoder(img_feat)
28
+
29
+ return dict(contact_out=contact_out, contact_336_out=contact_336_out, contact_84_out=contact_84_out, contact_joint_out=contact_joint_out)
30
+
31
+
32
+
33
+ def get_backbone_network(type='hamer'):
34
+ if type in ['hamer']:
35
+ from lib.models.backbone.backbone_hamer_style import ViT_HaMeR
36
+ backbone = ViT_HaMeR()
37
+ checkpoint = torch.load(cfg.MODEL.hamer_backbone_pretrained_path, map_location='cuda')['state_dict']
38
+ filtered_state_dict = {k[len("backbone."):]: v for k, v in checkpoint.items() if k.startswith("backbone.")}
39
+ backbone.load_state_dict(filtered_state_dict)
40
+ elif type in ['resnet-18']:
41
+ from lib.models.backbone.resnet import ResNetBackbone
42
+ backbone = ResNetBackbone(18) # ResNet
43
+ backbone.init_weights()
44
+ elif type in ['resnet-34']:
45
+ from lib.models.backbone.resnet import ResNetBackbone
46
+ backbone = ResNetBackbone(34) # ResNet
47
+ backbone.init_weights()
48
+ elif type in ['resnet-50']:
49
+ from lib.models.backbone.resnet import ResNetBackbone
50
+ backbone = ResNetBackbone(50) # ResNet
51
+ backbone.init_weights()
52
+ elif type in ['resnet-101']:
53
+ from lib.models.backbone.resnet import ResNetBackbone
54
+ backbone = ResNetBackbone(101) # ResNet
55
+ backbone.init_weights()
56
+ elif type in ['resnet-152']:
57
+ from lib.models.backbone.resnet import ResNetBackbone
58
+ backbone = ResNetBackbone(152) # ResNet
59
+ backbone.init_weights()
60
+ elif type in ['hrnet-w32']:
61
+ from lib.models.backbone.hrnet import HighResolutionNet
62
+ from lib.utils.func_utils import load_config
63
+ config = load_config(cfg.MODEL.hrnet_w32_backbone_config_path)
64
+ pretrained = cfg.MODEL.hrnet_w32_backbone_pretrained_path
65
+ backbone = HighResolutionNet(config)
66
+ backbone.init_weights(pretrained=pretrained)
67
+ elif type in ['hrnet-w48']:
68
+ from lib.models.backbone.hrnet import HighResolutionNet
69
+ from lib.utils.func_utils import load_config
70
+ config = load_config(cfg.MODEL.hrnet_w48_backbone_config_path)
71
+ pretrained = cfg.MODEL.hrnet_w48_backbone_pretrained_path
72
+ backbone = HighResolutionNet(config)
73
+ backbone.init_weights(pretrained=pretrained)
74
+ elif type in ['handoccnet']:
75
+ from lib.models.backbone.fpn import FPN
76
+ backbone = FPN(pretrained=False)
77
+ pretrained = cfg.MODEL.handoccnet_backbone_pretrained_path
78
+ state_dict = {k[len('module.backbone.'):]: v for k, v in torch.load(pretrained)['network'].items() if k.startswith('module.backbone.')}
79
+ backbone.load_state_dict(state_dict, strict=True)
80
+ elif type in ['vit-s-16']:
81
+ from lib.models.backbone.vit import ViTBackbone
82
+ backbone = ViTBackbone(model_name='vit_small_patch16_224', pretrained=True)
83
+ elif type in ['vit-b-16']:
84
+ from lib.models.backbone.vit import ViTBackbone
85
+ backbone = ViTBackbone(model_name='vit_base_patch16_224', pretrained=True)
86
+ elif type in ['vit-l-16']:
87
+ from lib.models.backbone.vit import ViTBackbone
88
+ backbone = ViTBackbone(model_name='vit_large_patch16_224', pretrained=True)
89
+ else:
90
+ raise NotImplementedError
91
+
92
+ return backbone
93
+
94
+
95
+
96
+ def get_decoder_network(type='hamer'):
97
+ from lib.models.decoder.decoder_hamer_style import ContactTransformerDecoderHead
98
+ decoder = ContactTransformerDecoderHead()
99
+
100
+ return decoder
lib/utils/__pycache__/contact_utils.cpython-38.pyc ADDED
Binary file (1.3 kB). View file
 
lib/utils/__pycache__/eval_utils.cpython-38.pyc ADDED
Binary file (1.17 kB). View file
 
lib/utils/__pycache__/func_utils.cpython-38.pyc ADDED
Binary file (2.14 kB). View file
 
lib/utils/__pycache__/human_models.cpython-38.pyc ADDED
Binary file (3.61 kB). View file
 
lib/utils/__pycache__/log_utils.cpython-38.pyc ADDED
Binary file (471 Bytes). View file
 
lib/utils/__pycache__/mano_utils.cpython-38.pyc ADDED
Binary file (4.35 kB). View file
 
lib/utils/__pycache__/mesh_utils.cpython-38.pyc ADDED
Binary file (2.01 kB). View file
 
lib/utils/__pycache__/preprocessing.cpython-38.pyc ADDED
Binary file (7.88 kB). View file
 
lib/utils/__pycache__/train_utils.cpython-38.pyc ADDED
Binary file (679 Bytes). View file
 
lib/utils/__pycache__/transforms.cpython-38.pyc ADDED
Binary file (640 Bytes). View file
 
lib/utils/__pycache__/vis_utils.cpython-38.pyc ADDED
Binary file (6.27 kB). View file
 
lib/utils/contact_utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import torch
3
+ import numpy as np
4
+ from trimesh.proximity import ProximityQuery
5
+
6
+ from lib.utils.human_models import mano
7
+
8
+
9
+ def get_ho_contact_and_offset(mesh_hand, mesh_obj, c_thres):
10
+ # Make sure that meshes are watertight and do not comntain inverted faces
11
+ # Typically canonical space meshes are more stable
12
+
13
+ pq = ProximityQuery(mesh_obj)
14
+ obj_coord_c, dist, obj_coord_c_idx = pq.on_surface(mesh_hand.vertices.astype(np.float32))
15
+
16
+ is_contact_h = (dist < c_thres)
17
+ contact_h = (1. * is_contact_h).astype(np.float32)
18
+
19
+ contact_valid = np.ones((mano.vertex_num, 1))
20
+ inter_coord_valid = np.ones((mano.vertex_num))
21
+
22
+ # Explicit cleanup
23
+ del pq
24
+ gc.collect()
25
+
26
+ return np.array(contact_h), np.array(obj_coord_c), contact_valid, inter_coord_valid
27
+
28
+
29
+ def get_contact_thres(backbone_type='hamer'):
30
+ if backbone_type == 'hamer':
31
+ return 0.5
32
+ elif backbone_type == 'vit-l-16':
33
+ return 0.55
34
+ elif backbone_type == 'vit-b-16':
35
+ return 0.5
36
+ elif backbone_type == 'vit-s-16':
37
+ return 0.5
38
+ elif backbone_type == 'handoccnet':
39
+ return 0.95
40
+ elif backbone_type == 'hrnet-w48':
41
+ return 0.5
42
+ elif backbone_type == 'hrnet-w32':
43
+ return 0.5
44
+ elif backbone_type == 'resnet-152':
45
+ return 0.55
46
+ elif backbone_type == 'resnet-101':
47
+ return 0.5
48
+ elif backbone_type == 'resnet-50':
49
+ return 0.5
50
+ elif backbone_type == 'resnet-34':
51
+ return 0.5
52
+ elif backbone_type == 'resnet-18':
53
+ return 0.5
54
+ else:
55
+ raise NotImplementedError
lib/utils/demo_utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from collections import defaultdict, deque
4
+
5
+ import mediapipe as mp
6
+
7
+
8
+ from lib.utils.vis_utils import draw_landmarks_on_image
9
+
10
+
11
+ def smooth_bbox(prev_bbox, curr_bbox, alpha=0.8):
12
+ if prev_bbox is None:
13
+ return curr_bbox
14
+ return [alpha * p + (1 - alpha) * c for p, c in zip(prev_bbox, curr_bbox)]
15
+
16
+
17
+ def smooth_contact_mask(prev_mask, curr_mask, alpha=0.8):
18
+ if prev_mask is None:
19
+ return curr_mask.astype(np.float32)
20
+ return alpha * prev_mask + (1 - alpha) * curr_mask.astype(np.float32)
21
+
22
+
23
+ def remove_small_contact_components(contact_mask, faces, min_size=20):
24
+ vertex_to_faces = defaultdict(list)
25
+ for i, f in enumerate(faces):
26
+ for v in f:
27
+ vertex_to_faces[v].append(i)
28
+
29
+ visited = np.zeros(len(contact_mask), dtype=bool)
30
+ filtered_mask = np.zeros_like(contact_mask, dtype=bool)
31
+
32
+ for v in range(len(contact_mask)):
33
+ if visited[v] or not contact_mask[v]:
34
+ continue
35
+
36
+ queue = deque([v])
37
+ component = []
38
+ while queue:
39
+ curr = queue.popleft()
40
+ if visited[curr] or not contact_mask[curr]:
41
+ continue
42
+ visited[curr] = True
43
+ component.append(curr)
44
+ for f_idx in vertex_to_faces[curr]:
45
+ for neighbor in faces[f_idx]:
46
+ if not visited[neighbor] and contact_mask[neighbor]:
47
+ queue.append(neighbor)
48
+
49
+ if len(component) >= min_size:
50
+ filtered_mask[component] = True
51
+
52
+ return filtered_mask
53
+
54
+
55
+ def initialize_video_writer(output_path, fps, frame_size):
56
+ tried_codecs = ['avc1', 'H264', 'X264', 'MJPG', 'mp4v'] # we recommend using 'MJPG'
57
+ for codec in tried_codecs:
58
+ fourcc = cv2.VideoWriter_fourcc(*codec)
59
+ writer = cv2.VideoWriter(output_path, fourcc, fps, frame_size)
60
+ if writer.isOpened():
61
+ print(f"Using codec '{codec}' for {output_path}")
62
+ return writer
63
+ writer.release()
64
+ raise RuntimeError(f"Failed to initialize VideoWriter for {output_path}")
65
+
66
+
67
+ def extract_frames_with_hand(cap, detector):
68
+ frames_with_hand = []
69
+ frame_idx = 0
70
+
71
+ while cap.isOpened():
72
+ ret, frame = cap.read()
73
+ if not ret:
74
+ break
75
+
76
+ orig_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
77
+ mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=orig_img)
78
+ detection_result = detector.detect(mp_image)
79
+ _, right_hand_bbox = draw_landmarks_on_image(orig_img.copy(), detection_result)
80
+
81
+ if right_hand_bbox is not None:
82
+ frames_with_hand.append((frame_idx, frame, right_hand_bbox))
83
+
84
+ frame_idx += 1
85
+
86
+ cap.release()
87
+ return frames_with_hand
88
+
89
+
90
+ def find_longest_continuous_segment(frames_with_hand):
91
+ longest_segment = []
92
+ current_segment = []
93
+
94
+ for i in range(len(frames_with_hand)):
95
+ if i == 0 or frames_with_hand[i][0] == frames_with_hand[i - 1][0] + 1:
96
+ current_segment.append(frames_with_hand[i])
97
+ else:
98
+ if len(current_segment) > len(longest_segment):
99
+ longest_segment = current_segment
100
+ current_segment = [frames_with_hand[i]]
101
+
102
+ if len(current_segment) > len(longest_segment):
103
+ longest_segment = current_segment
104
+
105
+ return longest_segment
lib/utils/eval_utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def evaluation(outputs, targets_data, meta_info, mode='val', thres=0.5):
6
+ eval_out = {}
7
+
8
+ # GT
9
+ mesh_valid = meta_info['mano_valid'] is not None
10
+
11
+ # Pred
12
+ contact_pred = outputs['contact_out'].sigmoid()[0].detach().cpu().numpy()
13
+
14
+ # Error Calculate
15
+ if mesh_valid:
16
+ # Contact Metrics
17
+ cont_pre, cont_rec, cont_f1 = compute_contact_metrics(targets_data['contact_data']['contact_h'][0].detach().cpu().numpy(), outputs['contact_out'][0].detach().cpu().numpy(), mesh_valid, thres=thres)
18
+ eval_out['cont_pre'] = cont_pre
19
+ eval_out['cont_rec'] = cont_rec
20
+ eval_out['cont_f1'] = cont_f1
21
+
22
+ return eval_out
23
+
24
+
25
+ def compute_contact_metrics(gt, pred, valid, thres=0.5):
26
+ """
27
+ Compute precision, recall, and f1 using NumPy
28
+ """
29
+ if valid:
30
+ # True Positives
31
+ tp_num = np.sum(gt[pred >= thres])
32
+
33
+ # Denominators for precision and recall
34
+ precision_denominator = np.sum(pred >= thres)
35
+ recall_denominator = np.sum(gt)
36
+
37
+ # Compute precision, recall, and F1 score
38
+ precision_ = tp_num / precision_denominator if precision_denominator > 0 else None
39
+ recall_ = tp_num / recall_denominator if recall_denominator > 0 else None
40
+ if precision_ is not None and recall_ is not None and (precision_ + recall_) > 0:
41
+ f1_ = 2 * precision_ * recall_ / (precision_ + recall_)
42
+ else:
43
+ f1_ = None
44
+ else:
45
+ # If not valid, return None for metrics
46
+ precision_ = None
47
+ recall_ = None
48
+ f1_ = None
49
+
50
+ return precision_, recall_, f1_
lib/utils/func_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+
5
+
6
+ def load_img(path, order='RGB'):
7
+ img = cv2.imread(path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
8
+ if not isinstance(img, np.ndarray):
9
+ raise IOError("Fail to read %s" % path)
10
+
11
+ if order=='RGB': img = img[:,:,::-1]
12
+ img = img.astype(np.float32)
13
+ return img
14
+
15
+
16
+ def get_bbox(joint_img, joint_valid, expansion_factor=1.0):
17
+ x_img, y_img = joint_img[:,0], joint_img[:,1]
18
+ x_img = x_img[joint_valid==1]; y_img = y_img[joint_valid==1];
19
+ xmin = min(x_img); ymin = min(y_img); xmax = max(x_img); ymax = max(y_img);
20
+
21
+ x_center = (xmin+xmax)/2.; width = (xmax-xmin)*expansion_factor;
22
+ xmin = x_center - 0.5*width
23
+ xmax = x_center + 0.5*width
24
+
25
+ y_center = (ymin+ymax)/2.; height = (ymax-ymin)*expansion_factor;
26
+ ymin = y_center - 0.5*height
27
+ ymax = y_center + 0.5*height
28
+
29
+ bbox = np.array([xmin, ymin, xmax - xmin, ymax - ymin]).astype(np.float32)
30
+ return bbox
31
+
32
+
33
+ def process_bbox(bbox, target_shape, original_img_shape):
34
+
35
+ # aspect ratio preserving bbox
36
+ w = bbox[2]
37
+ h = bbox[3]
38
+ c_x = bbox[0] + w/2.
39
+ c_y = bbox[1] + h/2.
40
+ aspect_ratio = target_shape[1]/target_shape[0]
41
+ if w > aspect_ratio * h:
42
+ h = w / aspect_ratio
43
+ elif w < aspect_ratio * h:
44
+ w = h * aspect_ratio
45
+ bbox[2] = w*1.25
46
+ bbox[3] = h*1.25
47
+ bbox[0] = c_x - bbox[2]/2.
48
+ bbox[1] = c_y - bbox[3]/2.
49
+
50
+ return bbox
51
+
52
+
53
+ import re
54
+ def atoi(text):
55
+ return int(text) if text.isdigit() else text
56
+ def natural_keys(text):
57
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
58
+
59
+
60
+ # Load config
61
+ import yaml
62
+ def load_config(cfg_path):
63
+ with open(cfg_path, 'r') as f:
64
+ cfg = yaml.safe_load(f)
65
+ return cfg
lib/utils/human_models.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import os.path as osp
4
+ import pickle
5
+
6
+ from lib.core.config import cfg
7
+ from lib.utils.transforms import transform_joint_to_other_db
8
+ from lib.utils.smplx import smplx
9
+
10
+
11
+
12
+ class MANO(object):
13
+ def __init__(self):
14
+ self.layer_arg = {'create_global_orient': False, 'create_hand_pose': False, 'create_betas': False, 'create_transl': False}
15
+ self.layer = {'right': smplx.create(cfg.MODEL.human_model_path, 'mano', is_rhand=True, use_pca=False, flat_hand_mean=False, **self.layer_arg), 'left': smplx.create(cfg.MODEL.human_model_path, 'mano', is_rhand=False, use_pca=False, flat_hand_mean=False, **self.layer_arg)}
16
+ self.vertex_num = 778
17
+ self.face = {'right': self.layer['right'].faces, 'left': self.layer['left'].faces}
18
+ self.add_watertight_face = {'right': np.array([[92,38,122], [234,92,122], [239,234,122], [279,239,122], [215,279,122], [215,122,118], [215,118,117], [215,117,119], [215,119,120], [215,120,108], [215,108,79], [215,79,78], [215,78,121], [214,215,121]])}
19
+ self.watertight_face = {'right': np.concatenate((self.layer['right'].faces, self.add_watertight_face['right']), axis=0)}
20
+ self.shape_param_dim = 10
21
+
22
+ if torch.sum(torch.abs(self.layer['left'].shapedirs[:,0,:] - self.layer['right'].shapedirs[:,0,:])) < 1:
23
+ print('Fix shapedirs bug of MANO')
24
+ self.layer['left'].shapedirs[:,0,:] *= -1
25
+
26
+ # original MANO joint set
27
+ self.orig_joint_num = 16
28
+ self.orig_joints_name = ('Wrist', 'Index_1', 'Index_2', 'Index_3', 'Middle_1', 'Middle_2', 'Middle_3', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Ring_1', 'Ring_2', 'Ring_3', 'Thumb_1', 'Thumb_2', 'Thumb_3')
29
+ self.orig_root_joint_idx = self.orig_joints_name.index('Wrist')
30
+ self.orig_flip_pairs = ()
31
+ self.orig_joint_regressor = self.layer['right'].J_regressor.numpy() # same for the right and left hands
32
+
33
+ # changed MANO joint set
34
+ self.joint_num = 21 # manually added fingertips
35
+ self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
36
+ self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) )
37
+ self.root_joint_idx = self.joints_name.index('Wrist')
38
+ self.flip_pairs = ()
39
+ # add fingertips to joint_regressor
40
+ self.joint_regressor = transform_joint_to_other_db(self.orig_joint_regressor, self.orig_joints_name, self.joints_name)
41
+ self.joint_regressor[self.joints_name.index('Thumb_4')] = np.array([1 if i == 745 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
42
+ self.joint_regressor[self.joints_name.index('Index_4')] = np.array([1 if i == 317 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
43
+ self.joint_regressor[self.joints_name.index('Middle_4')] = np.array([1 if i == 445 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
44
+ self.joint_regressor[self.joints_name.index('Ring_4')] = np.array([1 if i == 556 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
45
+ self.joint_regressor[self.joints_name.index('Pinky_4')] = np.array([1 if i == 673 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
46
+
47
+
48
+
49
+ mano = MANO()
lib/utils/log_utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ from datetime import datetime
5
+ from pytz import timezone
6
+
7
+
8
+ def init_dirs(dir_list):
9
+ for dir in dir_list:
10
+ if os.path.exists(dir) and os.path.isdir(dir):
11
+ shutil.rmtree(dir)
12
+ os.makedirs(dir)
lib/utils/mano_utils.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved.
3
+ This software is provided for research purposes only.
4
+ By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license
5
+
6
+ More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de.
7
+ For comments or questions, please email us at: mano@tue.mpg.de
8
+
9
+
10
+ About this file:
11
+ ================
12
+ This file defines a wrapper for the loading functions of the MANO model.
13
+
14
+ Modules included:
15
+ - load_model:
16
+ loads the MANO model from a given file location (i.e. a .pkl file location),
17
+ or a dictionary object.
18
+
19
+ '''
20
+ import os
21
+ import cv2
22
+ import torch
23
+ import numpy as np
24
+ import pickle
25
+ import chumpy as ch
26
+ from chumpy.ch import MatVecMult
27
+
28
+
29
+ class Rodrigues(ch.Ch):
30
+ dterms = 'rt'
31
+
32
+ def compute_r(self):
33
+ return cv2.Rodrigues(self.rt.r)[0]
34
+
35
+ def compute_dr_wrt(self, wrt):
36
+ if wrt is self.rt:
37
+ return cv2.Rodrigues(self.rt.r)[1].T
38
+
39
+
40
+ def lrotmin(p):
41
+ if isinstance(p, np.ndarray):
42
+ p = p.ravel()[3:]
43
+ return np.concatenate(
44
+ [(cv2.Rodrigues(np.array(pp))[0] - np.eye(3)).ravel()
45
+ for pp in p.reshape((-1, 3))]).ravel()
46
+ if p.ndim != 2 or p.shape[1] != 3:
47
+ p = p.reshape((-1, 3))
48
+ p = p[1:]
49
+ return ch.concatenate([(Rodrigues(pp) - ch.eye(3)).ravel()
50
+ for pp in p]).ravel()
51
+
52
+
53
+ def posemap(s):
54
+ if s == 'lrotmin':
55
+ return lrotmin
56
+ else:
57
+ raise Exception('Unknown posemapping: %s' % (str(s), ))
58
+
59
+
60
+ def ready_arguments(fname_or_dict, posekey4vposed='pose'):
61
+ if not isinstance(fname_or_dict, dict):
62
+ dd = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1')
63
+ else:
64
+ dd = fname_or_dict
65
+
66
+ want_shapemodel = 'shapedirs' in dd
67
+ nposeparms = dd['kintree_table'].shape[1] * 3
68
+
69
+ if 'trans' not in dd:
70
+ dd['trans'] = np.zeros(3)
71
+ if 'pose' not in dd:
72
+ dd['pose'] = np.zeros(nposeparms)
73
+ if 'shapedirs' in dd and 'betas' not in dd:
74
+ dd['betas'] = np.zeros(dd['shapedirs'].shape[-1])
75
+
76
+ for s in [
77
+ 'v_template', 'weights', 'posedirs', 'pose', 'trans', 'shapedirs',
78
+ 'betas', 'J'
79
+ ]:
80
+ if (s in dd) and not hasattr(dd[s], 'dterms'):
81
+ dd[s] = ch.array(dd[s])
82
+
83
+ assert (posekey4vposed in dd)
84
+ if want_shapemodel:
85
+ dd['v_shaped'] = dd['shapedirs'].dot(dd['betas']) + dd['v_template']
86
+ v_shaped = dd['v_shaped']
87
+ J_tmpx = MatVecMult(dd['J_regressor'], v_shaped[:, 0])
88
+ J_tmpy = MatVecMult(dd['J_regressor'], v_shaped[:, 1])
89
+ J_tmpz = MatVecMult(dd['J_regressor'], v_shaped[:, 2])
90
+ dd['J'] = ch.vstack((J_tmpx, J_tmpy, J_tmpz)).T
91
+ pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed])
92
+ dd['v_posed'] = v_shaped + dd['posedirs'].dot(pose_map_res)
93
+ else:
94
+ pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed])
95
+ dd_add = dd['posedirs'].dot(pose_map_res)
96
+ dd['v_posed'] = dd['v_template'] + dd_add
97
+
98
+ return dd
99
+
100
+
101
+
102
+ def get_mano_pca_basis(ncomps=45, use_pca=True, side='right', mano_root='data/base_data/human_models/mano'):
103
+ if use_pca:
104
+ ncomps = ncomps
105
+ else:
106
+ ncomps = 45
107
+
108
+ if side == 'right':
109
+ mano_path = os.path.join(mano_root, 'MANO_RIGHT.pkl')
110
+ elif side == 'left':
111
+ mano_path = os.path.join(mano_root, 'MANO_LEFT.pkl')
112
+ smpl_data = ready_arguments(mano_path)
113
+ hands_components = smpl_data['hands_components']
114
+ selected_components = hands_components[:ncomps]
115
+ th_selected_comps = selected_components
116
+
117
+ return torch.tensor(th_selected_comps, dtype=torch.float32)
118
+
119
+
120
+
121
+ def change_flat_hand_mean(hand_pose, remove=True, side='right', mano_root='data/base_data/human_models/mano'):
122
+ if side == 'right':
123
+ mano_path = os.path.join(mano_root, 'MANO_RIGHT.pkl')
124
+ elif side == 'left':
125
+ mano_path = os.path.join(mano_root, 'MANO_LEFT.pkl')
126
+ smpl_data = ready_arguments(mano_path)
127
+
128
+ # Get hand mean
129
+ hands_mean = smpl_data['hands_mean']
130
+ hands_mean = hands_mean.copy() # hands_mean: (45)
131
+
132
+ if remove:
133
+ hand_pose[3:] = hand_pose[3:] - hands_mean
134
+ else:
135
+ hand_pose[3:] = hand_pose[3:] + hands_mean
136
+ return hand_pose
lib/utils/mesh_utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import trimesh
3
+ import numpy as np
4
+ from plyfile import PlyData, PlyElement
5
+
6
+
7
+
8
+ def center_vertices(vertices, faces, flip_y=True): # This is for MOW dataset
9
+ """Centroid-align vertices."""
10
+ vertices = vertices - np.mean(vertices, axis=0, keepdims=True)
11
+ if flip_y:
12
+ vertices[:, 1] *= -1
13
+ faces = faces[:, [2, 1, 0]]
14
+ return vertices, faces
15
+
16
+
17
+
18
+ def load_obj_nr(filename_obj, normalization=True, texture_size=4, load_texture=False, # load_obj function from neural_renderer (https://github.com/daniilidis-group/neural_renderer) and MOW (https://github.com/ZheC/MOW)
19
+ texture_wrapping='REPEAT', use_bilinear=True):
20
+ """
21
+ Load Wavefront .obj file.
22
+ This function only supports vertices (v x x x) and faces (f x x x).
23
+ """
24
+
25
+ # load vertices
26
+ vertices = []
27
+ with open(filename_obj) as f:
28
+ lines = f.readlines()
29
+
30
+ for line in lines:
31
+ if len(line.split()) == 0:
32
+ continue
33
+ if line.split()[0] == 'v':
34
+ vertices.append([float(v) for v in line.split()[1:4]])
35
+ vertices = torch.from_numpy(np.vstack(vertices).astype(np.float32))
36
+
37
+ # load faces
38
+ faces = []
39
+ for line in lines:
40
+ if len(line.split()) == 0:
41
+ continue
42
+ if line.split()[0] == 'f':
43
+ vs = line.split()[1:]
44
+ nv = len(vs)
45
+ v0 = int(vs[0].split('/')[0])
46
+ for i in range(nv - 2):
47
+ v1 = int(vs[i + 1].split('/')[0])
48
+ v2 = int(vs[i + 2].split('/')[0])
49
+ faces.append((v0, v1, v2))
50
+ faces = torch.from_numpy(np.vstack(faces).astype(np.int32)) - 1
51
+
52
+ # load textures
53
+ textures = None
54
+ if load_texture:
55
+ for line in lines:
56
+ if line.startswith('mtllib'):
57
+ filename_mtl = os.path.join(os.path.dirname(filename_obj), line.split()[1])
58
+ textures = load_textures(filename_obj, filename_mtl, texture_size,
59
+ texture_wrapping=texture_wrapping,
60
+ use_bilinear=use_bilinear)
61
+ if textures is None:
62
+ raise Exception('Failed to load textures.')
63
+
64
+ # normalize into a unit cube centered zero
65
+ if normalization:
66
+ vertices -= vertices.min(0)[0][None, :]
67
+ vertices /= torch.abs(vertices).max()
68
+ vertices *= 2
69
+ vertices -= vertices.max(0)[0][None, :] / 2
70
+
71
+ if load_texture:
72
+ return vertices, faces, textures
73
+ else:
74
+ return vertices, faces
lib/utils/preprocessing.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+
7
+ from lib.core.config import cfg
8
+ from lib.utils.human_models import mano
9
+
10
+
11
+ def get_aug_config_contact():
12
+ # Augmentation intensity factors
13
+ scale_factor = 0.25
14
+ rot_factor = 30
15
+ color_factor = 0.2
16
+ trans_factor = 0.1 # Translation range (recommended 0.1 to 0.2)
17
+ noise_std = 0.02 # Gaussian noise strength
18
+ motion_blur_prob = 0.15 # Probability of applying motion blur
19
+ extreme_crop_prob = 0.1 # Probability for extreme cropping
20
+ extreme_crop_lvl = 0.3 # Crop intensity (recommended 0.2 to 0.4)
21
+ low_res_prob = 0.05 # Probability for applying low resolution
22
+ low_res_scale_range = (0.15, 0.5) # Range for low-res scaling
23
+
24
+ # Scaling augmentation
25
+ scale = np.clip(np.random.randn(), -1.0, 1.0) * scale_factor + 1.0
26
+
27
+ # Rotation augmentation
28
+ rot = np.clip(np.random.randn(), -2.0, 2.0) * rot_factor if random.random() <= 0.6 else 0
29
+
30
+ # Color augmentation
31
+ c_up = 1.0 + color_factor
32
+ c_low = 1.0 - color_factor
33
+ color_scale = np.array([
34
+ random.uniform(c_low, c_up),
35
+ random.uniform(c_low, c_up),
36
+ random.uniform(c_low, c_up)
37
+ ])
38
+
39
+ # Flipping augmentation
40
+ do_flip = random.random() <= 0.5
41
+
42
+ # Translation augmentation
43
+ tx = np.clip(np.random.randn(), -1.0, 1.0) * trans_factor
44
+ ty = np.clip(np.random.randn(), -1.0, 1.0) * trans_factor
45
+
46
+ # Extreme cropping augmentation
47
+ do_extreme_crop = random.random() <= extreme_crop_prob
48
+
49
+ # Noise augmentation (returns standard deviation for Gaussian noise injection)
50
+ add_noise = random.random() <= 0.3 # 30% chance of adding noise
51
+ noise_std = noise_std if add_noise else 0.0
52
+
53
+ # Motion blur augmentation
54
+ apply_motion_blur = random.random() <= motion_blur_prob
55
+ motion_blur_kernel_size = random.choice([3, 5, 7]) if apply_motion_blur else 0
56
+
57
+ # Low-resolution augmentation
58
+ apply_low_res = random.random() <= low_res_prob
59
+ low_res_scale = random.uniform(*low_res_scale_range) if apply_low_res else 1.0
60
+
61
+ return {
62
+ 'scale': scale,
63
+ 'rot': rot,
64
+ 'color_scale': color_scale,
65
+ 'do_flip': do_flip,
66
+ 'tx': tx,
67
+ 'ty': ty,
68
+ 'do_extreme_crop': do_extreme_crop,
69
+ 'extreme_crop_lvl': extreme_crop_lvl if do_extreme_crop else 0,
70
+ 'noise_std': noise_std,
71
+ 'motion_blur_kernel_size': motion_blur_kernel_size,
72
+ 'low_res_scale': low_res_scale # Added low-res scale parameter
73
+ }
74
+
75
+
76
+ def rotate_2d(pt_2d, rot_rad):
77
+ x = pt_2d[0]
78
+ y = pt_2d[1]
79
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
80
+ xx = x * cs - y * sn
81
+ yy = x * sn + y * cs
82
+ return np.array([xx, yy], dtype=np.float32)
83
+
84
+
85
+ def gen_trans_from_patch_cv(c_x, c_y, src_width, src_height, dst_width, dst_height, scale, rot, inv=False):
86
+ # augment size with scale
87
+ src_w = src_width * scale
88
+ src_h = src_height * scale
89
+ src_center = np.array([c_x, c_y], dtype=np.float32)
90
+
91
+ # augment rotation
92
+ rot_rad = np.pi * rot / 180
93
+ src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
94
+ src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
95
+
96
+ dst_w = dst_width
97
+ dst_h = dst_height
98
+ dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
99
+ dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
100
+ dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
101
+
102
+ src = np.zeros((3, 2), dtype=np.float32)
103
+ src[0, :] = src_center
104
+ src[1, :] = src_center + src_downdir
105
+ src[2, :] = src_center + src_rightdir
106
+
107
+ dst = np.zeros((3, 2), dtype=np.float32)
108
+ dst[0, :] = dst_center
109
+ dst[1, :] = dst_center + dst_downdir
110
+ dst[2, :] = dst_center + dst_rightdir
111
+
112
+ if inv:
113
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
114
+ else:
115
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
116
+
117
+ trans = trans.astype(np.float32)
118
+ return trans
119
+
120
+
121
+ def generate_patch_image_contact(cvimg, bbox, scale, rot, do_flip, out_shape, tx=0.0, ty=0.0, bkg_color='black'):
122
+ img = cvimg.copy()
123
+ img_height, img_width, img_channels = img.shape
124
+
125
+ bb_c_x = float(bbox[0] + 0.5 * bbox[2])
126
+ bb_c_y = float(bbox[1] + 0.5 * bbox[3])
127
+ bb_width = float(bbox[2])
128
+ bb_height = float(bbox[3])
129
+
130
+ if bkg_color == 'white':
131
+ borderMode=cv2.BORDER_CONSTANT
132
+ borderValue=(255, 255, 255)
133
+ else:
134
+ borderMode=cv2.BORDER_CONSTANT
135
+ borderValue=(0, 0, 0)
136
+
137
+ if do_flip:
138
+ img = img[:, ::-1, :]
139
+ bb_c_x = img_width - bb_c_x - 1
140
+
141
+ # Add translation offset
142
+ bb_c_x += tx * img_width
143
+ bb_c_y += ty * img_height
144
+
145
+ trans = gen_trans_from_patch_cv(bb_c_x, bb_c_y, bb_width, bb_height,
146
+ out_shape[1], out_shape[0], scale, rot)
147
+ img_patch = cv2.warpAffine(img, trans, (int(out_shape[1]), int(out_shape[0])), flags=cv2.INTER_LINEAR, borderMode=borderMode, borderValue=borderValue)
148
+ img_patch = img_patch.astype(np.float32)
149
+ inv_trans = gen_trans_from_patch_cv(bb_c_x, bb_c_y, bb_width, bb_height,
150
+ out_shape[1], out_shape[0], scale, rot, inv=True)
151
+
152
+ return img_patch, trans, inv_trans
153
+
154
+
155
+ def augmentation_contact(img, bbox, data_split, enforce_flip=None, bkg_color='black'):
156
+ if data_split == 'train':
157
+ aug_params = get_aug_config_contact()
158
+ else:
159
+ aug_params = {
160
+ 'scale': 1.0,
161
+ 'rot': 0.0,
162
+ 'color_scale': np.array([1, 1, 1]),
163
+ 'do_flip': False,
164
+ 'tx': 0.0,
165
+ 'ty': 0.0,
166
+ 'do_extreme_crop': False,
167
+ 'extreme_crop_lvl': 0.0,
168
+ 'noise_std': 0.0,
169
+ 'motion_blur_kernel_size': 0,
170
+ 'low_res_scale': 1.0 # No low-res in non-training mode
171
+ }
172
+
173
+ # Enforce flip if specified
174
+ if enforce_flip is not None:
175
+ aug_params['do_flip'] = enforce_flip
176
+
177
+ # Apply geometric augmentations (scaling, rotation, flipping)
178
+ img, trans, inv_trans = generate_patch_image_contact(
179
+ img, bbox, aug_params['scale'], aug_params['rot'],
180
+ aug_params['do_flip'], cfg.MODEL.input_img_shape,
181
+ aug_params['tx'], aug_params['ty'], bkg_color
182
+ )
183
+
184
+ # Apply low-resolution augmentation
185
+ if aug_params['low_res_scale'] < 1.0: # Only apply if scaling down
186
+ img = apply_low_res(img, aug_params['low_res_scale'])
187
+
188
+ # Apply color augmentation
189
+ img = np.clip(img * aug_params['color_scale'][None, None, :], 0, 255)
190
+
191
+ # Apply extreme cropping
192
+ if aug_params['do_extreme_crop']:
193
+ img = apply_extreme_crop(img, aug_params['extreme_crop_lvl'])
194
+
195
+ # Apply noise augmentation
196
+ if aug_params['noise_std'] > 0:
197
+ img = add_gaussian_noise(img, aug_params['noise_std'])
198
+
199
+ # Apply motion blur augmentation
200
+ if aug_params['motion_blur_kernel_size'] > 0:
201
+ img = apply_motion_blur(img, aug_params['motion_blur_kernel_size'])
202
+
203
+ return img, trans, inv_trans, aug_params['rot'], aug_params['do_flip'], aug_params['color_scale']
204
+
205
+
206
+ def apply_extreme_crop(img, crop_lvl):
207
+ """Extreme cropping: Aggressively crop the image."""
208
+ h, w = img.shape[:2]
209
+ crop_size = max(1, int(min(h, w) * (1 - crop_lvl))) # Prevent zero-size crops
210
+ start_x = random.randint(0, max(0, w - crop_size))
211
+ start_y = random.randint(0, max(0, h - crop_size))
212
+ cropped_img = img[start_y:start_y + crop_size, start_x:start_x + crop_size]
213
+
214
+ # Preserve aspect ratio during resizing
215
+ return cv2.resize(cropped_img, (w, h), interpolation=cv2.INTER_LINEAR)
216
+
217
+
218
+ def add_gaussian_noise(img, noise_std):
219
+ """Add Gaussian noise to the image with proper scaling for data type."""
220
+ noise = np.random.normal(0, noise_std, img.shape).astype(np.float32)
221
+
222
+ if img.dtype == np.uint8:
223
+ noisy_img = np.clip(img + noise * 255, 0, 255).astype(np.uint8)
224
+ elif img.dtype == np.float32:
225
+ noisy_img = np.clip(img + noise, 0.0, 1.0).astype(np.float32)
226
+ elif img.dtype == np.float64:
227
+ noisy_img = np.clip(img + noise, 0.0, 1.0).astype(np.float64)
228
+ else:
229
+ raise TypeError("Unsupported image dtype. Expected uint8 or float32.")
230
+
231
+ return noisy_img
232
+
233
+
234
+ def apply_motion_blur(img, kernel_size):
235
+ """Apply motion blur to the image with a random direction."""
236
+ kernel = np.zeros((kernel_size, kernel_size))
237
+ direction = random.choice(['horizontal', 'vertical', 'diagonal'])
238
+
239
+ if direction == 'horizontal':
240
+ kernel[(kernel_size - 1) // 2, :] = np.ones(kernel_size)
241
+ elif direction == 'vertical':
242
+ kernel[:, (kernel_size - 1) // 2] = np.ones(kernel_size)
243
+ elif direction == 'diagonal':
244
+ np.fill_diagonal(kernel, 1)
245
+
246
+ kernel /= kernel_size # Normalize the kernel
247
+ return cv2.filter2D(img, -1, kernel, borderType=cv2.BORDER_REFLECT)
248
+
249
+
250
+ def apply_low_res(img, scale_factor=0.25):
251
+ """Simulate low-resolution effect by downsampling and upsampling."""
252
+ if not (0 < scale_factor < 1):
253
+ raise ValueError("scale_factor should be between 0 and 1.")
254
+
255
+ h, w = img.shape[:2]
256
+
257
+ # Calculate target dimensions for downsampling
258
+ downsampled_size = (max(1, int(w * scale_factor)), max(1, int(h * scale_factor)))
259
+
260
+ # Downsample using INTER_AREA for better quality in aggressive downsampling
261
+ low_res_img = cv2.resize(img, downsampled_size, interpolation=cv2.INTER_AREA)
262
+
263
+ # Upsample using INTER_NEAREST for strong pixelation effect
264
+ return cv2.resize(low_res_img, (w, h), interpolation=cv2.INTER_NEAREST).astype(img.dtype)
265
+
266
+
267
+ def process_human_model_output_orig(human_model_param, cam_param):
268
+ pose, shape, trans = human_model_param['pose'], human_model_param['shape'], human_model_param['trans']
269
+ hand_type = human_model_param['hand_type']
270
+ trans = human_model_param['trans']
271
+ pose = torch.FloatTensor(pose).view(-1,3); shape = torch.FloatTensor(shape).view(1,-1); # mano parameters (pose: 48 dimension, shape: 10 dimension)
272
+ trans = torch.FloatTensor(trans).view(1,-1) # translation vector
273
+
274
+ # apply camera extrinsic (rotation)
275
+ # merge root pose and camera rotation
276
+ if 'R' in cam_param:
277
+ R = np.array(cam_param['R'], dtype=np.float32).reshape(3,3)
278
+ root_pose = pose[mano.orig_root_joint_idx,:].numpy()
279
+ root_pose, _ = cv2.Rodrigues(root_pose)
280
+ root_pose, _ = cv2.Rodrigues(np.dot(R,root_pose))
281
+ pose[mano.orig_root_joint_idx] = torch.from_numpy(root_pose).view(3)
282
+
283
+ # get root joint coordinate
284
+ root_pose = pose[mano.orig_root_joint_idx].view(1,3)
285
+ hand_pose = torch.cat((pose[:mano.orig_root_joint_idx,:], pose[mano.orig_root_joint_idx+1:,:])).view(1,-1)
286
+ with torch.no_grad():
287
+ output = mano.layer[hand_type](betas=shape, hand_pose=hand_pose, global_orient=root_pose, transl=trans)
288
+ mesh_coord = output.vertices[0].numpy()
289
+ joint_coord = np.dot(mano.joint_regressor, mesh_coord)
290
+
291
+ # apply camera exrinsic (translation)
292
+ # compenstate rotation (translation from origin to root joint was not cancled)
293
+ if 'R' in cam_param and 't' in cam_param:
294
+ R, t = np.array(cam_param['R'], dtype=np.float32).reshape(3,3), np.array(cam_param['t'], dtype=np.float32).reshape(1,3)
295
+ root_coord = joint_coord[mano.root_joint_idx,None,:]
296
+ joint_coord = joint_coord - root_coord + np.dot(R, root_coord.transpose(1,0)).transpose(1,0) + t
297
+ mesh_coord = mesh_coord - root_coord + np.dot(R, root_coord.transpose(1,0)).transpose(1,0) + t
298
+
299
+
300
+ joint_cam_orig = joint_coord.copy()
301
+ mesh_cam_orig = mesh_coord.copy()
302
+ pose_orig, shape_orig, trans_orig = torch.cat((root_pose, hand_pose), dim=-1)[0].detach().cpu().numpy(), shape[0].detach().cpu().numpy(), trans[0].detach().cpu().numpy()
303
+
304
+ return mesh_cam_orig, joint_cam_orig, pose_orig, shape_orig, trans_orig
305
+
306
+
307
+ def mask2bbox(mask, expansion_factor=1.0):
308
+ # Find non-zero elements (object pixels)
309
+ coords = np.argwhere(mask)
310
+
311
+ # Extract bounding box coordinates
312
+ y_min, x_min = coords.min(axis=0)
313
+ y_max, x_max = coords.max(axis=0)
314
+
315
+ # Compute width and height
316
+ width = x_max - x_min + 1
317
+ height = y_max - y_min + 1
318
+
319
+ # Expand bounding box
320
+ if expansion_factor > 0:
321
+ x_min = max(0, int(x_min - width * expansion_factor / 2))
322
+ y_min = max(0, int(y_min - height * expansion_factor / 2))
323
+ x_max = min(mask.shape[1] - 1, int(x_max + width * expansion_factor / 2))
324
+ y_max = min(mask.shape[0] - 1, int(y_max + height * expansion_factor / 2))
325
+
326
+ # Recalculate width and height after expansion
327
+ width = x_max - x_min + 1
328
+ height = y_max - y_min + 1
329
+
330
+ return (x_min, y_min, width, height)
lib/utils/smplx/LICENSE ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ License
2
+
3
+ Software Copyright License for non-commercial scientific research purposes
4
+ Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use the SMPL-X/SMPLify-X model, data and software, (the "Model & Software"), including 3D meshes, blend weights, blend shapes, textures, software, scripts, and animations. By downloading and/or using the Model & Software (including downloading, cloning, installing, and any other use of this github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Model & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License
5
+
6
+ Ownership / Licensees
7
+ The Software and the associated materials has been developed at the
8
+
9
+ Max Planck Institute for Intelligent Systems (hereinafter "MPI").
10
+
11
+ Any copyright or patent right is owned by and proprietary material of the
12
+
13
+ Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”)
14
+
15
+ hereinafter the “Licensor”.
16
+
17
+ License Grant
18
+ Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right:
19
+
20
+ To install the Model & Software on computers owned, leased or otherwise controlled by you and/or your organization;
21
+ To use the Model & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
22
+ Any other use, in particular any use for commercial purposes, is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Model & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission.
23
+
24
+ The Model & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Model & Software to train methods/algorithms/neural networks/etc. for commercial use of any kind. By downloading the Model & Software, you agree not to reverse engineer it.
25
+
26
+ No Distribution
27
+ The Model & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only.
28
+
29
+ Disclaimer of Representations and Warranties
30
+ You expressly acknowledge and agree that the Model & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Model & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE MODEL & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Model & Software, (ii) that the use of the Model & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Model & Software will not cause any damage of any kind to you or a third party.
31
+
32
+ Limitation of Liability
33
+ Because this Model & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage.
34
+ Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded.
35
+ Patent claims generated through the usage of the Model & Software cannot be directed towards the copyright holders.
36
+ The Model & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Model & Software and is not responsible for any problems such modifications cause.
37
+
38
+ No Maintenance Services
39
+ You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Model & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Model & Software at any time.
40
+
41
+ Defects of the Model & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication.
42
+
43
+ Publications using the Model & Software
44
+ You acknowledge that the Model & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Model & Software.
45
+
46
+ Citation:
47
+
48
+
49
+ @inproceedings{SMPL-X:2019,
50
+ title = {Expressive Body Capture: 3D Hands, Face, and Body from a Single Image},
51
+ author = {Pavlakos, Georgios and Choutas, Vasileios and Ghorbani, Nima and Bolkart, Timo and Osman, Ahmed A. A. and Tzionas, Dimitrios and Black, Michael J.},
52
+ booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)},
53
+ year = {2019}
54
+ }
55
+ Commercial licensing opportunities
56
+ For commercial uses of the Software, please send email to ps-license@tue.mpg.de
57
+
58
+ This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention.
lib/utils/smplx/README.md ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## SMPL-X: A new joint 3D model of the human body, face and hands together
2
+
3
+ [[Paper Page](https://smpl-x.is.tue.mpg.de)] [[Paper](https://ps.is.tuebingen.mpg.de/uploads_file/attachment/attachment/497/SMPL-X.pdf)]
4
+ [[Supp. Mat.](https://ps.is.tuebingen.mpg.de/uploads_file/attachment/attachment/498/SMPL-X-supp.pdf)]
5
+
6
+ ![SMPL-X Examples](./images/teaser_fig.png)
7
+
8
+ ## Table of Contents
9
+ * [License](#license)
10
+ * [Description](#description)
11
+ * [Installation](#installation)
12
+ * [Downloading the model](#downloading-the-model)
13
+ * [Loading SMPL-X, SMPL+H and SMPL](#loading-smpl-x-smplh-and-smpl)
14
+ * [SMPL and SMPL+H setup](#smpl-and-smplh-setup)
15
+ * [Model loading](https://github.com/vchoutas/smplx#model-loading)
16
+ * [MANO and FLAME correspondences](#mano-and-flame-correspondences)
17
+ * [Example](#example)
18
+ * [Citation](#citation)
19
+ * [Acknowledgments](#acknowledgments)
20
+ * [Contact](#contact)
21
+
22
+ ## License
23
+
24
+ Software Copyright License for **non-commercial scientific research purposes**.
25
+ Please read carefully the [terms and conditions](https://github.com/vchoutas/smplx/blob/master/LICENSE) and any accompanying documentation before you download and/or use the SMPL-X/SMPLify-X model, data and software, (the "Model & Software"), including 3D meshes, blend weights, blend shapes, textures, software, scripts, and animations. By downloading and/or using the Model & Software (including downloading, cloning, installing, and any other use of this github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Model & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this [License](./LICENSE).
26
+
27
+ ## Disclaimer
28
+
29
+ The original images used for the figures 1 and 2 of the paper can be found in this link.
30
+ The images in the paper are used under license from gettyimages.com.
31
+ We have acquired the right to use them in the publication, but redistribution is not allowed.
32
+ Please follow the instructions on the given link to acquire right of usage.
33
+ Our results are obtained on the 483 × 724 pixels resolution of the original images.
34
+
35
+ ## Description
36
+
37
+ *SMPL-X* (SMPL eXpressive) is a unified body model with shape parameters trained jointly for the
38
+ face, hands and body. *SMPL-X* uses standard vertex based linear blend skinning with learned corrective blend
39
+ shapes, has N = 10, 475 vertices and K = 54 joints,
40
+ which include joints for the neck, jaw, eyeballs and fingers.
41
+ SMPL-X is defined by a function M(θ, β, ψ), where θ is the pose parameters, β the shape parameters and
42
+ ψ the facial expression parameters.
43
+
44
+
45
+ ## Installation
46
+
47
+ To install the model please follow the next steps in the specified order:
48
+ 1. To install from PyPi simply run:
49
+ ```Shell
50
+ pip install smplx[all]
51
+ ```
52
+ 2. Clone this repository and install it using the *setup.py* script:
53
+ ```Shell
54
+ git clone https://github.com/vchoutas/smplx
55
+ python setup.py install
56
+ ```
57
+
58
+ ## Downloading the model
59
+
60
+ To download the *SMPL-X* model go to [this project website](https://smpl-x.is.tue.mpg.de) and register to get access to the downloads section.
61
+
62
+ To download the *SMPL+H* model go to [this project website](http://mano.is.tue.mpg.de) and register to get access to the downloads section.
63
+
64
+ To download the *SMPL* model go to [this](http://smpl.is.tue.mpg.de) (male and female models) and [this](http://smplify.is.tue.mpg.de) (gender neutral model) project website and register to get access to the downloads section.
65
+
66
+ ## Loading SMPL-X, SMPL+H and SMPL
67
+
68
+ ### SMPL and SMPL+H setup
69
+
70
+ The loader gives the option to use any of the SMPL-X, SMPL+H, SMPL, and MANO models. Depending on the model you want to use, please follow the respective download instructions. To switch between MANO, SMPL, SMPL+H and SMPL-X just change the *model_path* or *model_type* parameters. For more details please check the docs of the model classes.
71
+ Before using SMPL and SMPL+H you should follow the instructions in [tools/README.md](./tools/README.md) to remove the
72
+ Chumpy objects from both model pkls, as well as merge the MANO parameters with SMPL+H.
73
+
74
+ ### Model loading
75
+
76
+ You can either use the [create](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L54)
77
+ function from [body_models](./smplx/body_models.py) or directly call the constructor for the
78
+ [SMPL](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L106),
79
+ [SMPL+H](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L395) and
80
+ [SMPL-X](https://github.com/vchoutas/smplx/blob/c63c02b478c5c6f696491ed9167e3af6b08d89b1/smplx/body_models.py#L628) model. The path to the model can either be the path to the file with the parameters or a directory with the following structure:
81
+ ```bash
82
+ models
83
+ ├── smpl
84
+ │   ├── SMPL_FEMALE.pkl
85
+ │   ��── SMPL_MALE.pkl
86
+ │   └── SMPL_NEUTRAL.pkl
87
+ ├── smplh
88
+ │   ├── SMPLH_FEMALE.pkl
89
+ │   └── SMPLH_MALE.pkl
90
+ ├── mano
91
+ | ├── MANO_RIGHT.pkl
92
+ | └── MANO_LEFT.pkl
93
+ └── smplx
94
+ ├── SMPLX_FEMALE.npz
95
+ ├── SMPLX_FEMALE.pkl
96
+ ├── SMPLX_MALE.npz
97
+ ├── SMPLX_MALE.pkl
98
+ ├── SMPLX_NEUTRAL.npz
99
+ └── SMPLX_NEUTRAL.pkl
100
+ ```
101
+
102
+
103
+ ## MANO and FLAME correspondences
104
+
105
+ The vertex correspondences between SMPL-X and MANO, FLAME can be downloaded
106
+ from [the project website](https://smpl-x.is.tue.mpg.de). If you have extracted
107
+ the correspondence data in the folder *correspondences*, then use the following
108
+ scripts to visualize them:
109
+
110
+ 1. To view MANO correspondences run the following command:
111
+
112
+ ```
113
+ python examples/vis_mano_vertices.py --model-folder $SMPLX_FOLDER --corr-fname correspondences/MANO_SMPLX_vertex_ids.pkl
114
+ ```
115
+
116
+ 2. To view FLAME correspondences run the following command:
117
+
118
+ ```
119
+ python examples/vis_flame_vertices.py --model-folder $SMPLX_FOLDER --corr-fname correspondences/SMPL-X__FLAME_vertex_ids.npy
120
+ ```
121
+
122
+ ## Example
123
+
124
+ After installing the *smplx* package and downloading the model parameters you should be able to run the *demo.py*
125
+ script to visualize the results. For this step you have to install the [pyrender](https://pyrender.readthedocs.io/en/latest/index.html) and [trimesh](https://trimsh.org/) packages.
126
+
127
+ `python examples/demo.py --model-folder $SMPLX_FOLDER --plot-joints=True --gender="neutral"`
128
+
129
+ ![SMPL-X Examples](./images/example.png)
130
+
131
+ ## Citation
132
+
133
+ Depending on which model is loaded for your project, i.e. SMPL-X or SMPL+H or SMPL, please cite the most relevant work below, listed in the same order:
134
+
135
+ ```
136
+ @inproceedings{SMPL-X:2019,
137
+ title = {Expressive Body Capture: 3D Hands, Face, and Body from a Single Image},
138
+ author = {Pavlakos, Georgios and Choutas, Vasileios and Ghorbani, Nima and Bolkart, Timo and Osman, Ahmed A. A. and Tzionas, Dimitrios and Black, Michael J.},
139
+ booktitle = {Proceedings IEEE Conf. on Computer Vision and Pattern Recognition (CVPR)},
140
+ year = {2019}
141
+ }
142
+ ```
143
+
144
+ ```
145
+ @article{MANO:SIGGRAPHASIA:2017,
146
+ title = {Embodied Hands: Modeling and Capturing Hands and Bodies Together},
147
+ author = {Romero, Javier and Tzionas, Dimitrios and Black, Michael J.},
148
+ journal = {ACM Transactions on Graphics, (Proc. SIGGRAPH Asia)},
149
+ volume = {36},
150
+ number = {6},
151
+ series = {245:1--245:17},
152
+ month = nov,
153
+ year = {2017},
154
+ month_numeric = {11}
155
+ }
156
+ ```
157
+
158
+ ```
159
+ @article{SMPL:2015,
160
+ author = {Loper, Matthew and Mahmood, Naureen and Romero, Javier and Pons-Moll, Gerard and Black, Michael J.},
161
+ title = {{SMPL}: A Skinned Multi-Person Linear Model},
162
+ journal = {ACM Transactions on Graphics, (Proc. SIGGRAPH Asia)},
163
+ month = oct,
164
+ number = {6},
165
+ pages = {248:1--248:16},
166
+ publisher = {ACM},
167
+ volume = {34},
168
+ year = {2015}
169
+ }
170
+ ```
171
+
172
+ This repository was originally developed for SMPL-X / SMPLify-X (CVPR 2019), you might be interested in having a look: [https://smpl-x.is.tue.mpg.de](https://smpl-x.is.tue.mpg.de).
173
+
174
+ ## Acknowledgments
175
+
176
+ ### Facial Contour
177
+
178
+ Special thanks to [Soubhik Sanyal](https://github.com/soubhiksanyal) for sharing the Tensorflow code used for the facial
179
+ landmarks.
180
+
181
+ ## Contact
182
+ The code of this repository was implemented by [Vassilis Choutas](vassilis.choutas@tuebingen.mpg.de).
183
+
184
+ For questions, please contact [smplx@tue.mpg.de](smplx@tue.mpg.de).
185
+
186
+ For commercial licensing (and all related questions for business applications), please contact [ps-licensing@tue.mpg.de](ps-licensing@tue.mpg.de).
lib/utils/smplx/examples/demo.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ import os.path as osp
18
+ import argparse
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ import smplx
24
+
25
+
26
+ def main(model_folder,
27
+ model_type='smplx',
28
+ ext='npz',
29
+ gender='neutral',
30
+ plot_joints=False,
31
+ num_betas=10,
32
+ sample_shape=True,
33
+ sample_expression=True,
34
+ num_expression_coeffs=10,
35
+ plotting_module='pyrender',
36
+ use_face_contour=False):
37
+
38
+ model = smplx.create(model_folder, model_type=model_type,
39
+ gender=gender, use_face_contour=use_face_contour,
40
+ num_betas=num_betas,
41
+ num_expression_coeffs=num_expression_coeffs,
42
+ ext=ext)
43
+ print(model)
44
+
45
+ betas, expression = None, None
46
+ if sample_shape:
47
+ betas = torch.randn([1, model.num_betas], dtype=torch.float32)
48
+ if sample_expression:
49
+ expression = torch.randn(
50
+ [1, model.num_expression_coeffs], dtype=torch.float32)
51
+
52
+ output = model(betas=betas, expression=expression,
53
+ return_verts=True)
54
+ vertices = output.vertices.detach().cpu().numpy().squeeze()
55
+ joints = output.joints.detach().cpu().numpy().squeeze()
56
+
57
+ print('Vertices shape =', vertices.shape)
58
+ print('Joints shape =', joints.shape)
59
+
60
+ if plotting_module == 'pyrender':
61
+ import pyrender
62
+ import trimesh
63
+ vertex_colors = np.ones([vertices.shape[0], 4]) * [0.3, 0.3, 0.3, 0.8]
64
+ tri_mesh = trimesh.Trimesh(vertices, model.faces,
65
+ vertex_colors=vertex_colors)
66
+
67
+ mesh = pyrender.Mesh.from_trimesh(tri_mesh)
68
+
69
+ scene = pyrender.Scene()
70
+ scene.add(mesh)
71
+
72
+ if plot_joints:
73
+ sm = trimesh.creation.uv_sphere(radius=0.005)
74
+ sm.visual.vertex_colors = [0.9, 0.1, 0.1, 1.0]
75
+ tfs = np.tile(np.eye(4), (len(joints), 1, 1))
76
+ tfs[:, :3, 3] = joints
77
+ joints_pcl = pyrender.Mesh.from_trimesh(sm, poses=tfs)
78
+ scene.add(joints_pcl)
79
+
80
+ pyrender.Viewer(scene, use_raymond_lighting=True)
81
+ elif plotting_module == 'matplotlib':
82
+ from matplotlib import pyplot as plt
83
+ from mpl_toolkits.mplot3d import Axes3D
84
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
85
+
86
+ fig = plt.figure()
87
+ ax = fig.add_subplot(111, projection='3d')
88
+
89
+ mesh = Poly3DCollection(vertices[model.faces], alpha=0.1)
90
+ face_color = (1.0, 1.0, 0.9)
91
+ edge_color = (0, 0, 0)
92
+ mesh.set_edgecolor(edge_color)
93
+ mesh.set_facecolor(face_color)
94
+ ax.add_collection3d(mesh)
95
+ ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], color='r')
96
+
97
+ if plot_joints:
98
+ ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], alpha=0.1)
99
+ plt.show()
100
+ elif plotting_module == 'open3d':
101
+ import open3d as o3d
102
+
103
+ mesh = o3d.geometry.TriangleMesh()
104
+ mesh.vertices = o3d.utility.Vector3dVector(
105
+ vertices)
106
+ mesh.triangles = o3d.utility.Vector3iVector(model.faces)
107
+ mesh.compute_vertex_normals()
108
+ mesh.paint_uniform_color([0.3, 0.3, 0.3])
109
+
110
+ geometry = [mesh]
111
+ if plot_joints:
112
+ joints_pcl = o3d.geometry.PointCloud()
113
+ joints_pcl.points = o3d.utility.Vector3dVector(joints)
114
+ joints_pcl.paint_uniform_color([0.7, 0.3, 0.3])
115
+ geometry.append(joints_pcl)
116
+
117
+ o3d.visualization.draw_geometries(geometry)
118
+ else:
119
+ raise ValueError('Unknown plotting_module: {}'.format(plotting_module))
120
+
121
+
122
+ if __name__ == '__main__':
123
+ parser = argparse.ArgumentParser(description='SMPL-X Demo')
124
+
125
+ parser.add_argument('--model-folder', required=True, type=str,
126
+ help='The path to the model folder')
127
+ parser.add_argument('--model-type', default='smplx', type=str,
128
+ choices=['smpl', 'smplh', 'smplx', 'mano', 'flame'],
129
+ help='The type of model to load')
130
+ parser.add_argument('--gender', type=str, default='neutral',
131
+ help='The gender of the model')
132
+ parser.add_argument('--num-betas', default=10, type=int,
133
+ dest='num_betas',
134
+ help='Number of shape coefficients.')
135
+ parser.add_argument('--num-expression-coeffs', default=10, type=int,
136
+ dest='num_expression_coeffs',
137
+ help='Number of expression coefficients.')
138
+ parser.add_argument('--plotting-module', type=str, default='pyrender',
139
+ dest='plotting_module',
140
+ choices=['pyrender', 'matplotlib', 'open3d'],
141
+ help='The module to use for plotting the result')
142
+ parser.add_argument('--ext', type=str, default='npz',
143
+ help='Which extension to use for loading')
144
+ parser.add_argument('--plot-joints', default=False,
145
+ type=lambda arg: arg.lower() in ['true', '1'],
146
+ help='The path to the model folder')
147
+ parser.add_argument('--sample-shape', default=True,
148
+ dest='sample_shape',
149
+ type=lambda arg: arg.lower() in ['true', '1'],
150
+ help='Sample a random shape')
151
+ parser.add_argument('--sample-expression', default=True,
152
+ dest='sample_expression',
153
+ type=lambda arg: arg.lower() in ['true', '1'],
154
+ help='Sample a random expression')
155
+ parser.add_argument('--use-face-contour', default=False,
156
+ type=lambda arg: arg.lower() in ['true', '1'],
157
+ help='Compute the contour of the face')
158
+
159
+ args = parser.parse_args()
160
+
161
+ model_folder = osp.expanduser(osp.expandvars(args.model_folder))
162
+ model_type = args.model_type
163
+ plot_joints = args.plot_joints
164
+ use_face_contour = args.use_face_contour
165
+ gender = args.gender
166
+ ext = args.ext
167
+ plotting_module = args.plotting_module
168
+ num_betas = args.num_betas
169
+ num_expression_coeffs = args.num_expression_coeffs
170
+ sample_shape = args.sample_shape
171
+ sample_expression = args.sample_expression
172
+
173
+ main(model_folder, model_type, ext=ext,
174
+ gender=gender, plot_joints=plot_joints,
175
+ num_betas=num_betas,
176
+ num_expression_coeffs=num_expression_coeffs,
177
+ sample_shape=sample_shape,
178
+ sample_expression=sample_expression,
179
+ plotting_module=plotting_module,
180
+ use_face_contour=use_face_contour)
lib/utils/smplx/examples/demo_layers.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ import os.path as osp
18
+ import argparse
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ import smplx
24
+
25
+
26
+ def main(model_folder,
27
+ model_type='smplx',
28
+ ext='npz',
29
+ gender='neutral',
30
+ plot_joints=False,
31
+ num_betas=10,
32
+ sample_shape=True,
33
+ sample_expression=True,
34
+ num_expression_coeffs=10,
35
+ plotting_module='pyrender',
36
+ use_face_contour=False):
37
+
38
+ model = smplx.build_layer(
39
+ model_folder, model_type=model_type,
40
+ gender=gender, use_face_contour=use_face_contour,
41
+ num_betas=num_betas,
42
+ num_expression_coeffs=num_expression_coeffs,
43
+ ext=ext)
44
+ print(model)
45
+
46
+ betas, expression = None, None
47
+ if sample_shape:
48
+ betas = torch.randn([1, model.num_betas], dtype=torch.float32)
49
+ if sample_expression:
50
+ expression = torch.randn(
51
+ [1, model.num_expression_coeffs], dtype=torch.float32)
52
+
53
+ output = model(betas=betas, expression=expression,
54
+ return_verts=True)
55
+ vertices = output.vertices.detach().cpu().numpy().squeeze()
56
+ joints = output.joints.detach().cpu().numpy().squeeze()
57
+
58
+ print('Vertices shape =', vertices.shape)
59
+ print('Joints shape =', joints.shape)
60
+
61
+ if plotting_module == 'pyrender':
62
+ import pyrender
63
+ import trimesh
64
+ vertex_colors = np.ones([vertices.shape[0], 4]) * [0.3, 0.3, 0.3, 0.8]
65
+ tri_mesh = trimesh.Trimesh(vertices, model.faces,
66
+ vertex_colors=vertex_colors)
67
+
68
+ mesh = pyrender.Mesh.from_trimesh(tri_mesh)
69
+
70
+ scene = pyrender.Scene()
71
+ scene.add(mesh)
72
+
73
+ if plot_joints:
74
+ sm = trimesh.creation.uv_sphere(radius=0.005)
75
+ sm.visual.vertex_colors = [0.9, 0.1, 0.1, 1.0]
76
+ tfs = np.tile(np.eye(4), (len(joints), 1, 1))
77
+ tfs[:, :3, 3] = joints
78
+ joints_pcl = pyrender.Mesh.from_trimesh(sm, poses=tfs)
79
+ scene.add(joints_pcl)
80
+
81
+ pyrender.Viewer(scene, use_raymond_lighting=True)
82
+ elif plotting_module == 'matplotlib':
83
+ from matplotlib import pyplot as plt
84
+ from mpl_toolkits.mplot3d import Axes3D
85
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
86
+
87
+ fig = plt.figure()
88
+ ax = fig.add_subplot(111, projection='3d')
89
+
90
+ mesh = Poly3DCollection(vertices[model.faces], alpha=0.1)
91
+ face_color = (1.0, 1.0, 0.9)
92
+ edge_color = (0, 0, 0)
93
+ mesh.set_edgecolor(edge_color)
94
+ mesh.set_facecolor(face_color)
95
+ ax.add_collection3d(mesh)
96
+ ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], color='r')
97
+
98
+ if plot_joints:
99
+ ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], alpha=0.1)
100
+ plt.show()
101
+ elif plotting_module == 'open3d':
102
+ import open3d as o3d
103
+
104
+ mesh = o3d.geometry.TriangleMesh()
105
+ mesh.vertices = o3d.utility.Vector3dVector(
106
+ vertices)
107
+ mesh.triangles = o3d.utility.Vector3iVector(model.faces)
108
+ mesh.compute_vertex_normals()
109
+ mesh.paint_uniform_color([0.3, 0.3, 0.3])
110
+
111
+ geometry = [mesh]
112
+ if plot_joints:
113
+ joints_pcl = o3d.geometry.PointCloud()
114
+ joints_pcl.points = o3d.utility.Vector3dVector(joints)
115
+ joints_pcl.paint_uniform_color([0.7, 0.3, 0.3])
116
+ geometry.append(joints_pcl)
117
+
118
+ o3d.visualization.draw_geometries(geometry)
119
+ else:
120
+ raise ValueError('Unknown plotting_module: {}'.format(plotting_module))
121
+
122
+
123
+ if __name__ == '__main__':
124
+ parser = argparse.ArgumentParser(description='SMPL-X Demo')
125
+
126
+ parser.add_argument('--model-folder', required=True, type=str,
127
+ help='The path to the model folder')
128
+ parser.add_argument('--model-type', default='smplx', type=str,
129
+ choices=['smpl', 'smplh', 'smplx', 'mano', 'flame'],
130
+ help='The type of model to load')
131
+ parser.add_argument('--gender', type=str, default='neutral',
132
+ help='The gender of the model')
133
+ parser.add_argument('--num-betas', default=10, type=int,
134
+ dest='num_betas',
135
+ help='Number of shape coefficients.')
136
+ parser.add_argument('--num-expression-coeffs', default=10, type=int,
137
+ dest='num_expression_coeffs',
138
+ help='Number of expression coefficients.')
139
+ parser.add_argument('--plotting-module', type=str, default='pyrender',
140
+ dest='plotting_module',
141
+ choices=['pyrender', 'matplotlib', 'open3d'],
142
+ help='The module to use for plotting the result')
143
+ parser.add_argument('--ext', type=str, default='npz',
144
+ help='Which extension to use for loading')
145
+ parser.add_argument('--plot-joints', default=False,
146
+ type=lambda arg: arg.lower() in ['true', '1'],
147
+ help='The path to the model folder')
148
+ parser.add_argument('--sample-shape', default=True,
149
+ dest='sample_shape',
150
+ type=lambda arg: arg.lower() in ['true', '1'],
151
+ help='Sample a random shape')
152
+ parser.add_argument('--sample-expression', default=True,
153
+ dest='sample_expression',
154
+ type=lambda arg: arg.lower() in ['true', '1'],
155
+ help='Sample a random expression')
156
+ parser.add_argument('--use-face-contour', default=False,
157
+ type=lambda arg: arg.lower() in ['true', '1'],
158
+ help='Compute the contour of the face')
159
+
160
+ args = parser.parse_args()
161
+
162
+ model_folder = osp.expanduser(osp.expandvars(args.model_folder))
163
+ model_type = args.model_type
164
+ plot_joints = args.plot_joints
165
+ use_face_contour = args.use_face_contour
166
+ gender = args.gender
167
+ ext = args.ext
168
+ plotting_module = args.plotting_module
169
+ num_betas = args.num_betas
170
+ num_expression_coeffs = args.num_expression_coeffs
171
+ sample_shape = args.sample_shape
172
+ sample_expression = args.sample_expression
173
+
174
+ main(model_folder, model_type, ext=ext,
175
+ gender=gender, plot_joints=plot_joints,
176
+ num_betas=num_betas,
177
+ num_expression_coeffs=num_expression_coeffs,
178
+ sample_shape=sample_shape,
179
+ sample_expression=sample_expression,
180
+ plotting_module=plotting_module,
181
+ use_face_contour=use_face_contour)
lib/utils/smplx/examples/vis_flame_vertices.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ import os.path as osp
18
+ import argparse
19
+ import pickle
20
+
21
+ import numpy as np
22
+ import torch
23
+ import open3d as o3d
24
+
25
+ import smplx
26
+
27
+
28
+ def main(model_folder, corr_fname, ext='npz',
29
+ head_color=(0.3, 0.3, 0.6),
30
+ gender='neutral'):
31
+
32
+ head_idxs = np.load(corr_fname)
33
+
34
+ model = smplx.create(model_folder, model_type='smplx',
35
+ gender=gender,
36
+ ext=ext)
37
+ betas = torch.zeros([1, 10], dtype=torch.float32)
38
+ expression = torch.zeros([1, 10], dtype=torch.float32)
39
+
40
+ output = model(betas=betas, expression=expression,
41
+ return_verts=True)
42
+ vertices = output.vertices.detach().cpu().numpy().squeeze()
43
+ joints = output.joints.detach().cpu().numpy().squeeze()
44
+
45
+ print('Vertices shape =', vertices.shape)
46
+ print('Joints shape =', joints.shape)
47
+
48
+ mesh = o3d.geometry.TriangleMesh()
49
+ mesh.vertices = o3d.utility.Vector3dVector(vertices)
50
+ mesh.triangles = o3d.utility.Vector3iVector(model.faces)
51
+ mesh.compute_vertex_normals()
52
+
53
+ colors = np.ones_like(vertices) * [0.3, 0.3, 0.3]
54
+ colors[head_idxs] = head_color
55
+
56
+ mesh.vertex_colors = o3d.utility.Vector3dVector(colors)
57
+
58
+ o3d.visualization.draw_geometries([mesh])
59
+
60
+
61
+ if __name__ == '__main__':
62
+ parser = argparse.ArgumentParser(description='SMPL-X Demo')
63
+
64
+ parser.add_argument('--model-folder', required=True, type=str,
65
+ help='The path to the model folder')
66
+ parser.add_argument('--corr-fname', required=True, type=str,
67
+ dest='corr_fname',
68
+ help='Filename with the head correspondences')
69
+ parser.add_argument('--gender', type=str, default='neutral',
70
+ help='The gender of the model')
71
+ parser.add_argument('--ext', type=str, default='npz',
72
+ help='Which extension to use for loading')
73
+ parser.add_argument('--head', default='right',
74
+ choices=['right', 'left'],
75
+ type=str, help='Which head to plot')
76
+ parser.add_argument('--head-color', type=float, nargs=3, dest='head_color',
77
+ default=(0.3, 0.3, 0.6),
78
+ help='Color for the head vertices')
79
+
80
+ args = parser.parse_args()
81
+
82
+ model_folder = osp.expanduser(osp.expandvars(args.model_folder))
83
+ corr_fname = args.corr_fname
84
+ gender = args.gender
85
+ ext = args.ext
86
+ head = args.head
87
+ head_color = args.head_color
88
+
89
+ main(model_folder, corr_fname, ext=ext,
90
+ head_color=head_color,
91
+ gender=gender
92
+ )
lib/utils/smplx/examples/vis_mano_vertices.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems. All rights reserved.
14
+ #
15
+ # Contact: ps-license@tuebingen.mpg.de
16
+
17
+ import os.path as osp
18
+ import argparse
19
+ import pickle
20
+
21
+ import numpy as np
22
+ import torch
23
+ import open3d as o3d
24
+
25
+ import smplx
26
+
27
+
28
+ def main(model_folder, corr_fname, ext='npz',
29
+ hand_color=(0.3, 0.3, 0.6),
30
+ gender='neutral', hand='right'):
31
+
32
+ with open(corr_fname, 'rb') as f:
33
+ idxs_data = pickle.load(f)
34
+ if hand == 'both':
35
+ hand_idxs = np.concatenate(
36
+ [idxs_data['left_hand'], idxs_data['right_hand']]
37
+ )
38
+ else:
39
+ hand_idxs = idxs_data[f'{hand}_hand']
40
+
41
+ model = smplx.create(model_folder, model_type='smplx',
42
+ gender=gender,
43
+ ext=ext)
44
+ betas = torch.zeros([1, 10], dtype=torch.float32)
45
+ expression = torch.zeros([1, 10], dtype=torch.float32)
46
+
47
+ output = model(betas=betas, expression=expression,
48
+ return_verts=True)
49
+ vertices = output.vertices.detach().cpu().numpy().squeeze()
50
+ joints = output.joints.detach().cpu().numpy().squeeze()
51
+
52
+ print('Vertices shape =', vertices.shape)
53
+ print('Joints shape =', joints.shape)
54
+
55
+ mesh = o3d.geometry.TriangleMesh()
56
+ mesh.vertices = o3d.utility.Vector3dVector(vertices)
57
+ mesh.triangles = o3d.utility.Vector3iVector(model.faces)
58
+ mesh.compute_vertex_normals()
59
+
60
+ colors = np.ones_like(vertices) * [0.3, 0.3, 0.3]
61
+ colors[hand_idxs] = hand_color
62
+
63
+ mesh.vertex_colors = o3d.utility.Vector3dVector(colors)
64
+
65
+ o3d.visualization.draw_geometries([mesh])
66
+
67
+
68
+ if __name__ == '__main__':
69
+ parser = argparse.ArgumentParser(description='SMPL-X Demo')
70
+
71
+ parser.add_argument('--model-folder', required=True, type=str,
72
+ help='The path to the model folder')
73
+ parser.add_argument('--corr-fname', required=True, type=str,
74
+ dest='corr_fname',
75
+ help='Filename with the hand correspondences')
76
+ parser.add_argument('--gender', type=str, default='neutral',
77
+ help='The gender of the model')
78
+ parser.add_argument('--ext', type=str, default='npz',
79
+ help='Which extension to use for loading')
80
+ parser.add_argument('--hand', default='right',
81
+ choices=['right', 'left', 'both'],
82
+ type=str, help='Which hand to plot')
83
+ parser.add_argument('--hand-color', type=float, nargs=3, dest='hand_color',
84
+ default=(0.3, 0.3, 0.6),
85
+ help='Color for the hand vertices')
86
+
87
+ args = parser.parse_args()
88
+
89
+ model_folder = osp.expanduser(osp.expandvars(args.model_folder))
90
+ corr_fname = args.corr_fname
91
+ gender = args.gender
92
+ ext = args.ext
93
+ hand = args.hand
94
+ hand_color = args.hand_color
95
+
96
+ main(model_folder, corr_fname, ext=ext,
97
+ hand_color=hand_color,
98
+ gender=gender, hand=hand
99
+ )
lib/utils/smplx/setup.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems and the Max Planck Institute for Biological
14
+ # Cybernetics. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ import io
19
+ import os
20
+
21
+ from setuptools import setup
22
+
23
+ # Package meta-data.
24
+ NAME = 'smplx'
25
+ DESCRIPTION = 'PyTorch module for loading the SMPLX body model'
26
+ URL = 'http://smpl-x.is.tuebingen.mpg.de'
27
+ EMAIL = 'vassilis.choutas@tuebingen.mpg.de'
28
+ AUTHOR = 'Vassilis Choutas'
29
+ REQUIRES_PYTHON = '>=3.6.0'
30
+ VERSION = '0.1.21'
31
+
32
+ here = os.path.abspath(os.path.dirname(__file__))
33
+
34
+ try:
35
+ FileNotFoundError
36
+ except NameError:
37
+ FileNotFoundError = IOError
38
+
39
+ # Import the README and use it as the long-description.
40
+ # Note: this will only work if 'README.md' is present in your MANIFEST.in file!
41
+ try:
42
+ with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
43
+ long_description = '\n' + f.read()
44
+ except FileNotFoundError:
45
+ long_description = DESCRIPTION
46
+
47
+ # Load the package's __version__.py module as a dictionary.
48
+ about = {}
49
+ if not VERSION:
50
+ with open(os.path.join(here, NAME, '__version__.py')) as f:
51
+ exec(f.read(), about)
52
+ else:
53
+ about['__version__'] = VERSION
54
+
55
+ pyrender_reqs = ['pyrender>=0.1.23', 'trimesh>=2.37.6', 'shapely']
56
+ matplotlib_reqs = ['matplotlib']
57
+ open3d_reqs = ['open3d-python']
58
+
59
+ setup(name=NAME,
60
+ version=about['__version__'],
61
+ description=DESCRIPTION,
62
+ long_description=long_description,
63
+ long_description_content_type='text/markdown',
64
+ author=AUTHOR,
65
+ author_email=EMAIL,
66
+ python_requires=REQUIRES_PYTHON,
67
+ url=URL,
68
+ install_requires=[
69
+ 'numpy>=1.16.2',
70
+ 'torch>=1.0.1.post2',
71
+ 'torchgeometry>=0.1.2'
72
+ ],
73
+ extras_require={
74
+ 'pyrender': pyrender_reqs,
75
+ 'open3d': open3d_reqs,
76
+ 'matplotlib': matplotlib_reqs,
77
+ 'all': pyrender_reqs + matplotlib_reqs + open3d_reqs
78
+ },
79
+ packages=['smplx', 'tools'])