Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
42a2bfa
1
Parent(s):
64a0f40
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- .gitignore +4 -0
- LICENSE +201 -0
- README.md +209 -14
- __init__.py +3 -0
- app.py +391 -0
- assets/dough.mp4 +3 -0
- assets/sign.mp4 +3 -0
- assets/teaser_test.json +20 -0
- assets/two_man.mp4 +3 -0
- assets/woman_ballon.mp4 +3 -0
- config/1.3b_lora_zero_stage2_config.json +24 -0
- config/14b_lora_zero2_bf16_config.json +24 -0
- config/wan2.1/wan_civitai.yaml +39 -0
- config/wan2.2/wan_civitai_5b.yaml +41 -0
- config/wan2.2/wan_civitai_i2v.yaml +43 -0
- config/wan2.2/wan_civitai_s2v.yaml +44 -0
- config/wan2.2/wan_civitai_t2v.yaml +43 -0
- config/zero_stage2_config.json +16 -0
- config/zero_stage3_config.json +27 -0
- config/zero_stage3_config_cpu_offload.json +28 -0
- inference.py +400 -0
- install.py +45 -0
- pyproject.toml +15 -0
- requirements.txt +26 -0
- scripts/local_style.sh +13 -0
- scripts/obj_add.sh +13 -0
- scripts/obj_rem.sh +13 -0
- scripts/parallel_infer.sh +12 -0
- videox_fun/__init__.py +0 -0
- videox_fun/api/api.py +226 -0
- videox_fun/api/api_multi_nodes.py +320 -0
- videox_fun/data/bucket_sampler.py +392 -0
- videox_fun/data/dataset_image.py +76 -0
- videox_fun/data/dataset_image_video.py +1939 -0
- videox_fun/data/dataset_video.py +262 -0
- videox_fun/dist/__init__.py +66 -0
- videox_fun/dist/cogvideox_xfuser.py +105 -0
- videox_fun/dist/flux_xfuser.py +168 -0
- videox_fun/dist/fsdp.py +44 -0
- videox_fun/dist/fuser.py +55 -0
- videox_fun/dist/qwen_xfuser.py +176 -0
- videox_fun/dist/wan_xfuser.py +180 -0
- videox_fun/pipeline/__init__.py +21 -0
- videox_fun/pipeline/pipeline_wan.py +799 -0
- videox_fun/pipeline/pipeline_wan2_2.py +591 -0
- videox_fun/ui/cogvideox_fun_ui.py +722 -0
- videox_fun/ui/controller.py +514 -0
- videox_fun/ui/ui.py +366 -0
- videox_fun/ui/wan2_2_fun_ui.py +803 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
samples/
|
| 2 |
+
models/
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.pyc
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,14 +1,209 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
<h1 style="margin: 0; font-size: 2.4em;">
|
| 4 |
+
Unified Video Editing with Temporal Reasoner
|
| 5 |
+
</h1>
|
| 6 |
+
|
| 7 |
+
<h4 style="margin: 15px 0; color: #2c3e50;">
|
| 8 |
+
👁️ See → 🧠 Reason → ✏️ Edit
|
| 9 |
+
</h4>
|
| 10 |
+
|
| 11 |
+
<h4 style="margin: 15px 0; color: #2c3e50;">
|
| 12 |
+
🚀 A Chain of Frames video editing method enbale temporal reasoning and 4x video length extrapolation with just 50k training pairs!
|
| 13 |
+
</h4>
|
| 14 |
+
|
| 15 |
+
[](https://huggingface.co/papers/2512.07469)
|
| 16 |
+
[](https://arxiv.org/abs/2512.07469)
|
| 17 |
+
[](https://videocof.github.io)
|
| 18 |
+
[](https://huggingface.co/XiangpengYang/VideoCoF)
|
| 19 |
+

|
| 20 |
+
|
| 21 |
+
</div>
|
| 22 |
+
|
| 23 |
+
<div align="center">
|
| 24 |
+
<b>
|
| 25 |
+
<a href="https://scholar.google.com/citations?user=reiIeYMAAAAJ">Xiangpeng Yang</a><sup>1</sup>,
|
| 26 |
+
<a href="https://horizonwind2004.github.io/">Ji Xie</a><sup>2</sup>,
|
| 27 |
+
<a href="https://scholar.google.com/citations?user=OvfI_HMAAAAJ">Yiyuan Yang</a><sup>1</sup>,
|
| 28 |
+
<a href="https://scholar.google.com/citations?user=zfeWd6gAAAAJ">Yan Huang</a><sup>1</sup>,
|
| 29 |
+
<a href="https://scholar.google.com/citations?user=sCuACdkAAAAJ">Min Xu</a><sup>1</sup>,
|
| 30 |
+
<a href="https://scholar.google.com/citations?user=sCuACdkAAAAJ">Qiang Wu</a><sup>1</sup>
|
| 31 |
+
</b>
|
| 32 |
+
<br>
|
| 33 |
+
<span style="font-size: 1em; color: #555;"><sup>1</sup>University of Technology Sydney, <sup>2</sup>Zhejiang University</span>
|
| 34 |
+
</div>
|
| 35 |
+
|
| 36 |
+
<br>
|
| 37 |
+
|
| 38 |
+
## 💿 Introduction
|
| 39 |
+
|
| 40 |
+
https://github.com/user-attachments/assets/26f7d347-3d6c-43cf-9645-6eb5906f6ad6
|
| 41 |
+
|
| 42 |
+
## 🔥 News
|
| 43 |
+
|
| 44 |
+
- **2025.12.09**: Paper available on arXiv.
|
| 45 |
+
- **2025.12.08**: Release the inference code and videocof-50k weight.
|
| 46 |
+
- **2025.12.06**: 🔥 Project Page and README updated!
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
## 📑 Table of Contents
|
| 50 |
+
|
| 51 |
+
- [🔧 Quick Start](#-quick-start)
|
| 52 |
+
- [🏆 Model Zoo](#-model-zoo)
|
| 53 |
+
- [🍭 Results](#-results)
|
| 54 |
+
- [🎨 Edit Comparison](#-edit-comparison)
|
| 55 |
+
- [🚧 TODO](#-todo)
|
| 56 |
+
- [🙏 Acknowledgments](#-acknowledgments)
|
| 57 |
+
- [📜 License](#-license)
|
| 58 |
+
- [📮 Contact](#-contact)
|
| 59 |
+
- [📄 Citation](#-citation)
|
| 60 |
+
|
| 61 |
+
## 🔧 Quick Start
|
| 62 |
+
|
| 63 |
+
1. **Clone the repository:**
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
git clone https://github.com/videocof/VideoCoF.git
|
| 67 |
+
cd VideoCoF
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
2. **Install dependencies:**
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
# 1. Create and activate a conda environment
|
| 74 |
+
conda create -n videocof python=3.10
|
| 75 |
+
conda activate videocof
|
| 76 |
+
|
| 77 |
+
# 2. Install PyTorch (Choose version compatible with your CUDA)
|
| 78 |
+
# For standard GPUs (CUDA 12.1):
|
| 79 |
+
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
|
| 80 |
+
|
| 81 |
+
# For Hopper GPUs (e.g., H100/H800) requiring fast inference:
|
| 82 |
+
# pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu128
|
| 83 |
+
|
| 84 |
+
# 3. Install other dependencies
|
| 85 |
+
pip install -r requirements.txt
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
**Note on Flash Attention:**
|
| 89 |
+
We recommend using **FlashAttention-3** (currently beta) for optimal performance, especially on NVIDIA H100/H800 GPUs.
|
| 90 |
+
If you are using these GPUs, please follow the [official FlashAttention-3 installation guide](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release) after installing the compatible PyTorch version (e.g., PyTorch 2.8 + CUDA 12.8).
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
3. **Download Models:**
|
| 94 |
+
|
| 95 |
+
**Wan-2.1-T2V-14B Pretrained Weights:**
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
git lfs install
|
| 99 |
+
git clone https://huggingface.co/Wan-AI/Wan2.1-T2V-14B
|
| 100 |
+
|
| 101 |
+
# Or using huggingface-cli:
|
| 102 |
+
# hf download Wan-AI/Wan2.1-T2V-14B --local-dir Wan2.1-T2V-14B
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
**VideoCoF Checkpoint:**
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
git lfs install
|
| 109 |
+
git clone https://huggingface.co/XiangpengYang/VideoCoF videocof_weight
|
| 110 |
+
|
| 111 |
+
# Or using huggingface-cli:
|
| 112 |
+
# hf download XiangpengYang/VideoCoF --local-dir videocof_weight
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
4. **Inference:**
|
| 116 |
+
|
| 117 |
+
For single inference tasks:
|
| 118 |
+
|
| 119 |
+
```bash
|
| 120 |
+
# Object Removal
|
| 121 |
+
sh scripts/obj_rem.sh
|
| 122 |
+
|
| 123 |
+
# Object Addition
|
| 124 |
+
sh scripts/obj_add.sh
|
| 125 |
+
|
| 126 |
+
# Local Style Transfer
|
| 127 |
+
sh scripts/local_style.sh
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
For parallel inference:
|
| 131 |
+
|
| 132 |
+
```bash
|
| 133 |
+
sh scripts/parallel_infer.sh
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
## 🏆 Model Zoo
|
| 137 |
+
|
| 138 |
+
Our models are available on Hugging Face:
|
| 139 |
+
|
| 140 |
+
| Model Name | Description | Link |
|
| 141 |
+
|------------|-------------|------|
|
| 142 |
+
| VideoCoF-Base | Base model trained on 50k video pairs | [Hugging Face](https://huggingface.co/XiangpengYang/VideoCoF) |
|
| 143 |
+
|
| 144 |
+
## 🍭 Results
|
| 145 |
+
|
| 146 |
+
### Why We Need Reasoning Before Editing?
|
| 147 |
+

|
| 148 |
+
|
| 149 |
+
Current video editing methods typically follow two paths:
|
| 150 |
+
1. **Expert models**: Rely on external masks for precision but sacrifice unification.
|
| 151 |
+
2. **Unified in-context learning models**: Mask-free but often struggle with spatial accuracy due to the lack of explicit cues.
|
| 152 |
+
|
| 153 |
+
**VideoCoF** bridges this gap by predicting reasoning tokens before generating the target video tokens.
|
| 154 |
+
|
| 155 |
+
### Key Capabilities
|
| 156 |
+
|
| 157 |
+
1. **Seeing, Reasoning, Editing**: VideoCoF adopts a "seeing, reasoning, editing" approach, ensuring edits are applied accurately to the intended targets.
|
| 158 |
+
2. **Length Extrapolation**: Trained on only **50k** data (33 frames), VideoCoF demonstrates robust multi-shot editing and length generalization (e.g., 4× length extrapolation).
|
| 159 |
+
3. **Diverse Editing Tasks**: Supports fine-grained (instance and part level, spatial aware) Object Removal, Object Addition, Object Swap, and Local Style Transfer.
|
| 160 |
+
|
| 161 |
+
### Gallery Highlights
|
| 162 |
+
|
| 163 |
+
> Please refer to our [Project Page](https://videocof.github.io) for the full gallery.
|
| 164 |
+
|
| 165 |
+
* **Object Removal**: Remove people or objects based on text prompts.
|
| 166 |
+
* **Object Addition**: Add elements like animals, objects, or people.
|
| 167 |
+
* **Object Swap**: Change specific attributes or objects.
|
| 168 |
+
* **Local Style Transfer**: Modify texture, materials or colors.
|
| 169 |
+
|
| 170 |
+
## 🚧 TODO
|
| 171 |
+
|
| 172 |
+
- [x] Release paper.
|
| 173 |
+
- [x] Release inference code and weights.
|
| 174 |
+
- [ ] Release training code.
|
| 175 |
+
- [ ] Release training data.
|
| 176 |
+
- [ ] Add Hugging Face demo.
|
| 177 |
+
|
| 178 |
+
## 🙏 Acknowledgments
|
| 179 |
+
|
| 180 |
+
We thank the authors of related works and the open-source community [VideoX-Fun](https://github.com/aigc-apps/VideoX-Fun) and [Wan](https://github.com/Wan-Video/Wan2.1) for their contributions.
|
| 181 |
+
|
| 182 |
+
## 📜 License
|
| 183 |
+
|
| 184 |
+
This project is licensed under the [Apache License 2.0](LICENSE).
|
| 185 |
+
|
| 186 |
+
## 📮 Contact
|
| 187 |
+
|
| 188 |
+
For any questions, please feel free to reach out to the author Xiangpeng Yang [@knightyxp](https://github.com/knightyxp), email: knightyxp@gmail.com/Xiangpeng.Yang@student.uts.edu.au
|
| 189 |
+
|
| 190 |
+
## 📄 Citation
|
| 191 |
+
|
| 192 |
+
If you find this work useful for your research, please consider citing:
|
| 193 |
+
|
| 194 |
+
```bibtex
|
| 195 |
+
@article{yang2025videocof,
|
| 196 |
+
title={Unified Video Editing with Temporal Reasoner},
|
| 197 |
+
author={Yang, Xiangpeng and Xie, Ji and Yang, Yiyuan and Huang, Yan and Xu, Min and Wu, Qiang},
|
| 198 |
+
journal={arXiv preprint arXiv:2512.07469},
|
| 199 |
+
year={2025}
|
| 200 |
+
}
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
<div align="center">
|
| 204 |
+
⭐ **If you find this project helpful, please consider giving it a star!** ⭐
|
| 205 |
+
</div>
|
| 206 |
+
|
| 207 |
+
## ⭐️ Star History
|
| 208 |
+
|
| 209 |
+
[](https://star-history.com/#knightyxp/VideoCoF&Date)
|
__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .comfyui.comfyui_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
| 2 |
+
|
| 3 |
+
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
app.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import torch
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import numpy as np
|
| 7 |
+
import imageio
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
# Add project root to path
|
| 11 |
+
# current_file_path = os.path.abspath(__file__)
|
| 12 |
+
# project_root = os.path.dirname(os.path.dirname(current_file_path))
|
| 13 |
+
# if project_root not in sys.path:
|
| 14 |
+
# sys.path.insert(0, project_root)
|
| 15 |
+
|
| 16 |
+
from videox_fun.ui.wan_ui import Wan_Controller, css
|
| 17 |
+
from videox_fun.ui.ui import (
|
| 18 |
+
create_model_type, create_model_checkpoints, create_finetune_models_checkpoints,
|
| 19 |
+
create_teacache_params, create_cfg_skip_params, create_cfg_riflex_k,
|
| 20 |
+
create_prompts, create_samplers, create_height_width,
|
| 21 |
+
create_generation_methods_and_video_length, create_generation_method,
|
| 22 |
+
create_cfg_and_seedbox, create_ui_outputs
|
| 23 |
+
)
|
| 24 |
+
from videox_fun.data.dataset_image_video import derive_ground_object_from_instruction
|
| 25 |
+
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora
|
| 26 |
+
from videox_fun.utils.utils import save_videos_grid, timer
|
| 27 |
+
|
| 28 |
+
# Redefine create_height_width to remove Chinese and specific defaults if needed,
|
| 29 |
+
# although we will mostly ignore sliders if we use input resolution.
|
| 30 |
+
# We will create a custom version here to avoid modifying the library file if possible,
|
| 31 |
+
# or we just rely on `create_height_width` and update labels.
|
| 32 |
+
# But `create_height_width` is imported. Let's override it or create a new one.
|
| 33 |
+
|
| 34 |
+
def create_height_width_english(default_height, default_width, maximum_height, maximum_width):
|
| 35 |
+
resize_method = gr.Radio(
|
| 36 |
+
["Generate by", "Resize according to Reference"],
|
| 37 |
+
value="Generate by",
|
| 38 |
+
show_label=False,
|
| 39 |
+
visible=False # Hide since we force input resolution
|
| 40 |
+
)
|
| 41 |
+
# We keep sliders visible but maybe we can update them dynamically or just ignore them?
|
| 42 |
+
# User requested "input is whatever resolution, inference is whatever resolution".
|
| 43 |
+
# So we can hide these or just label them as "Default / Override if no video".
|
| 44 |
+
# But better to hide them if we always use video resolution.
|
| 45 |
+
# However, if no video is provided (which shouldn't happen for VideoCoF), we might need them.
|
| 46 |
+
# Let's keep them but make them less prominent or explain.
|
| 47 |
+
# Actually user said "no default 480x832", implying don't force it.
|
| 48 |
+
|
| 49 |
+
width_slider = gr.Slider(label="Width", value=default_width, minimum=128, maximum=maximum_width, step=16, visible=False)
|
| 50 |
+
height_slider = gr.Slider(label="Height", value=default_height, minimum=128, maximum=maximum_height, step=16, visible=False)
|
| 51 |
+
base_resolution = gr.Radio(label="Base Resolution", value=512, choices=[512, 640, 768, 896, 960, 1024], visible=False)
|
| 52 |
+
|
| 53 |
+
return resize_method, width_slider, height_slider, base_resolution
|
| 54 |
+
|
| 55 |
+
def load_video_frames(video_path: str, source_frames: int):
|
| 56 |
+
assert source_frames is not None, "source_frames is required"
|
| 57 |
+
|
| 58 |
+
reader = imageio.get_reader(video_path)
|
| 59 |
+
try:
|
| 60 |
+
total_frames = reader.count_frames()
|
| 61 |
+
except Exception:
|
| 62 |
+
total_frames = sum(1 for _ in reader)
|
| 63 |
+
reader = imageio.get_reader(video_path)
|
| 64 |
+
|
| 65 |
+
stride = max(1, total_frames // source_frames)
|
| 66 |
+
# Using random start frame as in inference.py
|
| 67 |
+
start_frame = torch.randint(0, max(1, total_frames - stride * source_frames), (1,))[0].item()
|
| 68 |
+
|
| 69 |
+
frames = []
|
| 70 |
+
original_height, original_width = None, None
|
| 71 |
+
|
| 72 |
+
for i in range(source_frames):
|
| 73 |
+
idx = start_frame + i * stride
|
| 74 |
+
if idx >= total_frames:
|
| 75 |
+
break
|
| 76 |
+
try:
|
| 77 |
+
frame = reader.get_data(idx)
|
| 78 |
+
pil_frame = Image.fromarray(frame)
|
| 79 |
+
if original_height is None:
|
| 80 |
+
original_width, original_height = pil_frame.size
|
| 81 |
+
frames.append(pil_frame)
|
| 82 |
+
except IndexError:
|
| 83 |
+
break
|
| 84 |
+
|
| 85 |
+
reader.close()
|
| 86 |
+
|
| 87 |
+
while len(frames) < source_frames:
|
| 88 |
+
if frames:
|
| 89 |
+
frames.append(frames[-1].copy())
|
| 90 |
+
else:
|
| 91 |
+
w, h = (original_width, original_height) if original_width else (832, 480)
|
| 92 |
+
frames.append(Image.new('RGB', (w, h), (0, 0, 0)))
|
| 93 |
+
|
| 94 |
+
input_video = torch.from_numpy(np.array(frames))
|
| 95 |
+
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0).float()
|
| 96 |
+
input_video = input_video * (2.0 / 255.0) - 1.0
|
| 97 |
+
|
| 98 |
+
return input_video, original_height, original_width
|
| 99 |
+
|
| 100 |
+
class VideoCoF_Controller(Wan_Controller):
|
| 101 |
+
@timer
|
| 102 |
+
def generate(
|
| 103 |
+
self,
|
| 104 |
+
diffusion_transformer_dropdown,
|
| 105 |
+
base_model_dropdown,
|
| 106 |
+
lora_model_dropdown,
|
| 107 |
+
lora_alpha_slider,
|
| 108 |
+
prompt_textbox,
|
| 109 |
+
negative_prompt_textbox,
|
| 110 |
+
sampler_dropdown,
|
| 111 |
+
sample_step_slider,
|
| 112 |
+
resize_method,
|
| 113 |
+
width_slider,
|
| 114 |
+
height_slider,
|
| 115 |
+
base_resolution,
|
| 116 |
+
generation_method,
|
| 117 |
+
length_slider,
|
| 118 |
+
overlap_video_length,
|
| 119 |
+
partial_video_length,
|
| 120 |
+
cfg_scale_slider,
|
| 121 |
+
start_image,
|
| 122 |
+
end_image,
|
| 123 |
+
validation_video,
|
| 124 |
+
validation_video_mask,
|
| 125 |
+
control_video,
|
| 126 |
+
denoise_strength,
|
| 127 |
+
seed_textbox,
|
| 128 |
+
ref_image=None,
|
| 129 |
+
enable_teacache=None,
|
| 130 |
+
teacache_threshold=None,
|
| 131 |
+
num_skip_start_steps=None,
|
| 132 |
+
teacache_offload=None,
|
| 133 |
+
cfg_skip_ratio=None,
|
| 134 |
+
enable_riflex=None,
|
| 135 |
+
riflex_k=None,
|
| 136 |
+
# Custom args
|
| 137 |
+
source_frames_slider=33,
|
| 138 |
+
reasoning_frames_slider=4,
|
| 139 |
+
repeat_rope_checkbox=True,
|
| 140 |
+
fps=10,
|
| 141 |
+
is_api=False,
|
| 142 |
+
):
|
| 143 |
+
self.clear_cache()
|
| 144 |
+
print(f"VideoCoF Generation started.")
|
| 145 |
+
|
| 146 |
+
if self.diffusion_transformer_dropdown != diffusion_transformer_dropdown:
|
| 147 |
+
self.update_diffusion_transformer(diffusion_transformer_dropdown)
|
| 148 |
+
|
| 149 |
+
if self.base_model_path != base_model_dropdown:
|
| 150 |
+
self.update_base_model(base_model_dropdown)
|
| 151 |
+
|
| 152 |
+
if self.lora_model_path != lora_model_dropdown:
|
| 153 |
+
self.update_lora_model(lora_model_dropdown)
|
| 154 |
+
|
| 155 |
+
# Scheduler setup
|
| 156 |
+
scheduler_config = self.pipeline.scheduler.config
|
| 157 |
+
if sampler_dropdown in ["Flow_Unipc", "Flow_DPM++"]:
|
| 158 |
+
scheduler_config['shift'] = 1
|
| 159 |
+
self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config)
|
| 160 |
+
|
| 161 |
+
# LoRA merging
|
| 162 |
+
if self.lora_model_path != "none":
|
| 163 |
+
print(f"Merge Lora.")
|
| 164 |
+
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 165 |
+
|
| 166 |
+
# Seed
|
| 167 |
+
if int(seed_textbox) != -1 and seed_textbox != "":
|
| 168 |
+
torch.manual_seed(int(seed_textbox))
|
| 169 |
+
else:
|
| 170 |
+
seed_textbox = np.random.randint(0, 1e10)
|
| 171 |
+
generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox))
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
# VideoCoF logic
|
| 175 |
+
# Use validation_video as source if provided (UI standard for Video-to-Video)
|
| 176 |
+
input_video_path = validation_video
|
| 177 |
+
|
| 178 |
+
if input_video_path is None:
|
| 179 |
+
# Fallback to control_video if set, but standard UI uses validation_video
|
| 180 |
+
input_video_path = control_video
|
| 181 |
+
|
| 182 |
+
if input_video_path is None:
|
| 183 |
+
raise ValueError("Please upload a video for VideoCoF generation.")
|
| 184 |
+
|
| 185 |
+
# CoT Prompt Construction
|
| 186 |
+
edit_text = prompt_textbox
|
| 187 |
+
ground_instr = derive_ground_object_from_instruction(edit_text)
|
| 188 |
+
prompt = (
|
| 189 |
+
"A video sequence showing three parts: first the original scene, "
|
| 190 |
+
f"then grounded {ground_instr}, and finally the same scene but {edit_text}"
|
| 191 |
+
)
|
| 192 |
+
print(f"Constructed prompt: {prompt}")
|
| 193 |
+
|
| 194 |
+
# Load video frames
|
| 195 |
+
input_video_tensor, video_height, video_width = load_video_frames(
|
| 196 |
+
input_video_path,
|
| 197 |
+
source_frames=source_frames_slider
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Using loaded video dimensions
|
| 201 |
+
h, w = video_height, video_width
|
| 202 |
+
print(f"Input video dimensions: {w}x{h}")
|
| 203 |
+
|
| 204 |
+
print(f"Running pipeline with frames={length_slider}, source={source_frames_slider}, reasoning={reasoning_frames_slider}")
|
| 205 |
+
|
| 206 |
+
sample = self.pipeline(
|
| 207 |
+
video=input_video_tensor,
|
| 208 |
+
prompt=prompt,
|
| 209 |
+
num_frames=length_slider,
|
| 210 |
+
source_frames=source_frames_slider,
|
| 211 |
+
reasoning_frames=reasoning_frames_slider,
|
| 212 |
+
negative_prompt=negative_prompt_textbox,
|
| 213 |
+
height=h,
|
| 214 |
+
width=w,
|
| 215 |
+
generator=generator,
|
| 216 |
+
guidance_scale=cfg_scale_slider,
|
| 217 |
+
num_inference_steps=sample_step_slider,
|
| 218 |
+
repeat_rope=repeat_rope_checkbox,
|
| 219 |
+
cot=True,
|
| 220 |
+
).videos
|
| 221 |
+
|
| 222 |
+
final_video = sample
|
| 223 |
+
|
| 224 |
+
except Exception as e:
|
| 225 |
+
print(f"Error: {e}")
|
| 226 |
+
if self.lora_model_path != "none":
|
| 227 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 228 |
+
return gr.update(), gr.update(), f"Error: {str(e)}"
|
| 229 |
+
|
| 230 |
+
# Unmerge LoRA
|
| 231 |
+
if self.lora_model_path != "none":
|
| 232 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 233 |
+
|
| 234 |
+
# Save output
|
| 235 |
+
save_sample_path = self.save_outputs(
|
| 236 |
+
False, length_slider, final_video, fps=fps
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Return input video to display it alongside output if needed?
|
| 240 |
+
# But generate returns [result_image, result_video, infer_progress].
|
| 241 |
+
# The user said "load original video didn't display".
|
| 242 |
+
# That usually refers to the input component not showing the video after upload or example selection.
|
| 243 |
+
# Grado handles that automatically if `value` is set or user uploads.
|
| 244 |
+
# Maybe they mean the `validation_video` component didn't show the example?
|
| 245 |
+
# Or do they mean they want to see the processed input frames?
|
| 246 |
+
# "load 原视频没有display 出来" -> "Loaded original video didn't display".
|
| 247 |
+
# Likely referring to the input UI component.
|
| 248 |
+
# If they mean they want to see it in the output area, we can't easily change the return signature without changing UI structure.
|
| 249 |
+
# But let's ensure the input component works.
|
| 250 |
+
|
| 251 |
+
return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
|
| 252 |
+
|
| 253 |
+
def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype):
|
| 254 |
+
controller = VideoCoF_Controller(
|
| 255 |
+
GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
|
| 256 |
+
config_path=config_path, compile_dit=compile_dit,
|
| 257 |
+
weight_dtype=weight_dtype
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
with gr.Blocks() as demo:
|
| 261 |
+
gr.Markdown("# VideoCoF Demo")
|
| 262 |
+
|
| 263 |
+
with gr.Column(variant="panel"):
|
| 264 |
+
# Hide model selection
|
| 265 |
+
diffusion_transformer_dropdown, _ = create_model_checkpoints(controller, visible=False, default_model="Wan-AI/Wan2.1-T2V-14B")
|
| 266 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider, _ = create_finetune_models_checkpoints(controller, visible=False, default_lora="XiangpengYang/VideoCoF")
|
| 267 |
+
|
| 268 |
+
# Set default LoRA alpha to 1.0 (matching inference.py)
|
| 269 |
+
lora_alpha_slider.value = 1.0
|
| 270 |
+
|
| 271 |
+
with gr.Row():
|
| 272 |
+
# Disable teacache by default
|
| 273 |
+
enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = create_teacache_params(False, 0.10, 5, False)
|
| 274 |
+
cfg_skip_ratio = create_cfg_skip_params(0)
|
| 275 |
+
enable_riflex, riflex_k = create_cfg_riflex_k(False, 6)
|
| 276 |
+
|
| 277 |
+
with gr.Column(variant="panel"):
|
| 278 |
+
prompt_textbox, negative_prompt_textbox = create_prompts(prompt="Remove the young man with short black hair wearing black shirt on the left.")
|
| 279 |
+
|
| 280 |
+
with gr.Row():
|
| 281 |
+
with gr.Column():
|
| 282 |
+
sampler_dropdown, sample_step_slider = create_samplers(controller)
|
| 283 |
+
|
| 284 |
+
# Custom VideoCoF Params
|
| 285 |
+
with gr.Group():
|
| 286 |
+
gr.Markdown("### VideoCoF Parameters")
|
| 287 |
+
source_frames_slider = gr.Slider(label="Source Frames", minimum=1, maximum=100, value=33, step=1)
|
| 288 |
+
reasoning_frames_slider = gr.Slider(label="Reasoning Frames", minimum=1, maximum=20, value=4, step=1)
|
| 289 |
+
repeat_rope_checkbox = gr.Checkbox(label="Repeat RoPE", value=True)
|
| 290 |
+
|
| 291 |
+
# Use custom height/width creation to hide/customize
|
| 292 |
+
resize_method, width_slider, height_slider, base_resolution = create_height_width_english(
|
| 293 |
+
default_height=480, default_width=832, maximum_height=1344, maximum_width=1344
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Default video length 65
|
| 297 |
+
generation_method, length_slider, overlap_video_length, partial_video_length = \
|
| 298 |
+
create_generation_methods_and_video_length(
|
| 299 |
+
["Video Generation"],
|
| 300 |
+
default_video_length=65,
|
| 301 |
+
maximum_video_length=161
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# Simplified input for VideoCoF - mainly Video to Video.
|
| 305 |
+
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
|
| 306 |
+
["Video to Video"], prompt_textbox, support_end_image=False, default_video="assets/two_man.mp4",
|
| 307 |
+
video_examples=[
|
| 308 |
+
["assets/two_man.mp4", "Remove the young man with short black hair wearing black shirt on the left."],
|
| 309 |
+
["assets/sign.mp4", "Replace the yellow \"SCHOOL\" sign with a red hospital sign, featuring a white hospital emblem on the top and the word \"HOSPITAL\" below."]
|
| 310 |
+
]
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Ensure validation_video is visible and interactive
|
| 314 |
+
validation_video.visible = True
|
| 315 |
+
validation_video.interactive = True
|
| 316 |
+
|
| 317 |
+
# Set default seed to 0
|
| 318 |
+
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(True)
|
| 319 |
+
seed_textbox.value = "0"
|
| 320 |
+
|
| 321 |
+
generate_button = gr.Button(value="Generate", variant='primary')
|
| 322 |
+
|
| 323 |
+
result_image, result_video, infer_progress = create_ui_outputs()
|
| 324 |
+
|
| 325 |
+
# Event handlers
|
| 326 |
+
generate_button.click(
|
| 327 |
+
fn=controller.generate,
|
| 328 |
+
inputs=[
|
| 329 |
+
diffusion_transformer_dropdown,
|
| 330 |
+
base_model_dropdown,
|
| 331 |
+
lora_model_dropdown,
|
| 332 |
+
lora_alpha_slider,
|
| 333 |
+
prompt_textbox,
|
| 334 |
+
negative_prompt_textbox,
|
| 335 |
+
sampler_dropdown,
|
| 336 |
+
sample_step_slider,
|
| 337 |
+
resize_method,
|
| 338 |
+
width_slider,
|
| 339 |
+
height_slider,
|
| 340 |
+
base_resolution,
|
| 341 |
+
generation_method,
|
| 342 |
+
length_slider,
|
| 343 |
+
overlap_video_length,
|
| 344 |
+
partial_video_length,
|
| 345 |
+
cfg_scale_slider,
|
| 346 |
+
start_image,
|
| 347 |
+
end_image,
|
| 348 |
+
validation_video,
|
| 349 |
+
validation_video_mask,
|
| 350 |
+
control_video,
|
| 351 |
+
denoise_strength,
|
| 352 |
+
seed_textbox,
|
| 353 |
+
ref_image,
|
| 354 |
+
enable_teacache,
|
| 355 |
+
teacache_threshold,
|
| 356 |
+
num_skip_start_steps,
|
| 357 |
+
teacache_offload,
|
| 358 |
+
cfg_skip_ratio,
|
| 359 |
+
enable_riflex,
|
| 360 |
+
riflex_k,
|
| 361 |
+
# New inputs
|
| 362 |
+
source_frames_slider,
|
| 363 |
+
reasoning_frames_slider,
|
| 364 |
+
repeat_rope_checkbox
|
| 365 |
+
],
|
| 366 |
+
outputs=[result_image, result_video, infer_progress]
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
return demo, controller
|
| 370 |
+
|
| 371 |
+
if __name__ == "__main__":
|
| 372 |
+
from videox_fun.ui.controller import flow_scheduler_dict
|
| 373 |
+
|
| 374 |
+
GPU_memory_mode = "sequential_cpu_offload"
|
| 375 |
+
compile_dit = False
|
| 376 |
+
weight_dtype = torch.bfloat16
|
| 377 |
+
server_name = "0.0.0.0"
|
| 378 |
+
server_port = 7860
|
| 379 |
+
config_path = "config/wan2.1/wan_civitai.yaml"
|
| 380 |
+
|
| 381 |
+
demo, controller = ui(GPU_memory_mode, flow_scheduler_dict, config_path, compile_dit, weight_dtype)
|
| 382 |
+
|
| 383 |
+
demo.queue(status_update_rate=1).launch(
|
| 384 |
+
server_name=server_name,
|
| 385 |
+
server_port=server_port,
|
| 386 |
+
prevent_thread_lock=True,
|
| 387 |
+
share=False
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
while True:
|
| 391 |
+
time.sleep(5)
|
assets/dough.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5262cf58ffa08dcd79d6346abec46bc0234aebfc65905b5ea2ca4ab905ca9dac
|
| 3 |
+
size 185700
|
assets/sign.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4e94f03a7d5738a001ce2e1302a8ae65596431a647dbfed83cdb6876322175a7
|
| 3 |
+
size 100798
|
assets/teaser_test.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"task_type": "obj_add",
|
| 4 |
+
"sample_id": "001",
|
| 5 |
+
"source_video_path": "assets/woman_ballon.mp4",
|
| 6 |
+
"qwen_vl_72b_refined_instruction": "Add the woman in a floral dress pointing at the balloon on the left."
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"task_type": "obj_rem",
|
| 10 |
+
"sample_id": "001",
|
| 11 |
+
"source_video_path": "assets/two_man.mp4",
|
| 12 |
+
"qwen_vl_72b_refined_instruction": "Remove the young man with short black hair wearing black shirt on the left."
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"task_type": "local_style",
|
| 16 |
+
"sample_id": "001",
|
| 17 |
+
"source_video_path": "assets/sign.mp4",
|
| 18 |
+
"qwen_vl_72b_refined_instruction": "Replace the yellow \"SCHOOL\" sign with a red hospital sign, featuring a white hospital emblem on the top and the word \"HOSPITAL\" below."
|
| 19 |
+
}
|
| 20 |
+
]
|
assets/two_man.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bd9c0f6523207bbcf0d5159beb7f7eaf37811e6e5b7a53585dda50491a573cd9
|
| 3 |
+
size 303233
|
assets/woman_ballon.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:575b37abda414161179bc00e0e7b6893f28feb967e875c8f9676275d2cc32572
|
| 3 |
+
size 89085
|
config/1.3b_lora_zero_stage2_config.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bf16": {
|
| 3 |
+
"enabled": true
|
| 4 |
+
},
|
| 5 |
+
"train_micro_batch_size_per_gpu": 4,
|
| 6 |
+
"train_batch_size": 64,
|
| 7 |
+
"gradient_accumulation_steps": 1,
|
| 8 |
+
"gradient_clipping": 0.05,
|
| 9 |
+
"zero_optimization": {
|
| 10 |
+
"stage": 2,
|
| 11 |
+
"offload_optimizer": {
|
| 12 |
+
"device": "none"
|
| 13 |
+
},
|
| 14 |
+
"overlap_comm": true,
|
| 15 |
+
"contiguous_gradients": true,
|
| 16 |
+
"sub_group_size": 1e9,
|
| 17 |
+
"reduce_bucket_size": 5e8,
|
| 18 |
+
"allgather_partitions": true,
|
| 19 |
+
"allgather_bucket_size": 2e8,
|
| 20 |
+
"reduce_scatter": true
|
| 21 |
+
},
|
| 22 |
+
"steps_per_print": 100,
|
| 23 |
+
"wall_clock_breakdown": false
|
| 24 |
+
}
|
config/14b_lora_zero2_bf16_config.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bf16": {
|
| 3 |
+
"enabled": true
|
| 4 |
+
},
|
| 5 |
+
"train_micro_batch_size_per_gpu": 1,
|
| 6 |
+
"train_batch_size": "auto",
|
| 7 |
+
"gradient_accumulation_steps": 1,
|
| 8 |
+
"gradient_clipping": 0.05,
|
| 9 |
+
"zero_optimization": {
|
| 10 |
+
"stage": 2,
|
| 11 |
+
"offload_optimizer": {
|
| 12 |
+
"device": "none"
|
| 13 |
+
},
|
| 14 |
+
"overlap_comm": true,
|
| 15 |
+
"contiguous_gradients": true,
|
| 16 |
+
"sub_group_size": 1e9,
|
| 17 |
+
"reduce_bucket_size": 5e8,
|
| 18 |
+
"allgather_partitions": true,
|
| 19 |
+
"allgather_bucket_size": 2e8,
|
| 20 |
+
"reduce_scatter": true
|
| 21 |
+
},
|
| 22 |
+
"steps_per_print": 100,
|
| 23 |
+
"wall_clock_breakdown": false
|
| 24 |
+
}
|
config/wan2.1/wan_civitai.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
format: civitai
|
| 2 |
+
pipeline: Wan
|
| 3 |
+
transformer_additional_kwargs:
|
| 4 |
+
transformer_subpath: ./
|
| 5 |
+
dict_mapping:
|
| 6 |
+
in_dim: in_channels
|
| 7 |
+
dim: hidden_size
|
| 8 |
+
|
| 9 |
+
vae_kwargs:
|
| 10 |
+
vae_subpath: Wan2.1_VAE.pth
|
| 11 |
+
temporal_compression_ratio: 4
|
| 12 |
+
spatial_compression_ratio: 8
|
| 13 |
+
|
| 14 |
+
text_encoder_kwargs:
|
| 15 |
+
text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
|
| 16 |
+
tokenizer_subpath: google/umt5-xxl
|
| 17 |
+
text_length: 512
|
| 18 |
+
vocab: 256384
|
| 19 |
+
dim: 4096
|
| 20 |
+
dim_attn: 4096
|
| 21 |
+
dim_ffn: 10240
|
| 22 |
+
num_heads: 64
|
| 23 |
+
num_layers: 24
|
| 24 |
+
num_buckets: 32
|
| 25 |
+
shared_pos: False
|
| 26 |
+
dropout: 0.0
|
| 27 |
+
|
| 28 |
+
scheduler_kwargs:
|
| 29 |
+
scheduler_subpath: null
|
| 30 |
+
num_train_timesteps: 1000
|
| 31 |
+
shift: 5.0
|
| 32 |
+
use_dynamic_shifting: false
|
| 33 |
+
base_shift: 0.5
|
| 34 |
+
max_shift: 1.15
|
| 35 |
+
base_image_seq_len: 256
|
| 36 |
+
max_image_seq_len: 4096
|
| 37 |
+
|
| 38 |
+
image_encoder_kwargs:
|
| 39 |
+
image_encoder_subpath: models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
|
config/wan2.2/wan_civitai_5b.yaml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
format: civitai
|
| 2 |
+
pipeline: Wan
|
| 3 |
+
transformer_additional_kwargs:
|
| 4 |
+
transformer_low_noise_model_subpath: ./
|
| 5 |
+
transformer_combination_type: "single"
|
| 6 |
+
dict_mapping:
|
| 7 |
+
in_dim: in_channels
|
| 8 |
+
dim: hidden_size
|
| 9 |
+
|
| 10 |
+
vae_kwargs:
|
| 11 |
+
vae_type: "AutoencoderKLWan3_8"
|
| 12 |
+
vae_subpath: Wan2.2_VAE.pth
|
| 13 |
+
temporal_compression_ratio: 4
|
| 14 |
+
spatial_compression_ratio: 16
|
| 15 |
+
|
| 16 |
+
text_encoder_kwargs:
|
| 17 |
+
text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
|
| 18 |
+
tokenizer_subpath: google/umt5-xxl
|
| 19 |
+
text_length: 512
|
| 20 |
+
vocab: 256384
|
| 21 |
+
dim: 4096
|
| 22 |
+
dim_attn: 4096
|
| 23 |
+
dim_ffn: 10240
|
| 24 |
+
num_heads: 64
|
| 25 |
+
num_layers: 24
|
| 26 |
+
num_buckets: 32
|
| 27 |
+
shared_pos: False
|
| 28 |
+
dropout: 0.0
|
| 29 |
+
|
| 30 |
+
scheduler_kwargs:
|
| 31 |
+
scheduler_subpath: null
|
| 32 |
+
num_train_timesteps: 1000
|
| 33 |
+
shift: 5.0
|
| 34 |
+
use_dynamic_shifting: false
|
| 35 |
+
base_shift: 0.5
|
| 36 |
+
max_shift: 1.15
|
| 37 |
+
base_image_seq_len: 256
|
| 38 |
+
max_image_seq_len: 4096
|
| 39 |
+
|
| 40 |
+
image_encoder_kwargs:
|
| 41 |
+
image_encoder_subpath: models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
|
config/wan2.2/wan_civitai_i2v.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
format: civitai
|
| 2 |
+
pipeline: Wan
|
| 3 |
+
transformer_additional_kwargs:
|
| 4 |
+
transformer_low_noise_model_subpath: ./low_noise_model
|
| 5 |
+
transformer_high_noise_model_subpath: ./high_noise_model
|
| 6 |
+
transformer_combination_type: "moe"
|
| 7 |
+
boundary: 0.900
|
| 8 |
+
dict_mapping:
|
| 9 |
+
in_dim: in_channels
|
| 10 |
+
dim: hidden_size
|
| 11 |
+
|
| 12 |
+
vae_kwargs:
|
| 13 |
+
vae_type: "AutoencoderKLWan"
|
| 14 |
+
vae_subpath: Wan2.1_VAE.pth
|
| 15 |
+
temporal_compression_ratio: 4
|
| 16 |
+
spatial_compression_ratio: 8
|
| 17 |
+
|
| 18 |
+
text_encoder_kwargs:
|
| 19 |
+
text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
|
| 20 |
+
tokenizer_subpath: google/umt5-xxl
|
| 21 |
+
text_length: 512
|
| 22 |
+
vocab: 256384
|
| 23 |
+
dim: 4096
|
| 24 |
+
dim_attn: 4096
|
| 25 |
+
dim_ffn: 10240
|
| 26 |
+
num_heads: 64
|
| 27 |
+
num_layers: 24
|
| 28 |
+
num_buckets: 32
|
| 29 |
+
shared_pos: False
|
| 30 |
+
dropout: 0.0
|
| 31 |
+
|
| 32 |
+
scheduler_kwargs:
|
| 33 |
+
scheduler_subpath: null
|
| 34 |
+
num_train_timesteps: 1000
|
| 35 |
+
shift: 5.0
|
| 36 |
+
use_dynamic_shifting: false
|
| 37 |
+
base_shift: 0.5
|
| 38 |
+
max_shift: 1.15
|
| 39 |
+
base_image_seq_len: 256
|
| 40 |
+
max_image_seq_len: 4096
|
| 41 |
+
|
| 42 |
+
image_encoder_kwargs:
|
| 43 |
+
image_encoder_subpath: models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
|
config/wan2.2/wan_civitai_s2v.yaml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
format: civitai
|
| 2 |
+
pipeline: Wan
|
| 3 |
+
transformer_additional_kwargs:
|
| 4 |
+
transformer_low_noise_model_subpath: ./
|
| 5 |
+
transformer_combination_type: "single"
|
| 6 |
+
dict_mapping:
|
| 7 |
+
in_dim: in_channels
|
| 8 |
+
dim: hidden_size
|
| 9 |
+
|
| 10 |
+
vae_kwargs:
|
| 11 |
+
vae_type: "AutoencoderKLWan"
|
| 12 |
+
vae_subpath: Wan2.1_VAE.pth
|
| 13 |
+
temporal_compression_ratio: 4
|
| 14 |
+
spatial_compression_ratio: 8
|
| 15 |
+
|
| 16 |
+
text_encoder_kwargs:
|
| 17 |
+
text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
|
| 18 |
+
tokenizer_subpath: google/umt5-xxl
|
| 19 |
+
text_length: 512
|
| 20 |
+
vocab: 256384
|
| 21 |
+
dim: 4096
|
| 22 |
+
dim_attn: 4096
|
| 23 |
+
dim_ffn: 10240
|
| 24 |
+
num_heads: 64
|
| 25 |
+
num_layers: 24
|
| 26 |
+
num_buckets: 32
|
| 27 |
+
shared_pos: False
|
| 28 |
+
dropout: 0.0
|
| 29 |
+
|
| 30 |
+
audio_encoder_kwargs:
|
| 31 |
+
audio_encoder_subpath: wav2vec2-large-xlsr-53-english
|
| 32 |
+
|
| 33 |
+
scheduler_kwargs:
|
| 34 |
+
scheduler_subpath: null
|
| 35 |
+
num_train_timesteps: 1000
|
| 36 |
+
shift: 3.0
|
| 37 |
+
use_dynamic_shifting: false
|
| 38 |
+
base_shift: 0.5
|
| 39 |
+
max_shift: 1.15
|
| 40 |
+
base_image_seq_len: 256
|
| 41 |
+
max_image_seq_len: 4096
|
| 42 |
+
|
| 43 |
+
image_encoder_kwargs:
|
| 44 |
+
image_encoder_subpath: models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
|
config/wan2.2/wan_civitai_t2v.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
format: civitai
|
| 2 |
+
pipeline: Wan
|
| 3 |
+
transformer_additional_kwargs:
|
| 4 |
+
transformer_low_noise_model_subpath: ./low_noise_model
|
| 5 |
+
transformer_high_noise_model_subpath: ./high_noise_model
|
| 6 |
+
transformer_combination_type: "moe"
|
| 7 |
+
boundary: 0.875
|
| 8 |
+
dict_mapping:
|
| 9 |
+
in_dim: in_channels
|
| 10 |
+
dim: hidden_size
|
| 11 |
+
|
| 12 |
+
vae_kwargs:
|
| 13 |
+
vae_type: "AutoencoderKLWan"
|
| 14 |
+
vae_subpath: Wan2.1_VAE.pth
|
| 15 |
+
temporal_compression_ratio: 4
|
| 16 |
+
spatial_compression_ratio: 8
|
| 17 |
+
|
| 18 |
+
text_encoder_kwargs:
|
| 19 |
+
text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
|
| 20 |
+
tokenizer_subpath: google/umt5-xxl
|
| 21 |
+
text_length: 512
|
| 22 |
+
vocab: 256384
|
| 23 |
+
dim: 4096
|
| 24 |
+
dim_attn: 4096
|
| 25 |
+
dim_ffn: 10240
|
| 26 |
+
num_heads: 64
|
| 27 |
+
num_layers: 24
|
| 28 |
+
num_buckets: 32
|
| 29 |
+
shared_pos: False
|
| 30 |
+
dropout: 0.0
|
| 31 |
+
|
| 32 |
+
scheduler_kwargs:
|
| 33 |
+
scheduler_subpath: null
|
| 34 |
+
num_train_timesteps: 1000
|
| 35 |
+
shift: 12.0
|
| 36 |
+
use_dynamic_shifting: false
|
| 37 |
+
base_shift: 0.5
|
| 38 |
+
max_shift: 1.15
|
| 39 |
+
base_image_seq_len: 256
|
| 40 |
+
max_image_seq_len: 4096
|
| 41 |
+
|
| 42 |
+
image_encoder_kwargs:
|
| 43 |
+
image_encoder_subpath: models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
|
config/zero_stage2_config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bf16": {
|
| 3 |
+
"enabled": true
|
| 4 |
+
},
|
| 5 |
+
"train_micro_batch_size_per_gpu": 1,
|
| 6 |
+
"train_batch_size": "auto",
|
| 7 |
+
"gradient_accumulation_steps": "auto",
|
| 8 |
+
"dump_state": true,
|
| 9 |
+
"zero_optimization": {
|
| 10 |
+
"stage": 2,
|
| 11 |
+
"overlap_comm": true,
|
| 12 |
+
"contiguous_gradients": true,
|
| 13 |
+
"sub_group_size": 1e9,
|
| 14 |
+
"reduce_bucket_size": 5e8
|
| 15 |
+
}
|
| 16 |
+
}
|
config/zero_stage3_config.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bf16": {
|
| 3 |
+
"enabled": true
|
| 4 |
+
},
|
| 5 |
+
"train_micro_batch_size_per_gpu": 1,
|
| 6 |
+
"train_batch_size": "auto",
|
| 7 |
+
"gradient_accumulation_steps": "auto",
|
| 8 |
+
"gradient_clipping": "auto",
|
| 9 |
+
"steps_per_print": 2000,
|
| 10 |
+
"wall_clock_breakdown": false,
|
| 11 |
+
"zero_optimization": {
|
| 12 |
+
"stage": 3,
|
| 13 |
+
"overlap_comm": true,
|
| 14 |
+
"contiguous_gradients": true,
|
| 15 |
+
"reduce_bucket_size": 5e8,
|
| 16 |
+
"sub_group_size": 1e9,
|
| 17 |
+
"stage3_max_live_parameters": 1e9,
|
| 18 |
+
"stage3_max_reuse_distance": 1e9,
|
| 19 |
+
"stage3_gather_16bit_weights_on_model_save": "auto",
|
| 20 |
+
"offload_optimizer": {
|
| 21 |
+
"device": "none"
|
| 22 |
+
},
|
| 23 |
+
"offload_param": {
|
| 24 |
+
"device": "none"
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
}
|
config/zero_stage3_config_cpu_offload.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bf16": {
|
| 3 |
+
"enabled": true
|
| 4 |
+
},
|
| 5 |
+
"train_micro_batch_size_per_gpu": 1,
|
| 6 |
+
"train_batch_size": "auto",
|
| 7 |
+
"gradient_accumulation_steps": "auto",
|
| 8 |
+
"gradient_clipping": "auto",
|
| 9 |
+
"steps_per_print": 2000,
|
| 10 |
+
"wall_clock_breakdown": false,
|
| 11 |
+
"zero_optimization": {
|
| 12 |
+
"stage": 3,
|
| 13 |
+
"overlap_comm": true,
|
| 14 |
+
"contiguous_gradients": true,
|
| 15 |
+
"reduce_bucket_size": 5e8,
|
| 16 |
+
"sub_group_size": 1e9,
|
| 17 |
+
"stage3_max_live_parameters": 1e9,
|
| 18 |
+
"stage3_max_reuse_distance": 1e9,
|
| 19 |
+
"stage3_gather_16bit_weights_on_model_save": "auto",
|
| 20 |
+
"offload_optimizer": {
|
| 21 |
+
"device": "cpu"
|
| 22 |
+
},
|
| 23 |
+
"offload_param": {
|
| 24 |
+
"device": "cpu"
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
}
|
| 28 |
+
|
inference.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import imageio
|
| 13 |
+
|
| 14 |
+
current_file_path = os.path.abspath(__file__)
|
| 15 |
+
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
|
| 16 |
+
for project_root in project_roots:
|
| 17 |
+
sys.path.insert(0, project_root) if project_root not in sys.path else None
|
| 18 |
+
|
| 19 |
+
from videox_fun.models import (AutoencoderKLWan, WanT5EncoderModel, AutoTokenizer,
|
| 20 |
+
WanTransformer3DModel)
|
| 21 |
+
from videox_fun.pipeline import WanPipeline
|
| 22 |
+
from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, replace_parameters_by_name,
|
| 23 |
+
convert_weight_dtype_wrapper)
|
| 24 |
+
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora
|
| 25 |
+
from videox_fun.utils.utils import (filter_kwargs, save_videos_grid)
|
| 26 |
+
from videox_fun.data.dataset_image_video import derive_ground_object_from_instruction
|
| 27 |
+
from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler
|
| 28 |
+
from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_video_frames(
|
| 32 |
+
video_path: str,
|
| 33 |
+
source_frames: int = None,
|
| 34 |
+
):
|
| 35 |
+
assert source_frames is not None, "请传入 source_frames"
|
| 36 |
+
|
| 37 |
+
reader = imageio.get_reader(video_path)
|
| 38 |
+
try:
|
| 39 |
+
total_frames = reader.count_frames()
|
| 40 |
+
except Exception:
|
| 41 |
+
total_frames = sum(1 for _ in reader)
|
| 42 |
+
reader = imageio.get_reader(video_path)
|
| 43 |
+
|
| 44 |
+
stride = max(1, total_frames // source_frames)
|
| 45 |
+
start_frame = torch.randint(0, max(1, total_frames - stride * source_frames), (1,))[0].item()
|
| 46 |
+
|
| 47 |
+
frames = []
|
| 48 |
+
original_height, original_width = None, None
|
| 49 |
+
|
| 50 |
+
for i in range(source_frames):
|
| 51 |
+
idx = start_frame + i * stride
|
| 52 |
+
if idx >= total_frames:
|
| 53 |
+
break
|
| 54 |
+
try:
|
| 55 |
+
frame = reader.get_data(idx)
|
| 56 |
+
pil_frame = Image.fromarray(frame)
|
| 57 |
+
if original_height is None:
|
| 58 |
+
original_width, original_height = pil_frame.size
|
| 59 |
+
print(f"Original video dimensions: {original_width}x{original_height}")
|
| 60 |
+
frames.append(pil_frame)
|
| 61 |
+
except IndexError:
|
| 62 |
+
break
|
| 63 |
+
|
| 64 |
+
reader.close()
|
| 65 |
+
|
| 66 |
+
while len(frames) < source_frames:
|
| 67 |
+
if frames:
|
| 68 |
+
frames.append(frames[-1].copy())
|
| 69 |
+
else:
|
| 70 |
+
w, h = (original_width, original_height) if original_width else (832, 480)
|
| 71 |
+
frames.append(Image.new('RGB', (w, h), (0, 0, 0)))
|
| 72 |
+
|
| 73 |
+
assert len(frames) == source_frames
|
| 74 |
+
print(f"Loaded {source_frames} source frames")
|
| 75 |
+
|
| 76 |
+
input_video = torch.from_numpy(np.array(frames))
|
| 77 |
+
input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0).float()
|
| 78 |
+
input_video = input_video * (2.0 / 255.0) - 1.0
|
| 79 |
+
|
| 80 |
+
return input_video, original_height, original_width
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def parse_args():
|
| 84 |
+
parser = argparse.ArgumentParser(description="Video-to-video CoT reasoning generation from JSON task list with parallel inference")
|
| 85 |
+
parser.add_argument("--test_json", type=str, default=None, help="Path to test JSON file for batch inference")
|
| 86 |
+
parser.add_argument("--prompt", type=str, default=None, help="Text prompt for editing (single mode)")
|
| 87 |
+
parser.add_argument("--video_path", type=str, default=None, help="Path to input video (single mode)")
|
| 88 |
+
parser.add_argument("--model_name", type=str, default="/scratch3/yan204/models/Wan2.1-T2V-14B", help="Model checkpoint path")
|
| 89 |
+
parser.add_argument("--output_dir", type=str, required=True, help="Output directory for generated videos")
|
| 90 |
+
parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducible generation")
|
| 91 |
+
parser.add_argument("--videocof_path", type=str, default=None, help="Path to videocof weight checkpoint")
|
| 92 |
+
parser.add_argument("--num_frames", type=int, default=65, help="Total number of frames (input + generated)")
|
| 93 |
+
parser.add_argument("--source_frames", type=int, default=33, help="Number of source frames; default 33")
|
| 94 |
+
parser.add_argument("--reasoning_frames", type=int, default=4, help="Grounding frames in the middle segment (pixel-space)")
|
| 95 |
+
parser.add_argument("--repeat_rope", action="store_true", help="Enable repeat temporal RoPE for src and tgt segments")
|
| 96 |
+
return parser.parse_args()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Defaults aligned with predict_v2v_json_new.py
|
| 100 |
+
GPU_memory_mode = "sequential_cpu_offload"
|
| 101 |
+
ulysses_degree = 1
|
| 102 |
+
ring_degree = 1
|
| 103 |
+
fsdp_dit = False
|
| 104 |
+
fsdp_text_encoder = True
|
| 105 |
+
compile_dit = False
|
| 106 |
+
enable_teacache = True
|
| 107 |
+
teacache_threshold = 0.10
|
| 108 |
+
num_skip_start_steps = 5
|
| 109 |
+
teacache_offload = False
|
| 110 |
+
cfg_skip_ratio = 0
|
| 111 |
+
enable_riflex = False
|
| 112 |
+
riflex_k = 6
|
| 113 |
+
|
| 114 |
+
config_path = "config/wan2.1/wan_civitai.yaml"
|
| 115 |
+
model_name = "/scratch3/yan204/models/Wan2.1-T2V-14B"
|
| 116 |
+
sampler_name = "Flow_Unipc"
|
| 117 |
+
shift = 3
|
| 118 |
+
transformer_path = None
|
| 119 |
+
vae_path = None
|
| 120 |
+
|
| 121 |
+
fps = 10
|
| 122 |
+
weight_dtype = torch.bfloat16
|
| 123 |
+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
| 124 |
+
guidance_scale = 5.0
|
| 125 |
+
num_inference_steps = 50
|
| 126 |
+
lora_weight = 1.0
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def save_results(tensor: torch.Tensor, file_path: str, fps_out: int = 16):
|
| 130 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 131 |
+
B, C, T, H, W = tensor.shape
|
| 132 |
+
arr = tensor[0].cpu().numpy()
|
| 133 |
+
if T == 1:
|
| 134 |
+
img = arr[:, 0].transpose(1, 2, 0)
|
| 135 |
+
img = (img * 255).astype(np.uint8)
|
| 136 |
+
Image.fromarray(img).save(file_path)
|
| 137 |
+
else:
|
| 138 |
+
save_videos_grid(tensor, file_path, fps=fps_out)
|
| 139 |
+
print(f"Saved video → {file_path}")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _normalize_to_01(video: torch.Tensor) -> torch.Tensor:
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
vmin = float(video.min())
|
| 145 |
+
vmax = float(video.max())
|
| 146 |
+
if vmin < 0.0 or vmax > 1.0:
|
| 147 |
+
video = (video + 1.0) / 2.0
|
| 148 |
+
return video.clamp(0.0, 1.0)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def save_side_by_side(input_tensor: torch.Tensor, sample_tensor: torch.Tensor, file_path: str, fps_out: int = 16):
|
| 152 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 153 |
+
a = _normalize_to_01(input_tensor.detach().cpu())
|
| 154 |
+
b = _normalize_to_01(sample_tensor.detach().cpu())
|
| 155 |
+
|
| 156 |
+
# Align dimensions by cropping to the minimum across T/H/W
|
| 157 |
+
T = min(a.shape[2], b.shape[2])
|
| 158 |
+
H = min(a.shape[3], b.shape[3])
|
| 159 |
+
W = min(a.shape[4], b.shape[4])
|
| 160 |
+
a = a[:, :, :T, :H, :W]
|
| 161 |
+
b = b[:, :, :T, :H, :W]
|
| 162 |
+
|
| 163 |
+
combined = torch.cat([a, b], dim=4)
|
| 164 |
+
save_videos_grid(combined, file_path, fps=fps_out)
|
| 165 |
+
print(f"Saved side-by-side video → {file_path}")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def derive_ground_instruction(edit_instruction_text: str) -> str:
|
| 169 |
+
# Keep wrapper for backward compatibility; reuse the same rule as training dataset
|
| 170 |
+
return derive_ground_object_from_instruction(edit_instruction_text)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def main():
|
| 174 |
+
args = parse_args()
|
| 175 |
+
|
| 176 |
+
# Initialize DDP
|
| 177 |
+
dist.init_process_group(backend="nccl")
|
| 178 |
+
rank = dist.get_rank()
|
| 179 |
+
world_size = dist.get_world_size()
|
| 180 |
+
local_rank = int(os.environ.get("LOCAL_RANK", rank % max(1, torch.cuda.device_count())))
|
| 181 |
+
torch.cuda.set_device(local_rank)
|
| 182 |
+
|
| 183 |
+
if rank == 0:
|
| 184 |
+
print(f"Running parallel CoT inference with {world_size} GPUs")
|
| 185 |
+
print(f"Using seed: {args.seed}")
|
| 186 |
+
|
| 187 |
+
model_name = args.model_name
|
| 188 |
+
|
| 189 |
+
# Load tasks
|
| 190 |
+
if args.test_json:
|
| 191 |
+
if rank == 0:
|
| 192 |
+
print(f"Loading tasks from JSON: {args.test_json}")
|
| 193 |
+
with open(args.test_json, 'r', encoding='utf-8') as f:
|
| 194 |
+
eval_prompts_list = json.load(f)
|
| 195 |
+
|
| 196 |
+
eval_prompts = {}
|
| 197 |
+
for item in eval_prompts_list:
|
| 198 |
+
# Assume item has structure compatible or use fallback logic
|
| 199 |
+
# Here we expect task_type/sample_id to uniquely identify, or we create a name
|
| 200 |
+
if 'task_type' in item and 'sample_id' in item:
|
| 201 |
+
fname = f"{item['task_type']}_{item['sample_id']}.mp4"
|
| 202 |
+
else:
|
| 203 |
+
# Fallback naming if JSON structure is different
|
| 204 |
+
fname = f"sample_{len(eval_prompts)}.mp4"
|
| 205 |
+
eval_prompts[fname] = item
|
| 206 |
+
items = list(eval_prompts.items())
|
| 207 |
+
|
| 208 |
+
elif args.video_path and args.prompt:
|
| 209 |
+
if rank == 0:
|
| 210 |
+
print(f"Running in single video mode: {args.video_path}")
|
| 211 |
+
fname = os.path.basename(args.video_path)
|
| 212 |
+
item = {
|
| 213 |
+
"source_video_path": args.video_path,
|
| 214 |
+
"edit_instruction": args.prompt
|
| 215 |
+
}
|
| 216 |
+
items = [(fname, item)]
|
| 217 |
+
else:
|
| 218 |
+
raise ValueError("Must provide either --test_json or both --video_path and --prompt")
|
| 219 |
+
|
| 220 |
+
# Filter done
|
| 221 |
+
pending_items = []
|
| 222 |
+
for fname, item in items:
|
| 223 |
+
base = os.path.splitext(fname)[0]
|
| 224 |
+
output_video_path = os.path.join(args.output_dir, f"gen_{base}.mp4")
|
| 225 |
+
if not os.path.exists(output_video_path):
|
| 226 |
+
pending_items.append((fname, item))
|
| 227 |
+
|
| 228 |
+
if rank == 0:
|
| 229 |
+
print(f"Total items: {len(items)}, already generated: {len(items) - len(pending_items)}, pending: {len(pending_items)}")
|
| 230 |
+
|
| 231 |
+
# Shard across GPUs
|
| 232 |
+
subset_items = pending_items[rank::world_size] if world_size > 0 else pending_items
|
| 233 |
+
|
| 234 |
+
print(f"[GPU {rank} | local {local_rank}] Processing {len(subset_items)} items")
|
| 235 |
+
|
| 236 |
+
device = torch.device(f"cuda:{local_rank}")
|
| 237 |
+
|
| 238 |
+
# Load config and models
|
| 239 |
+
config = OmegaConf.load(config_path)
|
| 240 |
+
|
| 241 |
+
transformer = WanTransformer3DModel.from_pretrained(
|
| 242 |
+
os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
|
| 243 |
+
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
|
| 244 |
+
low_cpu_mem_usage=True,
|
| 245 |
+
torch_dtype=weight_dtype,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if transformer_path is not None:
|
| 249 |
+
print(f"[GPU {rank}] Loading transformer from checkpoint: {transformer_path}")
|
| 250 |
+
if transformer_path.endswith("safetensors"):
|
| 251 |
+
from safetensors.torch import load_file
|
| 252 |
+
state_dict = load_file(transformer_path)
|
| 253 |
+
else:
|
| 254 |
+
state_dict = torch.load(transformer_path, map_location="cpu")
|
| 255 |
+
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
|
| 256 |
+
m, u = transformer.load_state_dict(state_dict, strict=False)
|
| 257 |
+
print(f"[GPU {rank}] Missing keys: {len(m)}, unexpected keys: {len(u)}")
|
| 258 |
+
|
| 259 |
+
vae = AutoencoderKLWan.from_pretrained(
|
| 260 |
+
os.path.join(model_name, config['vae_kwargs'].get('vae_subpath', 'vae')),
|
| 261 |
+
additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
|
| 262 |
+
).to(weight_dtype)
|
| 263 |
+
|
| 264 |
+
if vae_path is not None:
|
| 265 |
+
print(f"[GPU {rank}] Loading VAE from checkpoint: {vae_path}")
|
| 266 |
+
if vae_path.endswith("safetensors"):
|
| 267 |
+
from safetensors.torch import load_file
|
| 268 |
+
state_dict = load_file(vae_path)
|
| 269 |
+
else:
|
| 270 |
+
state_dict = torch.load(vae_path, map_location="cpu")
|
| 271 |
+
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
|
| 272 |
+
m, u = vae.load_state_dict(state_dict, strict=False)
|
| 273 |
+
print(f"[GPU {rank}] Missing keys: {len(m)}, unexpected keys: {len(u)}")
|
| 274 |
+
|
| 275 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 276 |
+
os.path.join(model_name, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
text_encoder = WanT5EncoderModel.from_pretrained(
|
| 280 |
+
os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
|
| 281 |
+
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
|
| 282 |
+
low_cpu_mem_usage=True,
|
| 283 |
+
torch_dtype=weight_dtype,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
Choosen_Scheduler = {
|
| 287 |
+
"Flow": FlowMatchEulerDiscreteScheduler,
|
| 288 |
+
"Flow_Unipc": FlowUniPCMultistepScheduler,
|
| 289 |
+
"Flow_DPM++": FlowDPMSolverMultistepScheduler,
|
| 290 |
+
}[sampler_name]
|
| 291 |
+
if sampler_name in ["Flow_Unipc", "Flow_DPM++"]:
|
| 292 |
+
config['scheduler_kwargs']['shift'] = 1
|
| 293 |
+
scheduler = Choosen_Scheduler(
|
| 294 |
+
**filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs']))
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
pipeline = WanPipeline(
|
| 298 |
+
transformer=transformer,
|
| 299 |
+
vae=vae,
|
| 300 |
+
tokenizer=tokenizer,
|
| 301 |
+
text_encoder=text_encoder,
|
| 302 |
+
scheduler=scheduler,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# Memory mode
|
| 306 |
+
if GPU_memory_mode == "sequential_cpu_offload":
|
| 307 |
+
replace_parameters_by_name(transformer, ["modulation",], device=device)
|
| 308 |
+
transformer.freqs = transformer.freqs.to(device=device)
|
| 309 |
+
pipeline.enable_sequential_cpu_offload(device=device)
|
| 310 |
+
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
|
| 311 |
+
convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device)
|
| 312 |
+
convert_weight_dtype_wrapper(transformer, weight_dtype)
|
| 313 |
+
pipeline.enable_model_cpu_offload(device=device)
|
| 314 |
+
elif GPU_memory_mode == "model_cpu_offload":
|
| 315 |
+
pipeline.enable_model_cpu_offload(device=device)
|
| 316 |
+
elif GPU_memory_mode == "model_full_load_and_qfloat8":
|
| 317 |
+
convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device)
|
| 318 |
+
convert_weight_dtype_wrapper(transformer, weight_dtype)
|
| 319 |
+
pipeline.to(device=device)
|
| 320 |
+
else:
|
| 321 |
+
pipeline.to(device=device)
|
| 322 |
+
|
| 323 |
+
# LoRA
|
| 324 |
+
if args.videocof_path is not None:
|
| 325 |
+
pipeline = merge_lora(pipeline, args.videocof_path, lora_weight, device=device)
|
| 326 |
+
print(f"[GPU {rank}] Loaded LoRA from {args.videocof_path}")
|
| 327 |
+
|
| 328 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 329 |
+
|
| 330 |
+
generator = torch.Generator(device=device).manual_seed(args.seed + rank)
|
| 331 |
+
|
| 332 |
+
# Grounding indices are now handled inside the pipeline; no forward override needed.
|
| 333 |
+
|
| 334 |
+
for fname, item in subset_items:
|
| 335 |
+
base = os.path.splitext(fname)[0]
|
| 336 |
+
output_video_path = os.path.join(args.output_dir, f"gen_{base}.mp4")
|
| 337 |
+
info_path = os.path.join(args.output_dir, f"gen_{base}_info.txt")
|
| 338 |
+
|
| 339 |
+
print(f"[GPU {rank}] Processing {fname}...")
|
| 340 |
+
|
| 341 |
+
video_path = item["source_video_path"]
|
| 342 |
+
|
| 343 |
+
# Match training dataset (ImageVideoCoTDataset) prompt formatting
|
| 344 |
+
edit_text = item.get('text', item.get('qwen_vl_72b_refined_instruction', item.get('edit_instruction', '')))
|
| 345 |
+
ground_instr = derive_ground_instruction(edit_text)
|
| 346 |
+
prompt = (
|
| 347 |
+
"A video sequence showing three parts: first the original scene, "
|
| 348 |
+
f"then grounded {ground_instr}, and finally the same scene but {edit_text}"
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
input_video, video_height, video_width = load_video_frames(
|
| 353 |
+
video_path,
|
| 354 |
+
source_frames=args.source_frames,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
with torch.no_grad():
|
| 358 |
+
sample = pipeline(
|
| 359 |
+
video=input_video,
|
| 360 |
+
prompt=prompt,
|
| 361 |
+
num_frames=args.num_frames,
|
| 362 |
+
source_frames=args.source_frames,
|
| 363 |
+
reasoning_frames=args.reasoning_frames,
|
| 364 |
+
negative_prompt=negative_prompt,
|
| 365 |
+
height=video_height,
|
| 366 |
+
width=video_width,
|
| 367 |
+
generator=generator,
|
| 368 |
+
guidance_scale=guidance_scale,
|
| 369 |
+
num_inference_steps=num_inference_steps,
|
| 370 |
+
shift=shift,
|
| 371 |
+
repeat_rope=args.repeat_rope,
|
| 372 |
+
cot=True,
|
| 373 |
+
).videos
|
| 374 |
+
|
| 375 |
+
reason_edit_path = os.path.join(args.output_dir, f"gen_{base}_reason_edit.mp4")
|
| 376 |
+
save_results(sample, reason_edit_path, fps)
|
| 377 |
+
print(f"[GPU {rank}] Saved reason+edit video shape: {sample.shape}")
|
| 378 |
+
|
| 379 |
+
edit_video = sample[:, :, -args.source_frames:, :, :]
|
| 380 |
+
save_results(edit_video, output_video_path, fps)
|
| 381 |
+
print(f"[GPU {rank}] Edit video shape: {edit_video.shape}")
|
| 382 |
+
|
| 383 |
+
compare_path = os.path.join(args.output_dir, f"gen_{base}_compare.mp4")
|
| 384 |
+
save_side_by_side(input_video, edit_video, compare_path, fps)
|
| 385 |
+
|
| 386 |
+
with open(info_path, "w", encoding="utf-8") as info_f:
|
| 387 |
+
info_f.write(prompt)
|
| 388 |
+
|
| 389 |
+
print(f"[GPU {rank}] Completed {fname}")
|
| 390 |
+
|
| 391 |
+
if args.videocof_path is not None:
|
| 392 |
+
pipeline = unmerge_lora(pipeline, args.videocof_path, lora_weight, device=device)
|
| 393 |
+
|
| 394 |
+
print(f"[GPU {rank}] Finished processing all assigned items")
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
if __name__ == "__main__":
|
| 398 |
+
main()
|
| 399 |
+
|
| 400 |
+
|
install.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import subprocess
|
| 3 |
+
import locale
|
| 4 |
+
import threading
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
def handle_stream(stream, prefix):
|
| 8 |
+
stream.reconfigure(encoding=locale.getpreferredencoding(), errors='replace')
|
| 9 |
+
for msg in stream:
|
| 10 |
+
if prefix == '[!]' and ('it/s]' in msg or 's/it]' in msg) and ('%|' in msg or 'it [' in msg):
|
| 11 |
+
if msg.startswith('100%'):
|
| 12 |
+
print('\r' + msg, end="", file=sys.stderr),
|
| 13 |
+
else:
|
| 14 |
+
print('\r' + msg[:-1], end="", file=sys.stderr),
|
| 15 |
+
else:
|
| 16 |
+
if prefix == '[!]':
|
| 17 |
+
print(prefix, msg, end="", file=sys.stderr)
|
| 18 |
+
else:
|
| 19 |
+
print(prefix, msg, end="")
|
| 20 |
+
|
| 21 |
+
def process_wrap(cmd_str, cwd_path, handler=None):
|
| 22 |
+
process = subprocess.Popen(cmd_str, cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1)
|
| 23 |
+
|
| 24 |
+
if handler is None:
|
| 25 |
+
handler = handle_stream
|
| 26 |
+
|
| 27 |
+
stdout_thread = threading.Thread(target=handler, args=(process.stdout, ""))
|
| 28 |
+
stderr_thread = threading.Thread(target=handler, args=(process.stderr, "[!]"))
|
| 29 |
+
|
| 30 |
+
stdout_thread.start()
|
| 31 |
+
stderr_thread.start()
|
| 32 |
+
|
| 33 |
+
stdout_thread.join()
|
| 34 |
+
stderr_thread.join()
|
| 35 |
+
|
| 36 |
+
return process.wait()
|
| 37 |
+
|
| 38 |
+
assert process_wrap([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd_path=os.path.dirname(os.path.realpath(__file__))) == 0, "ERROR: Failed to install requirements.txt. Please install them manually, and restart ComfyUI."
|
| 39 |
+
|
| 40 |
+
nodep_packages = [
|
| 41 |
+
"kornia>=0.6.9",
|
| 42 |
+
"xformers>=0.0.20",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
assert process_wrap([sys.executable, "-m", "pip", "install", "--no-deps", *nodep_packages], cwd_path=os.path.dirname(os.path.realpath(__file__))) == 0, "ERROR: Failed to install last set of packages. Please install them manually, and restart ComfyUI."
|
pyproject.toml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "videox-fun"
|
| 3 |
+
description = "VideoX-Fun is a video generation pipeline that can be used to generate AI images and videos, as well as to train baseline and Lora models for Diffusion Transformer. We support direct prediction from pre-trained baseline models to generate videos with different resolutions, durations, and FPS. Additionally, we also support users in training their own baseline and Lora models to perform specific style transformations."
|
| 4 |
+
version = "1.0.0"
|
| 5 |
+
license = {file = "LICENSE"}
|
| 6 |
+
dependencies = ["Pillow", "einops", "safetensors", "timm", "tomesd", "torch>=2.1.2", "torchdiffeq", "torchsde", "decord", "datasets", "numpy", "scikit-image", "opencv-python", "omegaconf", "SentencePiece", "albumentations", "imageio[ffmpeg]", "imageio[pyav]", "tensorboard", "beautifulsoup4", "ftfy", "func_timeout", "accelerate>=0.25.0", "gradio>=3.41.2,<=3.48.0", "diffusers>=0.30.1,<=0.31.0", "transformers>=4.46.2"]
|
| 7 |
+
|
| 8 |
+
[project.urls]
|
| 9 |
+
Repository = "https://github.com/aigc-apps/VideoX-Fun"
|
| 10 |
+
# Used by Comfy Registry https://comfyregistry.org
|
| 11 |
+
|
| 12 |
+
[tool.comfy]
|
| 13 |
+
PublisherId = "bubbliiiing"
|
| 14 |
+
DisplayName = "VideoX-Fun"
|
| 15 |
+
Icon = ""
|
requirements.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Pillow
|
| 2 |
+
einops
|
| 3 |
+
safetensors
|
| 4 |
+
timm
|
| 5 |
+
tomesd
|
| 6 |
+
torchdiffeq
|
| 7 |
+
torchsde
|
| 8 |
+
decord
|
| 9 |
+
datasets
|
| 10 |
+
numpy
|
| 11 |
+
scikit-image
|
| 12 |
+
opencv-python
|
| 13 |
+
omegaconf
|
| 14 |
+
SentencePiece
|
| 15 |
+
albumentations
|
| 16 |
+
imageio[ffmpeg]
|
| 17 |
+
imageio[pyav]
|
| 18 |
+
tensorboard
|
| 19 |
+
beautifulsoup4
|
| 20 |
+
ftfy
|
| 21 |
+
func_timeout
|
| 22 |
+
onnxruntime
|
| 23 |
+
accelerate>=0.25.0
|
| 24 |
+
gradio>=3.41.2
|
| 25 |
+
diffusers>=0.30.1
|
| 26 |
+
transformers>=4.46.2
|
scripts/local_style.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export CUDA_VISIBLE_DEVICES=2
|
| 2 |
+
|
| 3 |
+
torchrun --nproc_per_node=1 inference.py \
|
| 4 |
+
--video_path assets/sign.mp4 \
|
| 5 |
+
--prompt "Replace the yellow \"SCHOOL\" sign with a red hospital sign, featuring a white hospital emblem on the top and the word \"HOSPITAL\" below." \
|
| 6 |
+
--output_dir results/local_style \
|
| 7 |
+
--model_name /scratch3/yan204/models/Wan2.1-T2V-14B \
|
| 8 |
+
--seed 0 \
|
| 9 |
+
--num_frames 33 \
|
| 10 |
+
--source_frames 33 \
|
| 11 |
+
--reasoning_frames 4 \
|
| 12 |
+
--repeat_rope \
|
| 13 |
+
--videocof_path videocof_weight/videocof.safetensors
|
scripts/obj_add.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export CUDA_VISIBLE_DEVICES=0
|
| 2 |
+
|
| 3 |
+
torchrun --nproc_per_node=1 inference.py \
|
| 4 |
+
--video_path assets/woman_ballon.mp4 \
|
| 5 |
+
--prompt "Add the woman in a floral dress pointing at the balloon on the left." \
|
| 6 |
+
--output_dir results/obj_add \
|
| 7 |
+
--model_name /scratch3/yan204/models/Wan2.1-T2V-14B \
|
| 8 |
+
--seed 0 \
|
| 9 |
+
--num_frames 33 \
|
| 10 |
+
--source_frames 33 \
|
| 11 |
+
--reasoning_frames 4 \
|
| 12 |
+
--repeat_rope \
|
| 13 |
+
--videocof_path videocof_weight/videocof.safetensors
|
scripts/obj_rem.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export CUDA_VISIBLE_DEVICES=1
|
| 2 |
+
|
| 3 |
+
torchrun --nproc_per_node=1 inference.py \
|
| 4 |
+
--video_path assets/two_man.mp4 \
|
| 5 |
+
--prompt "Remove the young man with short black hair wearing black shirt on the left." \
|
| 6 |
+
--output_dir results/obj_rem \
|
| 7 |
+
--model_name /scratch3/yan204/models/Wan2.1-T2V-14B \
|
| 8 |
+
--seed 0 \
|
| 9 |
+
--num_frames 33 \
|
| 10 |
+
--source_frames 33 \
|
| 11 |
+
--reasoning_frames 4 \
|
| 12 |
+
--repeat_rope \
|
| 13 |
+
--videocof_path videocof_weight/videocof.safetensors
|
scripts/parallel_infer.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
| 2 |
+
|
| 3 |
+
torchrun --nproc_per_node=4 inference.py \
|
| 4 |
+
--test_json assets/teaser_test.json \
|
| 5 |
+
--output_dir results/torch_2.5.1 \
|
| 6 |
+
--model_name /scratch3/yan204/models/Wan2.1-T2V-14B \
|
| 7 |
+
--seed 0 \
|
| 8 |
+
--num_frames 33 \
|
| 9 |
+
--source_frames 33 \
|
| 10 |
+
--reasoning_frames 4 \
|
| 11 |
+
--repeat_rope \
|
| 12 |
+
--videocof_path videocof_weight/videocof.safetensors
|
videox_fun/__init__.py
ADDED
|
File without changes
|
videox_fun/api/api.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import gc
|
| 3 |
+
import hashlib
|
| 4 |
+
import io
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import requests
|
| 11 |
+
import torch
|
| 12 |
+
from fastapi import FastAPI
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Function to encode a file to Base64
|
| 17 |
+
def encode_file_to_base64(file_path):
|
| 18 |
+
with open(file_path, "rb") as file:
|
| 19 |
+
# Encode the data to Base64
|
| 20 |
+
file_base64 = base64.b64encode(file.read())
|
| 21 |
+
return file_base64
|
| 22 |
+
|
| 23 |
+
def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
|
| 24 |
+
@app.post("/videox_fun/update_diffusion_transformer")
|
| 25 |
+
def _update_diffusion_transformer_api(
|
| 26 |
+
datas: dict,
|
| 27 |
+
):
|
| 28 |
+
diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
controller.update_diffusion_transformer(
|
| 32 |
+
diffusion_transformer_path
|
| 33 |
+
)
|
| 34 |
+
comment = "Success"
|
| 35 |
+
except Exception as e:
|
| 36 |
+
torch.cuda.empty_cache()
|
| 37 |
+
comment = f"Error. error information is {str(e)}"
|
| 38 |
+
|
| 39 |
+
return {"message": comment}
|
| 40 |
+
|
| 41 |
+
def download_from_url(url, timeout=10):
|
| 42 |
+
try:
|
| 43 |
+
response = requests.get(url, timeout=timeout)
|
| 44 |
+
response.raise_for_status() # 检查请求是否成功
|
| 45 |
+
return response.content
|
| 46 |
+
except requests.exceptions.RequestException as e:
|
| 47 |
+
print(f"Error downloading from {url}: {e}")
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
def save_base64_video(base64_string):
|
| 51 |
+
video_data = base64.b64decode(base64_string)
|
| 52 |
+
|
| 53 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
| 54 |
+
filename = f"{md5_hash}.mp4"
|
| 55 |
+
|
| 56 |
+
temp_dir = tempfile.gettempdir()
|
| 57 |
+
file_path = os.path.join(temp_dir, filename)
|
| 58 |
+
|
| 59 |
+
with open(file_path, 'wb') as video_file:
|
| 60 |
+
video_file.write(video_data)
|
| 61 |
+
|
| 62 |
+
return file_path
|
| 63 |
+
|
| 64 |
+
def save_base64_image(base64_string):
|
| 65 |
+
video_data = base64.b64decode(base64_string)
|
| 66 |
+
|
| 67 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
| 68 |
+
filename = f"{md5_hash}.jpg"
|
| 69 |
+
|
| 70 |
+
temp_dir = tempfile.gettempdir()
|
| 71 |
+
file_path = os.path.join(temp_dir, filename)
|
| 72 |
+
|
| 73 |
+
with open(file_path, 'wb') as video_file:
|
| 74 |
+
video_file.write(video_data)
|
| 75 |
+
|
| 76 |
+
return file_path
|
| 77 |
+
|
| 78 |
+
def save_url_video(url):
|
| 79 |
+
video_data = download_from_url(url)
|
| 80 |
+
if video_data:
|
| 81 |
+
return save_base64_video(base64.b64encode(video_data))
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
def save_url_image(url):
|
| 85 |
+
image_data = download_from_url(url)
|
| 86 |
+
if image_data:
|
| 87 |
+
return save_base64_image(base64.b64encode(image_data))
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
|
| 91 |
+
@app.post("/videox_fun/infer_forward")
|
| 92 |
+
def _infer_forward_api(
|
| 93 |
+
datas: dict,
|
| 94 |
+
):
|
| 95 |
+
base_model_path = datas.get('base_model_path', 'none')
|
| 96 |
+
base_model_2_path = datas.get('base_model_2_path', 'none')
|
| 97 |
+
lora_model_path = datas.get('lora_model_path', 'none')
|
| 98 |
+
lora_model_2_path = datas.get('lora_model_2_path', 'none')
|
| 99 |
+
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
| 100 |
+
prompt_textbox = datas.get('prompt_textbox', None)
|
| 101 |
+
negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
|
| 102 |
+
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
| 103 |
+
sample_step_slider = datas.get('sample_step_slider', 30)
|
| 104 |
+
resize_method = datas.get('resize_method', "Generate by")
|
| 105 |
+
width_slider = datas.get('width_slider', 672)
|
| 106 |
+
height_slider = datas.get('height_slider', 384)
|
| 107 |
+
base_resolution = datas.get('base_resolution', 512)
|
| 108 |
+
is_image = datas.get('is_image', False)
|
| 109 |
+
generation_method = datas.get('generation_method', False)
|
| 110 |
+
length_slider = datas.get('length_slider', 49)
|
| 111 |
+
overlap_video_length = datas.get('overlap_video_length', 4)
|
| 112 |
+
partial_video_length = datas.get('partial_video_length', 72)
|
| 113 |
+
cfg_scale_slider = datas.get('cfg_scale_slider', 6)
|
| 114 |
+
start_image = datas.get('start_image', None)
|
| 115 |
+
end_image = datas.get('end_image', None)
|
| 116 |
+
validation_video = datas.get('validation_video', None)
|
| 117 |
+
validation_video_mask = datas.get('validation_video_mask', None)
|
| 118 |
+
control_video = datas.get('control_video', None)
|
| 119 |
+
denoise_strength = datas.get('denoise_strength', 0.70)
|
| 120 |
+
seed_textbox = datas.get("seed_textbox", 43)
|
| 121 |
+
|
| 122 |
+
ref_image = datas.get('ref_image', None)
|
| 123 |
+
enable_teacache = datas.get('enable_teacache', True)
|
| 124 |
+
teacache_threshold = datas.get('teacache_threshold', 0.10)
|
| 125 |
+
num_skip_start_steps = datas.get('num_skip_start_steps', 1)
|
| 126 |
+
teacache_offload = datas.get('teacache_offload', False)
|
| 127 |
+
cfg_skip_ratio = datas.get('cfg_skip_ratio', 0)
|
| 128 |
+
enable_riflex = datas.get('enable_riflex', False)
|
| 129 |
+
riflex_k = datas.get('riflex_k', 6)
|
| 130 |
+
fps = datas.get('fps', None)
|
| 131 |
+
|
| 132 |
+
generation_method = "Image Generation" if is_image else generation_method
|
| 133 |
+
|
| 134 |
+
if start_image is not None:
|
| 135 |
+
if start_image.startswith('http'):
|
| 136 |
+
start_image = save_url_image(start_image)
|
| 137 |
+
start_image = [Image.open(start_image).convert("RGB")]
|
| 138 |
+
else:
|
| 139 |
+
start_image = base64.b64decode(start_image)
|
| 140 |
+
start_image = [Image.open(BytesIO(start_image)).convert("RGB")]
|
| 141 |
+
|
| 142 |
+
if end_image is not None:
|
| 143 |
+
if end_image.startswith('http'):
|
| 144 |
+
end_image = save_url_image(end_image)
|
| 145 |
+
end_image = [Image.open(end_image).convert("RGB")]
|
| 146 |
+
else:
|
| 147 |
+
end_image = base64.b64decode(end_image)
|
| 148 |
+
end_image = [Image.open(BytesIO(end_image)).convert("RGB")]
|
| 149 |
+
|
| 150 |
+
if validation_video is not None:
|
| 151 |
+
if validation_video.startswith('http'):
|
| 152 |
+
validation_video = save_url_video(validation_video)
|
| 153 |
+
else:
|
| 154 |
+
validation_video = save_base64_video(validation_video)
|
| 155 |
+
|
| 156 |
+
if validation_video_mask is not None:
|
| 157 |
+
if validation_video_mask.startswith('http'):
|
| 158 |
+
validation_video_mask = save_url_image(validation_video_mask)
|
| 159 |
+
else:
|
| 160 |
+
validation_video_mask = save_base64_image(validation_video_mask)
|
| 161 |
+
|
| 162 |
+
if control_video is not None:
|
| 163 |
+
if control_video.startswith('http'):
|
| 164 |
+
control_video = save_url_video(control_video)
|
| 165 |
+
else:
|
| 166 |
+
control_video = save_base64_video(control_video)
|
| 167 |
+
|
| 168 |
+
if ref_image is not None:
|
| 169 |
+
if ref_image.startswith('http'):
|
| 170 |
+
ref_image = save_url_image(ref_image)
|
| 171 |
+
ref_image = [Image.open(ref_image).convert("RGB")]
|
| 172 |
+
else:
|
| 173 |
+
ref_image = base64.b64decode(ref_image)
|
| 174 |
+
ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")]
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
save_sample_path, comment = controller.generate(
|
| 178 |
+
"",
|
| 179 |
+
base_model_path,
|
| 180 |
+
lora_model_path,
|
| 181 |
+
lora_alpha_slider,
|
| 182 |
+
prompt_textbox,
|
| 183 |
+
negative_prompt_textbox,
|
| 184 |
+
sampler_dropdown,
|
| 185 |
+
sample_step_slider,
|
| 186 |
+
resize_method,
|
| 187 |
+
width_slider,
|
| 188 |
+
height_slider,
|
| 189 |
+
base_resolution,
|
| 190 |
+
generation_method,
|
| 191 |
+
length_slider,
|
| 192 |
+
overlap_video_length,
|
| 193 |
+
partial_video_length,
|
| 194 |
+
cfg_scale_slider,
|
| 195 |
+
start_image,
|
| 196 |
+
end_image,
|
| 197 |
+
validation_video,
|
| 198 |
+
validation_video_mask,
|
| 199 |
+
control_video,
|
| 200 |
+
denoise_strength,
|
| 201 |
+
seed_textbox,
|
| 202 |
+
ref_image = ref_image,
|
| 203 |
+
enable_teacache = enable_teacache,
|
| 204 |
+
teacache_threshold = teacache_threshold,
|
| 205 |
+
num_skip_start_steps = num_skip_start_steps,
|
| 206 |
+
teacache_offload = teacache_offload,
|
| 207 |
+
cfg_skip_ratio = cfg_skip_ratio,
|
| 208 |
+
enable_riflex = enable_riflex,
|
| 209 |
+
riflex_k = riflex_k,
|
| 210 |
+
base_model_2_dropdown = base_model_2_path,
|
| 211 |
+
lora_model_2_dropdown = lora_model_2_path,
|
| 212 |
+
fps = fps,
|
| 213 |
+
is_api = True,
|
| 214 |
+
)
|
| 215 |
+
except Exception as e:
|
| 216 |
+
gc.collect()
|
| 217 |
+
torch.cuda.empty_cache()
|
| 218 |
+
torch.cuda.ipc_collect()
|
| 219 |
+
save_sample_path = ""
|
| 220 |
+
comment = f"Error. error information is {str(e)}"
|
| 221 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 222 |
+
|
| 223 |
+
if save_sample_path != "":
|
| 224 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
|
| 225 |
+
else:
|
| 226 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": None}
|
videox_fun/api/api_multi_nodes.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is modified from https://github.com/xdit-project/xDiT/blob/main/entrypoints/launch.py
|
| 2 |
+
import base64
|
| 3 |
+
import gc
|
| 4 |
+
import hashlib
|
| 5 |
+
import io
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import requests
|
| 12 |
+
import torch
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
from fastapi import FastAPI, HTTPException
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
from .api import download_from_url, encode_file_to_base64
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import ray
|
| 21 |
+
except:
|
| 22 |
+
print("Ray is not installed. If you want to use multi gpus api. Please install it by running 'pip install ray'.")
|
| 23 |
+
ray = None
|
| 24 |
+
|
| 25 |
+
def save_base64_video_dist(base64_string):
|
| 26 |
+
video_data = base64.b64decode(base64_string)
|
| 27 |
+
|
| 28 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
| 29 |
+
filename = f"{md5_hash}.mp4"
|
| 30 |
+
|
| 31 |
+
temp_dir = tempfile.gettempdir()
|
| 32 |
+
file_path = os.path.join(temp_dir, filename)
|
| 33 |
+
|
| 34 |
+
if dist.is_initialized():
|
| 35 |
+
if dist.get_rank() == 0:
|
| 36 |
+
with open(file_path, 'wb') as video_file:
|
| 37 |
+
video_file.write(video_data)
|
| 38 |
+
dist.barrier()
|
| 39 |
+
else:
|
| 40 |
+
with open(file_path, 'wb') as video_file:
|
| 41 |
+
video_file.write(video_data)
|
| 42 |
+
return file_path
|
| 43 |
+
|
| 44 |
+
def save_base64_image_dist(base64_string):
|
| 45 |
+
video_data = base64.b64decode(base64_string)
|
| 46 |
+
|
| 47 |
+
md5_hash = hashlib.md5(video_data).hexdigest()
|
| 48 |
+
filename = f"{md5_hash}.jpg"
|
| 49 |
+
|
| 50 |
+
temp_dir = tempfile.gettempdir()
|
| 51 |
+
file_path = os.path.join(temp_dir, filename)
|
| 52 |
+
|
| 53 |
+
if dist.is_initialized():
|
| 54 |
+
if dist.get_rank() == 0:
|
| 55 |
+
with open(file_path, 'wb') as video_file:
|
| 56 |
+
video_file.write(video_data)
|
| 57 |
+
dist.barrier()
|
| 58 |
+
else:
|
| 59 |
+
with open(file_path, 'wb') as video_file:
|
| 60 |
+
video_file.write(video_data)
|
| 61 |
+
return file_path
|
| 62 |
+
|
| 63 |
+
def save_url_video_dist(url):
|
| 64 |
+
video_data = download_from_url(url)
|
| 65 |
+
if video_data:
|
| 66 |
+
return save_base64_video_dist(base64.b64encode(video_data))
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
def save_url_image_dist(url):
|
| 70 |
+
image_data = download_from_url(url)
|
| 71 |
+
if image_data:
|
| 72 |
+
return save_base64_image_dist(base64.b64encode(image_data))
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
if ray is not None:
|
| 76 |
+
@ray.remote(num_gpus=1)
|
| 77 |
+
class MultiNodesGenerator:
|
| 78 |
+
def __init__(
|
| 79 |
+
self, rank: int, world_size: int, Controller,
|
| 80 |
+
GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
|
| 81 |
+
config_path=None, ulysses_degree=1, ring_degree=1,
|
| 82 |
+
fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False,
|
| 83 |
+
weight_dtype=None, savedir_sample=None,
|
| 84 |
+
):
|
| 85 |
+
# Set PyTorch distributed environment variables
|
| 86 |
+
os.environ["RANK"] = str(rank)
|
| 87 |
+
os.environ["WORLD_SIZE"] = str(world_size)
|
| 88 |
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
| 89 |
+
os.environ["MASTER_PORT"] = "29500"
|
| 90 |
+
|
| 91 |
+
self.rank = rank
|
| 92 |
+
self.controller = Controller(
|
| 93 |
+
GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
|
| 94 |
+
ulysses_degree=ulysses_degree, ring_degree=ring_degree,
|
| 95 |
+
fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit,
|
| 96 |
+
weight_dtype=weight_dtype, savedir_sample=savedir_sample,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def generate(self, datas):
|
| 100 |
+
try:
|
| 101 |
+
base_model_path = datas.get('base_model_path', 'none')
|
| 102 |
+
base_model_2_path = datas.get('base_model_2_path', 'none')
|
| 103 |
+
lora_model_path = datas.get('lora_model_path', 'none')
|
| 104 |
+
lora_model_2_path = datas.get('lora_model_2_path', 'none')
|
| 105 |
+
lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
|
| 106 |
+
prompt_textbox = datas.get('prompt_textbox', None)
|
| 107 |
+
negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
|
| 108 |
+
sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
|
| 109 |
+
sample_step_slider = datas.get('sample_step_slider', 30)
|
| 110 |
+
resize_method = datas.get('resize_method', "Generate by")
|
| 111 |
+
width_slider = datas.get('width_slider', 672)
|
| 112 |
+
height_slider = datas.get('height_slider', 384)
|
| 113 |
+
base_resolution = datas.get('base_resolution', 512)
|
| 114 |
+
is_image = datas.get('is_image', False)
|
| 115 |
+
generation_method = datas.get('generation_method', False)
|
| 116 |
+
length_slider = datas.get('length_slider', 49)
|
| 117 |
+
overlap_video_length = datas.get('overlap_video_length', 4)
|
| 118 |
+
partial_video_length = datas.get('partial_video_length', 72)
|
| 119 |
+
cfg_scale_slider = datas.get('cfg_scale_slider', 6)
|
| 120 |
+
start_image = datas.get('start_image', None)
|
| 121 |
+
end_image = datas.get('end_image', None)
|
| 122 |
+
validation_video = datas.get('validation_video', None)
|
| 123 |
+
validation_video_mask = datas.get('validation_video_mask', None)
|
| 124 |
+
control_video = datas.get('control_video', None)
|
| 125 |
+
denoise_strength = datas.get('denoise_strength', 0.70)
|
| 126 |
+
seed_textbox = datas.get("seed_textbox", 43)
|
| 127 |
+
|
| 128 |
+
ref_image = datas.get('ref_image', None)
|
| 129 |
+
enable_teacache = datas.get('enable_teacache', True)
|
| 130 |
+
teacache_threshold = datas.get('teacache_threshold', 0.10)
|
| 131 |
+
num_skip_start_steps = datas.get('num_skip_start_steps', 1)
|
| 132 |
+
teacache_offload = datas.get('teacache_offload', False)
|
| 133 |
+
cfg_skip_ratio = datas.get('cfg_skip_ratio', 0)
|
| 134 |
+
enable_riflex = datas.get('enable_riflex', False)
|
| 135 |
+
riflex_k = datas.get('riflex_k', 6)
|
| 136 |
+
fps = datas.get('fps', None)
|
| 137 |
+
|
| 138 |
+
generation_method = "Image Generation" if is_image else generation_method
|
| 139 |
+
|
| 140 |
+
if start_image is not None:
|
| 141 |
+
if start_image.startswith('http'):
|
| 142 |
+
start_image = save_url_image_dist(start_image)
|
| 143 |
+
start_image = [Image.open(start_image).convert("RGB")]
|
| 144 |
+
else:
|
| 145 |
+
start_image = base64.b64decode(start_image)
|
| 146 |
+
start_image = [Image.open(BytesIO(start_image)).convert("RGB")]
|
| 147 |
+
|
| 148 |
+
if end_image is not None:
|
| 149 |
+
if end_image.startswith('http'):
|
| 150 |
+
end_image = save_url_image_dist(end_image)
|
| 151 |
+
end_image = [Image.open(end_image).convert("RGB")]
|
| 152 |
+
else:
|
| 153 |
+
end_image = base64.b64decode(end_image)
|
| 154 |
+
end_image = [Image.open(BytesIO(end_image)).convert("RGB")]
|
| 155 |
+
|
| 156 |
+
if validation_video is not None:
|
| 157 |
+
if validation_video.startswith('http'):
|
| 158 |
+
validation_video = save_url_video_dist(validation_video)
|
| 159 |
+
else:
|
| 160 |
+
validation_video = save_base64_video_dist(validation_video)
|
| 161 |
+
|
| 162 |
+
if validation_video_mask is not None:
|
| 163 |
+
if validation_video_mask.startswith('http'):
|
| 164 |
+
validation_video_mask = save_url_image_dist(validation_video_mask)
|
| 165 |
+
else:
|
| 166 |
+
validation_video_mask = save_base64_image_dist(validation_video_mask)
|
| 167 |
+
|
| 168 |
+
if control_video is not None:
|
| 169 |
+
if control_video.startswith('http'):
|
| 170 |
+
control_video = save_url_video_dist(control_video)
|
| 171 |
+
else:
|
| 172 |
+
control_video = save_base64_video_dist(control_video)
|
| 173 |
+
|
| 174 |
+
if ref_image is not None:
|
| 175 |
+
if ref_image.startswith('http'):
|
| 176 |
+
ref_image = save_url_image_dist(ref_image)
|
| 177 |
+
ref_image = [Image.open(ref_image).convert("RGB")]
|
| 178 |
+
else:
|
| 179 |
+
ref_image = base64.b64decode(ref_image)
|
| 180 |
+
ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")]
|
| 181 |
+
|
| 182 |
+
try:
|
| 183 |
+
save_sample_path, comment = self.controller.generate(
|
| 184 |
+
"",
|
| 185 |
+
base_model_path,
|
| 186 |
+
lora_model_path,
|
| 187 |
+
lora_alpha_slider,
|
| 188 |
+
prompt_textbox,
|
| 189 |
+
negative_prompt_textbox,
|
| 190 |
+
sampler_dropdown,
|
| 191 |
+
sample_step_slider,
|
| 192 |
+
resize_method,
|
| 193 |
+
width_slider,
|
| 194 |
+
height_slider,
|
| 195 |
+
base_resolution,
|
| 196 |
+
generation_method,
|
| 197 |
+
length_slider,
|
| 198 |
+
overlap_video_length,
|
| 199 |
+
partial_video_length,
|
| 200 |
+
cfg_scale_slider,
|
| 201 |
+
start_image,
|
| 202 |
+
end_image,
|
| 203 |
+
validation_video,
|
| 204 |
+
validation_video_mask,
|
| 205 |
+
control_video,
|
| 206 |
+
denoise_strength,
|
| 207 |
+
seed_textbox,
|
| 208 |
+
ref_image = ref_image,
|
| 209 |
+
enable_teacache = enable_teacache,
|
| 210 |
+
teacache_threshold = teacache_threshold,
|
| 211 |
+
num_skip_start_steps = num_skip_start_steps,
|
| 212 |
+
teacache_offload = teacache_offload,
|
| 213 |
+
cfg_skip_ratio = cfg_skip_ratio,
|
| 214 |
+
enable_riflex = enable_riflex,
|
| 215 |
+
riflex_k = riflex_k,
|
| 216 |
+
base_model_2_dropdown = base_model_2_path,
|
| 217 |
+
lora_model_2_dropdown = lora_model_2_path,
|
| 218 |
+
fps = fps,
|
| 219 |
+
is_api = True,
|
| 220 |
+
)
|
| 221 |
+
except Exception as e:
|
| 222 |
+
gc.collect()
|
| 223 |
+
torch.cuda.empty_cache()
|
| 224 |
+
torch.cuda.ipc_collect()
|
| 225 |
+
save_sample_path = ""
|
| 226 |
+
comment = f"Error. error information is {str(e)}"
|
| 227 |
+
if dist.is_initialized():
|
| 228 |
+
if dist.get_rank() == 0:
|
| 229 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 230 |
+
else:
|
| 231 |
+
return None
|
| 232 |
+
else:
|
| 233 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
if dist.is_initialized():
|
| 237 |
+
if dist.get_rank() == 0:
|
| 238 |
+
if save_sample_path != "":
|
| 239 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
|
| 240 |
+
else:
|
| 241 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 242 |
+
else:
|
| 243 |
+
return None
|
| 244 |
+
else:
|
| 245 |
+
if save_sample_path != "":
|
| 246 |
+
return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
|
| 247 |
+
else:
|
| 248 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 249 |
+
|
| 250 |
+
except Exception as e:
|
| 251 |
+
print(f"Error generating: {str(e)}")
|
| 252 |
+
comment = f"Error generating: {str(e)}"
|
| 253 |
+
if dist.is_initialized():
|
| 254 |
+
if dist.get_rank() == 0:
|
| 255 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 256 |
+
else:
|
| 257 |
+
return None
|
| 258 |
+
else:
|
| 259 |
+
return {"message": comment, "save_sample_path": None, "base64_encoding": None}
|
| 260 |
+
|
| 261 |
+
class MultiNodesEngine:
|
| 262 |
+
def __init__(
|
| 263 |
+
self,
|
| 264 |
+
world_size,
|
| 265 |
+
Controller,
|
| 266 |
+
GPU_memory_mode,
|
| 267 |
+
scheduler_dict,
|
| 268 |
+
model_name,
|
| 269 |
+
model_type,
|
| 270 |
+
config_path,
|
| 271 |
+
ulysses_degree=1,
|
| 272 |
+
ring_degree=1,
|
| 273 |
+
fsdp_dit=False,
|
| 274 |
+
fsdp_text_encoder=False,
|
| 275 |
+
compile_dit=False,
|
| 276 |
+
weight_dtype=torch.bfloat16,
|
| 277 |
+
savedir_sample="samples"
|
| 278 |
+
):
|
| 279 |
+
# Ensure Ray is initialized
|
| 280 |
+
if not ray.is_initialized():
|
| 281 |
+
ray.init()
|
| 282 |
+
|
| 283 |
+
num_workers = world_size
|
| 284 |
+
self.workers = [
|
| 285 |
+
MultiNodesGenerator.remote(
|
| 286 |
+
rank, world_size, Controller,
|
| 287 |
+
GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
|
| 288 |
+
ulysses_degree=ulysses_degree, ring_degree=ring_degree,
|
| 289 |
+
fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit,
|
| 290 |
+
weight_dtype=weight_dtype, savedir_sample=savedir_sample,
|
| 291 |
+
)
|
| 292 |
+
for rank in range(num_workers)
|
| 293 |
+
]
|
| 294 |
+
print("Update workers done")
|
| 295 |
+
|
| 296 |
+
async def generate(self, data):
|
| 297 |
+
results = ray.get([
|
| 298 |
+
worker.generate.remote(data)
|
| 299 |
+
for worker in self.workers
|
| 300 |
+
])
|
| 301 |
+
|
| 302 |
+
return next(path for path in results if path is not None)
|
| 303 |
+
|
| 304 |
+
def multi_nodes_infer_forward_api(_: gr.Blocks, app: FastAPI, engine):
|
| 305 |
+
|
| 306 |
+
@app.post("/videox_fun/infer_forward")
|
| 307 |
+
async def _multi_nodes_infer_forward_api(
|
| 308 |
+
datas: dict,
|
| 309 |
+
):
|
| 310 |
+
try:
|
| 311 |
+
result = await engine.generate(datas)
|
| 312 |
+
return result
|
| 313 |
+
except Exception as e:
|
| 314 |
+
if isinstance(e, HTTPException):
|
| 315 |
+
raise e
|
| 316 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 317 |
+
else:
|
| 318 |
+
MultiNodesEngine = None
|
| 319 |
+
MultiNodesGenerator = None
|
| 320 |
+
multi_nodes_infer_forward_api = None
|
videox_fun/data/bucket_sampler.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import os
|
| 3 |
+
from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
|
| 4 |
+
Sized, TypeVar, Union)
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torch.utils.data import BatchSampler, Dataset, Sampler
|
| 11 |
+
|
| 12 |
+
# Original presets (commented out per request):
|
| 13 |
+
CUSTOM_ASPECT_RATIOS = {
|
| 14 |
+
"0.5676": [336, 592], # count=133984
|
| 15 |
+
"1.7619": [592, 336], # count=78813
|
| 16 |
+
"0.5682": [400, 704], # count=4421
|
| 17 |
+
"0.5556": [320, 576], # count=2481
|
| 18 |
+
"1.7600": [704, 400], # count=1682
|
| 19 |
+
"0.5319": [400, 752], # count=1235
|
| 20 |
+
"1.8000": [576, 320], # count=924
|
| 21 |
+
"0.5128": [320, 624], # count=711
|
| 22 |
+
"1.8800": [752, 400], # count=400
|
| 23 |
+
"1.9000": [608, 320], # count=226
|
| 24 |
+
"0.4237": [400, 944], # count=29
|
| 25 |
+
}
|
| 26 |
+
# CUSTOM_ASPECT_RATIOS = {
|
| 27 |
+
# "0.5676": [336, 592], # 336x592 (h x w)
|
| 28 |
+
# "1.7619": [592, 336], # 592x336
|
| 29 |
+
# "0.5682": [400, 704], # 400x704
|
| 30 |
+
# "1.7600": [704, 400], # 704x400
|
| 31 |
+
# "0.5319": [400, 752], # 400x752
|
| 32 |
+
# "1.8800": [752, 400], # 752x400
|
| 33 |
+
# "0.4237": [400, 944], # 400x944
|
| 34 |
+
# }
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
ASPECT_RATIO_512 = {
|
| 38 |
+
'0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
|
| 39 |
+
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
|
| 40 |
+
'0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
|
| 41 |
+
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
|
| 42 |
+
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
|
| 43 |
+
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
|
| 44 |
+
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
|
| 45 |
+
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
|
| 46 |
+
'2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
|
| 47 |
+
'3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
|
| 48 |
+
}
|
| 49 |
+
ASPECT_RATIO_RANDOM_CROP_512 = {
|
| 50 |
+
'0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
|
| 51 |
+
'0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
|
| 52 |
+
'0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
|
| 53 |
+
'1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
|
| 54 |
+
'2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
|
| 55 |
+
}
|
| 56 |
+
ASPECT_RATIO_RANDOM_CROP_PROB = [
|
| 57 |
+
1, 2,
|
| 58 |
+
4, 4, 4, 4,
|
| 59 |
+
8, 8, 8,
|
| 60 |
+
4, 4, 4, 4,
|
| 61 |
+
2, 1
|
| 62 |
+
]
|
| 63 |
+
ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
|
| 64 |
+
|
| 65 |
+
def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
|
| 66 |
+
aspect_ratio = height / width
|
| 67 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
| 68 |
+
return ratios[closest_ratio], float(closest_ratio)
|
| 69 |
+
|
| 70 |
+
def get_image_size_without_loading(path):
|
| 71 |
+
with Image.open(path) as img:
|
| 72 |
+
return img.size # (width, height)
|
| 73 |
+
|
| 74 |
+
class RandomSampler(Sampler[int]):
|
| 75 |
+
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
|
| 76 |
+
|
| 77 |
+
If with replacement, then user can specify :attr:`num_samples` to draw.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
data_source (Dataset): dataset to sample from
|
| 81 |
+
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
|
| 82 |
+
num_samples (int): number of samples to draw, default=`len(dataset)`.
|
| 83 |
+
generator (Generator): Generator used in sampling.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
data_source: Sized
|
| 87 |
+
replacement: bool
|
| 88 |
+
|
| 89 |
+
def __init__(self, data_source: Sized, replacement: bool = False,
|
| 90 |
+
num_samples: Optional[int] = None, generator=None) -> None:
|
| 91 |
+
self.data_source = data_source
|
| 92 |
+
self.replacement = replacement
|
| 93 |
+
self._num_samples = num_samples
|
| 94 |
+
self.generator = generator
|
| 95 |
+
self._pos_start = 0
|
| 96 |
+
|
| 97 |
+
if not isinstance(self.replacement, bool):
|
| 98 |
+
raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
|
| 99 |
+
|
| 100 |
+
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
|
| 101 |
+
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def num_samples(self) -> int:
|
| 105 |
+
# dataset size might change at runtime
|
| 106 |
+
if self._num_samples is None:
|
| 107 |
+
return len(self.data_source)
|
| 108 |
+
return self._num_samples
|
| 109 |
+
|
| 110 |
+
def __iter__(self) -> Iterator[int]:
|
| 111 |
+
n = len(self.data_source)
|
| 112 |
+
if self.generator is None:
|
| 113 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 114 |
+
generator = torch.Generator()
|
| 115 |
+
generator.manual_seed(seed)
|
| 116 |
+
else:
|
| 117 |
+
generator = self.generator
|
| 118 |
+
|
| 119 |
+
if self.replacement:
|
| 120 |
+
for _ in range(self.num_samples // 32):
|
| 121 |
+
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
|
| 122 |
+
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
|
| 123 |
+
else:
|
| 124 |
+
for _ in range(self.num_samples // n):
|
| 125 |
+
xx = torch.randperm(n, generator=generator).tolist()
|
| 126 |
+
if self._pos_start >= n:
|
| 127 |
+
self._pos_start = 0
|
| 128 |
+
for idx in range(self._pos_start, n):
|
| 129 |
+
yield xx[idx]
|
| 130 |
+
self._pos_start = (self._pos_start + 1) % n
|
| 131 |
+
self._pos_start = 0
|
| 132 |
+
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
|
| 133 |
+
|
| 134 |
+
def __len__(self) -> int:
|
| 135 |
+
return self.num_samples
|
| 136 |
+
|
| 137 |
+
class AspectRatioBatchImageSampler(BatchSampler):
|
| 138 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
sampler (Sampler): Base sampler.
|
| 142 |
+
dataset (Dataset): Dataset providing data information.
|
| 143 |
+
batch_size (int): Size of mini-batch.
|
| 144 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 145 |
+
its size would be less than ``batch_size``.
|
| 146 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 147 |
+
"""
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
sampler: Sampler,
|
| 151 |
+
dataset: Dataset,
|
| 152 |
+
batch_size: int,
|
| 153 |
+
train_folder: str = None,
|
| 154 |
+
aspect_ratios: dict = ASPECT_RATIO_512,
|
| 155 |
+
drop_last: bool = False,
|
| 156 |
+
config=None,
|
| 157 |
+
**kwargs
|
| 158 |
+
) -> None:
|
| 159 |
+
if not isinstance(sampler, Sampler):
|
| 160 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 161 |
+
f'but got {sampler}')
|
| 162 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 163 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 164 |
+
f'but got batch_size={batch_size}')
|
| 165 |
+
self.sampler = sampler
|
| 166 |
+
self.dataset = dataset
|
| 167 |
+
self.train_folder = train_folder
|
| 168 |
+
self.batch_size = batch_size
|
| 169 |
+
self.aspect_ratios = aspect_ratios
|
| 170 |
+
self.drop_last = drop_last
|
| 171 |
+
self.config = config
|
| 172 |
+
# buckets for each aspect ratio
|
| 173 |
+
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
|
| 174 |
+
# [str(k) for k, v in aspect_ratios]
|
| 175 |
+
self.current_available_bucket_keys = list(aspect_ratios.keys())
|
| 176 |
+
|
| 177 |
+
def __iter__(self):
|
| 178 |
+
for idx in self.sampler:
|
| 179 |
+
try:
|
| 180 |
+
image_dict = self.dataset[idx]
|
| 181 |
+
|
| 182 |
+
width, height = image_dict.get("width", None), image_dict.get("height", None)
|
| 183 |
+
if width is None or height is None:
|
| 184 |
+
image_id, name = image_dict['file_path'], image_dict['text']
|
| 185 |
+
if self.train_folder is None:
|
| 186 |
+
image_dir = image_id
|
| 187 |
+
else:
|
| 188 |
+
image_dir = os.path.join(self.train_folder, image_id)
|
| 189 |
+
|
| 190 |
+
width, height = get_image_size_without_loading(image_dir)
|
| 191 |
+
|
| 192 |
+
ratio = height / width # self.dataset[idx]
|
| 193 |
+
else:
|
| 194 |
+
height = int(height)
|
| 195 |
+
width = int(width)
|
| 196 |
+
ratio = height / width # self.dataset[idx]
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(e)
|
| 199 |
+
continue
|
| 200 |
+
# find the closest aspect ratio
|
| 201 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
| 202 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
| 203 |
+
continue
|
| 204 |
+
bucket = self._aspect_ratio_buckets[closest_ratio]
|
| 205 |
+
bucket.append(idx)
|
| 206 |
+
# yield a batch of indices in the same aspect ratio group
|
| 207 |
+
if len(bucket) == self.batch_size:
|
| 208 |
+
yield bucket[:]
|
| 209 |
+
del bucket[:]
|
| 210 |
+
|
| 211 |
+
class AspectRatioBatchSampler(BatchSampler):
|
| 212 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
sampler (Sampler): Base sampler.
|
| 216 |
+
dataset (Dataset): Dataset providing data information.
|
| 217 |
+
batch_size (int): Size of mini-batch.
|
| 218 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 219 |
+
its size would be less than ``batch_size``.
|
| 220 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 221 |
+
"""
|
| 222 |
+
def __init__(
|
| 223 |
+
self,
|
| 224 |
+
sampler: Sampler,
|
| 225 |
+
dataset: Dataset,
|
| 226 |
+
batch_size: int,
|
| 227 |
+
video_folder: str = None,
|
| 228 |
+
train_data_format: str = "webvid",
|
| 229 |
+
aspect_ratios: dict = ASPECT_RATIO_512,
|
| 230 |
+
drop_last: bool = False,
|
| 231 |
+
config=None,
|
| 232 |
+
**kwargs
|
| 233 |
+
) -> None:
|
| 234 |
+
if not isinstance(sampler, Sampler):
|
| 235 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 236 |
+
f'but got {sampler}')
|
| 237 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 238 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 239 |
+
f'but got batch_size={batch_size}')
|
| 240 |
+
self.sampler = sampler
|
| 241 |
+
self.dataset = dataset
|
| 242 |
+
self.video_folder = video_folder
|
| 243 |
+
self.train_data_format = train_data_format
|
| 244 |
+
self.batch_size = batch_size
|
| 245 |
+
self.aspect_ratios = aspect_ratios
|
| 246 |
+
self.drop_last = drop_last
|
| 247 |
+
self.config = config
|
| 248 |
+
# buckets for each aspect ratio
|
| 249 |
+
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
|
| 250 |
+
# [str(k) for k, v in aspect_ratios]
|
| 251 |
+
self.current_available_bucket_keys = list(aspect_ratios.keys())
|
| 252 |
+
|
| 253 |
+
def __iter__(self):
|
| 254 |
+
for idx in self.sampler:
|
| 255 |
+
try:
|
| 256 |
+
video_dict = self.dataset[idx]
|
| 257 |
+
width, more = video_dict.get("width", None), video_dict.get("height", None)
|
| 258 |
+
|
| 259 |
+
if width is None or height is None:
|
| 260 |
+
if self.train_data_format == "normal":
|
| 261 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
| 262 |
+
if self.video_folder is None:
|
| 263 |
+
video_dir = video_id
|
| 264 |
+
else:
|
| 265 |
+
video_dir = os.path.join(self.video_folder, video_id)
|
| 266 |
+
else:
|
| 267 |
+
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
|
| 268 |
+
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
|
| 269 |
+
cap = cv2.VideoCapture(video_dir)
|
| 270 |
+
|
| 271 |
+
# 获取视频尺寸
|
| 272 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
|
| 273 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
|
| 274 |
+
|
| 275 |
+
ratio = height / width # self.dataset[idx]
|
| 276 |
+
else:
|
| 277 |
+
height = int(height)
|
| 278 |
+
width = int(width)
|
| 279 |
+
ratio = height / width # self.dataset[idx]
|
| 280 |
+
except Exception as e:
|
| 281 |
+
print(e, self.dataset[idx], "This item is error, please check it.")
|
| 282 |
+
continue
|
| 283 |
+
# find the closest aspect ratio
|
| 284 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
| 285 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
| 286 |
+
continue
|
| 287 |
+
bucket = self._aspect_ratio_buckets[closest_ratio]
|
| 288 |
+
bucket.append(idx)
|
| 289 |
+
# yield a batch of indices in the same aspect ratio group
|
| 290 |
+
if len(bucket) == self.batch_size:
|
| 291 |
+
yield bucket[:]
|
| 292 |
+
del bucket[:]
|
| 293 |
+
|
| 294 |
+
class AspectRatioBatchImageVideoSampler(BatchSampler):
|
| 295 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
sampler (Sampler): Base sampler.
|
| 299 |
+
dataset (Dataset): Dataset providing data information.
|
| 300 |
+
batch_size (int): Size of mini-batch.
|
| 301 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 302 |
+
its size would be less than ``batch_size``.
|
| 303 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
def __init__(self,
|
| 307 |
+
sampler: Sampler,
|
| 308 |
+
dataset: Dataset,
|
| 309 |
+
batch_size: int,
|
| 310 |
+
train_folder: str = None,
|
| 311 |
+
aspect_ratios: dict = ASPECT_RATIO_512,
|
| 312 |
+
drop_last: bool = False
|
| 313 |
+
) -> None:
|
| 314 |
+
if not isinstance(sampler, Sampler):
|
| 315 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 316 |
+
f'but got {sampler}')
|
| 317 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 318 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 319 |
+
f'but got batch_size={batch_size}')
|
| 320 |
+
self.sampler = sampler
|
| 321 |
+
self.dataset = dataset
|
| 322 |
+
self.train_folder = train_folder
|
| 323 |
+
self.batch_size = batch_size
|
| 324 |
+
self.aspect_ratios = aspect_ratios
|
| 325 |
+
self.drop_last = drop_last
|
| 326 |
+
|
| 327 |
+
# buckets for each aspect ratio
|
| 328 |
+
self.current_available_bucket_keys = list(aspect_ratios.keys())
|
| 329 |
+
self.bucket = {
|
| 330 |
+
'image':{ratio: [] for ratio in aspect_ratios},
|
| 331 |
+
'video':{ratio: [] for ratio in aspect_ratios}
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
def __iter__(self):
|
| 335 |
+
for idx in self.sampler:
|
| 336 |
+
content_type = self.dataset[idx].get('type', 'video') # Default to video for video edit datasets
|
| 337 |
+
|
| 338 |
+
try:
|
| 339 |
+
data_dict = self.dataset[idx]
|
| 340 |
+
width, height = data_dict.get("width", None), data_dict.get("height", None)
|
| 341 |
+
|
| 342 |
+
if width is None or height is None:
|
| 343 |
+
if content_type == 'image':
|
| 344 |
+
# Image branch
|
| 345 |
+
image_id = data_dict.get('file_path', '')
|
| 346 |
+
if self.train_folder is None:
|
| 347 |
+
image_dir = image_id
|
| 348 |
+
else:
|
| 349 |
+
image_dir = os.path.join(self.train_folder, image_id)
|
| 350 |
+
width, height = get_image_size_without_loading(image_dir)
|
| 351 |
+
else:
|
| 352 |
+
# Video branch - prefer original_video -> edited_video -> file_path
|
| 353 |
+
video_id = (
|
| 354 |
+
data_dict.get('original_video')
|
| 355 |
+
or data_dict.get('edited_video')
|
| 356 |
+
or data_dict.get('file_path')
|
| 357 |
+
)
|
| 358 |
+
if video_id is None:
|
| 359 |
+
raise ValueError(f"No valid video path found in dataset item: {data_dict}")
|
| 360 |
+
if self.train_folder is None:
|
| 361 |
+
video_dir = video_id
|
| 362 |
+
else:
|
| 363 |
+
video_dir = os.path.join(self.train_folder, video_id)
|
| 364 |
+
cap = cv2.VideoCapture(video_dir)
|
| 365 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 366 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 367 |
+
cap.release()
|
| 368 |
+
if width == 0 or height == 0:
|
| 369 |
+
raise ValueError(f"Invalid video size for {video_dir}: {width}x{height}")
|
| 370 |
+
else:
|
| 371 |
+
height = int(height)
|
| 372 |
+
width = int(width)
|
| 373 |
+
|
| 374 |
+
ratio = height / width
|
| 375 |
+
|
| 376 |
+
except Exception as e:
|
| 377 |
+
print(e, self.dataset[idx], "This item is error, please check it.")
|
| 378 |
+
continue
|
| 379 |
+
|
| 380 |
+
# Find the closest aspect ratio
|
| 381 |
+
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
|
| 382 |
+
if closest_ratio not in self.current_available_bucket_keys:
|
| 383 |
+
continue
|
| 384 |
+
|
| 385 |
+
# Add to appropriate bucket (image or video)
|
| 386 |
+
bucket = self.bucket[content_type][closest_ratio]
|
| 387 |
+
bucket.append(idx)
|
| 388 |
+
|
| 389 |
+
# Yield a batch when bucket is full (ensures all items are same type)
|
| 390 |
+
if len(bucket) == self.batch_size:
|
| 391 |
+
yield bucket[:]
|
| 392 |
+
del bucket[:]
|
videox_fun/data/dataset_image.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torch.utils.data.dataset import Dataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CC15M(Dataset):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
json_path,
|
| 16 |
+
video_folder=None,
|
| 17 |
+
resolution=512,
|
| 18 |
+
enable_bucket=False,
|
| 19 |
+
):
|
| 20 |
+
print(f"loading annotations from {json_path} ...")
|
| 21 |
+
self.dataset = json.load(open(json_path, 'r'))
|
| 22 |
+
self.length = len(self.dataset)
|
| 23 |
+
print(f"data scale: {self.length}")
|
| 24 |
+
|
| 25 |
+
self.enable_bucket = enable_bucket
|
| 26 |
+
self.video_folder = video_folder
|
| 27 |
+
|
| 28 |
+
resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
|
| 29 |
+
self.pixel_transforms = transforms.Compose([
|
| 30 |
+
transforms.Resize(resolution[0]),
|
| 31 |
+
transforms.CenterCrop(resolution),
|
| 32 |
+
transforms.ToTensor(),
|
| 33 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 34 |
+
])
|
| 35 |
+
|
| 36 |
+
def get_batch(self, idx):
|
| 37 |
+
video_dict = self.dataset[idx]
|
| 38 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
| 39 |
+
|
| 40 |
+
if self.video_folder is None:
|
| 41 |
+
video_dir = video_id
|
| 42 |
+
else:
|
| 43 |
+
video_dir = os.path.join(self.video_folder, video_id)
|
| 44 |
+
|
| 45 |
+
pixel_values = Image.open(video_dir).convert("RGB")
|
| 46 |
+
return pixel_values, name
|
| 47 |
+
|
| 48 |
+
def __len__(self):
|
| 49 |
+
return self.length
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, idx):
|
| 52 |
+
while True:
|
| 53 |
+
try:
|
| 54 |
+
pixel_values, name = self.get_batch(idx)
|
| 55 |
+
break
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(e)
|
| 58 |
+
idx = random.randint(0, self.length-1)
|
| 59 |
+
|
| 60 |
+
if not self.enable_bucket:
|
| 61 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
| 62 |
+
else:
|
| 63 |
+
pixel_values = np.array(pixel_values)
|
| 64 |
+
|
| 65 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
| 66 |
+
return sample
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
dataset = CC15M(
|
| 70 |
+
csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json",
|
| 71 |
+
resolution=512,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
|
| 75 |
+
for idx, batch in enumerate(dataloader):
|
| 76 |
+
print(batch["pixel_values"].shape, len(batch["text"]))
|
videox_fun/data/dataset_image_video.py
ADDED
|
@@ -0,0 +1,1939 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import gc
|
| 3 |
+
import io
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import re
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
from random import shuffle
|
| 11 |
+
from threading import Thread
|
| 12 |
+
|
| 13 |
+
import albumentations
|
| 14 |
+
import cv2
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import torchvision.transforms as transforms
|
| 19 |
+
from decord import VideoReader
|
| 20 |
+
from einops import rearrange
|
| 21 |
+
from func_timeout import FunctionTimedOut, func_timeout
|
| 22 |
+
from packaging import version as pver
|
| 23 |
+
from PIL import Image
|
| 24 |
+
from torch.utils.data import BatchSampler, Sampler
|
| 25 |
+
from torch.utils.data.dataset import Dataset
|
| 26 |
+
|
| 27 |
+
VIDEO_READER_TIMEOUT = 20
|
| 28 |
+
|
| 29 |
+
def get_random_mask(shape, image_start_only=False):
|
| 30 |
+
f, c, h, w = shape
|
| 31 |
+
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
|
| 32 |
+
|
| 33 |
+
if not image_start_only:
|
| 34 |
+
if f != 1:
|
| 35 |
+
mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05])
|
| 36 |
+
else:
|
| 37 |
+
mask_index = np.random.choice([0, 1], p = [0.2, 0.8])
|
| 38 |
+
if mask_index == 0:
|
| 39 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 40 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 41 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
| 42 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
| 43 |
+
|
| 44 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 45 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 46 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 47 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 48 |
+
mask[:, :, start_y:end_y, start_x:end_x] = 1
|
| 49 |
+
elif mask_index == 1:
|
| 50 |
+
mask[:, :, :, :] = 1
|
| 51 |
+
elif mask_index == 2:
|
| 52 |
+
mask_frame_index = np.random.randint(1, 5)
|
| 53 |
+
mask[mask_frame_index:, :, :, :] = 1
|
| 54 |
+
elif mask_index == 3:
|
| 55 |
+
mask_frame_index = np.random.randint(1, 5)
|
| 56 |
+
mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
|
| 57 |
+
elif mask_index == 4:
|
| 58 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 59 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 60 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
| 61 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
| 62 |
+
|
| 63 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 64 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 65 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 66 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 67 |
+
|
| 68 |
+
mask_frame_before = np.random.randint(0, f // 2)
|
| 69 |
+
mask_frame_after = np.random.randint(f // 2, f)
|
| 70 |
+
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
|
| 71 |
+
elif mask_index == 5:
|
| 72 |
+
mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
|
| 73 |
+
elif mask_index == 6:
|
| 74 |
+
num_frames_to_mask = random.randint(1, max(f // 2, 1))
|
| 75 |
+
frames_to_mask = random.sample(range(f), num_frames_to_mask)
|
| 76 |
+
|
| 77 |
+
for i in frames_to_mask:
|
| 78 |
+
block_height = random.randint(1, h // 4)
|
| 79 |
+
block_width = random.randint(1, w // 4)
|
| 80 |
+
top_left_y = random.randint(0, h - block_height)
|
| 81 |
+
top_left_x = random.randint(0, w - block_width)
|
| 82 |
+
mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
|
| 83 |
+
elif mask_index == 7:
|
| 84 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 85 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 86 |
+
a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴
|
| 87 |
+
b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴
|
| 88 |
+
|
| 89 |
+
for i in range(h):
|
| 90 |
+
for j in range(w):
|
| 91 |
+
if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
|
| 92 |
+
mask[:, :, i, j] = 1
|
| 93 |
+
elif mask_index == 8:
|
| 94 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 95 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 96 |
+
radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
|
| 97 |
+
for i in range(h):
|
| 98 |
+
for j in range(w):
|
| 99 |
+
if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
|
| 100 |
+
mask[:, :, i, j] = 1
|
| 101 |
+
elif mask_index == 9:
|
| 102 |
+
for idx in range(f):
|
| 103 |
+
if np.random.rand() > 0.5:
|
| 104 |
+
mask[idx, :, :, :] = 1
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError(f"The mask_index {mask_index} is not define")
|
| 107 |
+
else:
|
| 108 |
+
if f != 1:
|
| 109 |
+
mask[1:, :, :, :] = 1
|
| 110 |
+
else:
|
| 111 |
+
mask[:, :, :, :] = 1
|
| 112 |
+
return mask
|
| 113 |
+
|
| 114 |
+
class Camera(object):
|
| 115 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 116 |
+
"""
|
| 117 |
+
def __init__(self, entry):
|
| 118 |
+
fx, fy, cx, cy = entry[1:5]
|
| 119 |
+
self.fx = fx
|
| 120 |
+
self.fy = fy
|
| 121 |
+
self.cx = cx
|
| 122 |
+
self.cy = cy
|
| 123 |
+
w2c_mat = np.array(entry[7:]).reshape(3, 4)
|
| 124 |
+
w2c_mat_4x4 = np.eye(4)
|
| 125 |
+
w2c_mat_4x4[:3, :] = w2c_mat
|
| 126 |
+
self.w2c_mat = w2c_mat_4x4
|
| 127 |
+
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
|
| 128 |
+
|
| 129 |
+
def custom_meshgrid(*args):
|
| 130 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 131 |
+
"""
|
| 132 |
+
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
| 133 |
+
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
| 134 |
+
return torch.meshgrid(*args)
|
| 135 |
+
else:
|
| 136 |
+
return torch.meshgrid(*args, indexing='ij')
|
| 137 |
+
|
| 138 |
+
def get_relative_pose(cam_params):
|
| 139 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 140 |
+
"""
|
| 141 |
+
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
| 142 |
+
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
| 143 |
+
cam_to_origin = 0
|
| 144 |
+
target_cam_c2w = np.array([
|
| 145 |
+
[1, 0, 0, 0],
|
| 146 |
+
[0, 1, 0, -cam_to_origin],
|
| 147 |
+
[0, 0, 1, 0],
|
| 148 |
+
[0, 0, 0, 1]
|
| 149 |
+
])
|
| 150 |
+
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
| 151 |
+
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
| 152 |
+
ret_poses = np.array(ret_poses, dtype=np.float32)
|
| 153 |
+
return ret_poses
|
| 154 |
+
|
| 155 |
+
def ray_condition(K, c2w, H, W, device):
|
| 156 |
+
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 157 |
+
"""
|
| 158 |
+
# c2w: B, V, 4, 4
|
| 159 |
+
# K: B, V, 4
|
| 160 |
+
|
| 161 |
+
B = K.shape[0]
|
| 162 |
+
|
| 163 |
+
j, i = custom_meshgrid(
|
| 164 |
+
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
| 165 |
+
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
| 166 |
+
)
|
| 167 |
+
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 168 |
+
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 169 |
+
|
| 170 |
+
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
| 171 |
+
|
| 172 |
+
zs = torch.ones_like(i) # [B, HxW]
|
| 173 |
+
xs = (i - cx) / fx * zs
|
| 174 |
+
ys = (j - cy) / fy * zs
|
| 175 |
+
zs = zs.expand_as(ys)
|
| 176 |
+
|
| 177 |
+
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
| 178 |
+
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
| 179 |
+
|
| 180 |
+
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
| 181 |
+
rays_o = c2w[..., :3, 3] # B, V, 3
|
| 182 |
+
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
| 183 |
+
# c2w @ dirctions
|
| 184 |
+
rays_dxo = torch.cross(rays_o, rays_d)
|
| 185 |
+
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
| 186 |
+
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
| 187 |
+
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
| 188 |
+
return plucker
|
| 189 |
+
|
| 190 |
+
def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
|
| 191 |
+
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 192 |
+
"""
|
| 193 |
+
with open(pose_file_path, 'r') as f:
|
| 194 |
+
poses = f.readlines()
|
| 195 |
+
|
| 196 |
+
poses = [pose.strip().split(' ') for pose in poses[1:]]
|
| 197 |
+
cam_params = [[float(x) for x in pose] for pose in poses]
|
| 198 |
+
if return_poses:
|
| 199 |
+
return cam_params
|
| 200 |
+
else:
|
| 201 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
| 202 |
+
|
| 203 |
+
sample_wh_ratio = width / height
|
| 204 |
+
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
| 205 |
+
|
| 206 |
+
if pose_wh_ratio > sample_wh_ratio:
|
| 207 |
+
resized_ori_w = height * pose_wh_ratio
|
| 208 |
+
for cam_param in cam_params:
|
| 209 |
+
cam_param.fx = resized_ori_w * cam_param.fx / width
|
| 210 |
+
else:
|
| 211 |
+
resized_ori_h = width / pose_wh_ratio
|
| 212 |
+
for cam_param in cam_params:
|
| 213 |
+
cam_param.fy = resized_ori_h * cam_param.fy / height
|
| 214 |
+
|
| 215 |
+
intrinsic = np.asarray([[cam_param.fx * width,
|
| 216 |
+
cam_param.fy * height,
|
| 217 |
+
cam_param.cx * width,
|
| 218 |
+
cam_param.cy * height]
|
| 219 |
+
for cam_param in cam_params], dtype=np.float32)
|
| 220 |
+
|
| 221 |
+
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
| 222 |
+
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
| 223 |
+
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
| 224 |
+
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
| 225 |
+
plucker_embedding = plucker_embedding[None]
|
| 226 |
+
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
| 227 |
+
return plucker_embedding
|
| 228 |
+
|
| 229 |
+
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
|
| 230 |
+
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
| 231 |
+
"""
|
| 232 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
| 233 |
+
|
| 234 |
+
sample_wh_ratio = width / height
|
| 235 |
+
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
|
| 236 |
+
|
| 237 |
+
if pose_wh_ratio > sample_wh_ratio:
|
| 238 |
+
resized_ori_w = height * pose_wh_ratio
|
| 239 |
+
for cam_param in cam_params:
|
| 240 |
+
cam_param.fx = resized_ori_w * cam_param.fx / width
|
| 241 |
+
else:
|
| 242 |
+
resized_ori_h = width / pose_wh_ratio
|
| 243 |
+
for cam_param in cam_params:
|
| 244 |
+
cam_param.fy = resized_ori_h * cam_param.fy / height
|
| 245 |
+
|
| 246 |
+
intrinsic = np.asarray([[cam_param.fx * width,
|
| 247 |
+
cam_param.fy * height,
|
| 248 |
+
cam_param.cx * width,
|
| 249 |
+
cam_param.cy * height]
|
| 250 |
+
for cam_param in cam_params], dtype=np.float32)
|
| 251 |
+
|
| 252 |
+
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
| 253 |
+
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
|
| 254 |
+
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
| 255 |
+
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
|
| 256 |
+
plucker_embedding = plucker_embedding[None]
|
| 257 |
+
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
| 258 |
+
return plucker_embedding
|
| 259 |
+
|
| 260 |
+
def derive_ground_object_from_instruction(instruction: str) -> str:
|
| 261 |
+
s = (instruction or '').strip()
|
| 262 |
+
if not s:
|
| 263 |
+
return 'the target area'
|
| 264 |
+
s = s.rstrip('.').strip()
|
| 265 |
+
|
| 266 |
+
# swap/replace: capture phrase between "replace/swap" and "with/by"
|
| 267 |
+
swap_patterns = [
|
| 268 |
+
r"\breplace\s+(.*?)\s+(?:with|by)\b",
|
| 269 |
+
r"\bswap\s+(.*?)\s+with\b",
|
| 270 |
+
]
|
| 271 |
+
for pat in swap_patterns:
|
| 272 |
+
m = re.search(pat, s, flags=re.IGNORECASE)
|
| 273 |
+
if m:
|
| 274 |
+
phrase = m.group(1).strip(' .,:;')
|
| 275 |
+
if phrase:
|
| 276 |
+
return phrase
|
| 277 |
+
|
| 278 |
+
# removal: capture object after remove/delete/erase/eliminate up to a preposition or punctuation
|
| 279 |
+
m = re.search(r"\b(?:remove|delete|erase|eliminate)\s+(.*?)(?:\s+(?:from|in|at|on|over|under|near|by)\b|[.,;]|$)", s, flags=re.IGNORECASE)
|
| 280 |
+
if m:
|
| 281 |
+
phrase = m.group(1).strip(' .,:;')
|
| 282 |
+
if phrase:
|
| 283 |
+
return phrase
|
| 284 |
+
|
| 285 |
+
# add/insert: generic target area
|
| 286 |
+
if re.search(r"^\s*(?:add|insert)\b", s, flags=re.IGNORECASE):
|
| 287 |
+
return 'the target area'
|
| 288 |
+
|
| 289 |
+
# local style (change/make ...): take the immediate noun after determiner
|
| 290 |
+
m = re.search(r"\b(?:change|make)\s+(?:(the|a|an)\s+)?([A-Za-z][A-Za-z0-9\-]*)", s, flags=re.IGNORECASE)
|
| 291 |
+
if m:
|
| 292 |
+
det = m.group(1) or ''
|
| 293 |
+
noun = m.group(2)
|
| 294 |
+
phrase = (det + ' ' + noun).strip()
|
| 295 |
+
return phrase
|
| 296 |
+
|
| 297 |
+
return 'the target area'
|
| 298 |
+
|
| 299 |
+
class ImageVideoSampler(BatchSampler):
|
| 300 |
+
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
sampler (Sampler): Base sampler.
|
| 304 |
+
dataset (Dataset): Dataset providing data information.
|
| 305 |
+
batch_size (int): Size of mini-batch.
|
| 306 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
| 307 |
+
its size would be less than ``batch_size``.
|
| 308 |
+
aspect_ratios (dict): The predefined aspect ratios.
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
def __init__(self,
|
| 312 |
+
sampler: Sampler,
|
| 313 |
+
dataset: Dataset,
|
| 314 |
+
batch_size: int,
|
| 315 |
+
drop_last: bool = False
|
| 316 |
+
) -> None:
|
| 317 |
+
if not isinstance(sampler, Sampler):
|
| 318 |
+
raise TypeError('sampler should be an instance of ``Sampler``, '
|
| 319 |
+
f'but got {sampler}')
|
| 320 |
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
| 321 |
+
raise ValueError('batch_size should be a positive integer value, '
|
| 322 |
+
f'but got batch_size={batch_size}')
|
| 323 |
+
self.sampler = sampler
|
| 324 |
+
self.dataset = dataset
|
| 325 |
+
self.batch_size = batch_size
|
| 326 |
+
self.drop_last = drop_last
|
| 327 |
+
|
| 328 |
+
# buckets for each aspect ratio
|
| 329 |
+
self.bucket = {'image':[], 'video':[]}
|
| 330 |
+
|
| 331 |
+
def __iter__(self):
|
| 332 |
+
for idx in self.sampler:
|
| 333 |
+
content_type = self.dataset.dataset[idx].get('type', 'image')
|
| 334 |
+
self.bucket[content_type].append(idx)
|
| 335 |
+
|
| 336 |
+
# yield a batch of indices in the same aspect ratio group
|
| 337 |
+
if len(self.bucket['video']) == self.batch_size:
|
| 338 |
+
bucket = self.bucket['video']
|
| 339 |
+
yield bucket[:]
|
| 340 |
+
del bucket[:]
|
| 341 |
+
elif len(self.bucket['image']) == self.batch_size:
|
| 342 |
+
bucket = self.bucket['image']
|
| 343 |
+
yield bucket[:]
|
| 344 |
+
del bucket[:]
|
| 345 |
+
|
| 346 |
+
@contextmanager
|
| 347 |
+
def VideoReader_contextmanager(*args, **kwargs):
|
| 348 |
+
vr = VideoReader(*args, **kwargs)
|
| 349 |
+
try:
|
| 350 |
+
yield vr
|
| 351 |
+
finally:
|
| 352 |
+
del vr
|
| 353 |
+
gc.collect()
|
| 354 |
+
|
| 355 |
+
def get_video_reader_batch(video_reader, batch_index):
|
| 356 |
+
frames = video_reader.get_batch(batch_index).asnumpy()
|
| 357 |
+
return frames
|
| 358 |
+
|
| 359 |
+
def resize_frame(frame, target_short_side):
|
| 360 |
+
h, w, _ = frame.shape
|
| 361 |
+
if h < w:
|
| 362 |
+
if target_short_side > h:
|
| 363 |
+
return frame
|
| 364 |
+
new_h = target_short_side
|
| 365 |
+
new_w = int(target_short_side * w / h)
|
| 366 |
+
else:
|
| 367 |
+
if target_short_side > w:
|
| 368 |
+
return frame
|
| 369 |
+
new_w = target_short_side
|
| 370 |
+
new_h = int(target_short_side * h / w)
|
| 371 |
+
|
| 372 |
+
resized_frame = cv2.resize(frame, (new_w, new_h))
|
| 373 |
+
return resized_frame
|
| 374 |
+
|
| 375 |
+
class VideoEditDataset(Dataset):
|
| 376 |
+
def __init__(
|
| 377 |
+
self,
|
| 378 |
+
ann_path,
|
| 379 |
+
data_root=None,
|
| 380 |
+
video_sample_height: int = None, # 改为None以支持动态分辨率
|
| 381 |
+
video_sample_width: int = None,
|
| 382 |
+
video_sample_stride=1,
|
| 383 |
+
video_sample_n_frames=65, # 9+8=17 for your case
|
| 384 |
+
source_frames=33,
|
| 385 |
+
edit_frames=32,
|
| 386 |
+
text_drop_ratio=0.1,
|
| 387 |
+
enable_bucket=False,
|
| 388 |
+
enable_inpaint=False,
|
| 389 |
+
instruction_template="A video sequence showing two parts: the first half shows the original scene, and the second half shows the same scene but {edit_instruction}",
|
| 390 |
+
):
|
| 391 |
+
dataset = json.load(open(ann_path))
|
| 392 |
+
if isinstance(dataset, dict):
|
| 393 |
+
new_dataset = []
|
| 394 |
+
for vid_id, info in dataset.items():
|
| 395 |
+
text_content = info["edit_instruction"]
|
| 396 |
+
new_dataset.append({
|
| 397 |
+
"original_video": info["original_video"],
|
| 398 |
+
"edited_video": info["edited_video"],
|
| 399 |
+
"text": text_content,
|
| 400 |
+
"type": info.get("type", "video"),
|
| 401 |
+
# 添加分辨率信息到metadata
|
| 402 |
+
"resolution": info.get("resolution", None)
|
| 403 |
+
})
|
| 404 |
+
dataset = new_dataset
|
| 405 |
+
|
| 406 |
+
self.data_root = data_root
|
| 407 |
+
self.dataset = dataset
|
| 408 |
+
self.length = len(self.dataset)
|
| 409 |
+
|
| 410 |
+
self.source_frames = source_frames
|
| 411 |
+
self.edit_frames = edit_frames
|
| 412 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 413 |
+
|
| 414 |
+
self.instruction_template = instruction_template
|
| 415 |
+
self.enable_bucket = enable_bucket
|
| 416 |
+
self.text_drop_ratio = text_drop_ratio
|
| 417 |
+
self.enable_inpaint = enable_inpaint
|
| 418 |
+
self.video_sample_stride = video_sample_stride
|
| 419 |
+
|
| 420 |
+
# 如果启用bucket,不固定分辨率
|
| 421 |
+
if enable_bucket:
|
| 422 |
+
self.video_sample_height = None
|
| 423 |
+
self.video_sample_width = None
|
| 424 |
+
else:
|
| 425 |
+
self.video_sample_height = video_sample_height
|
| 426 |
+
self.video_sample_width = video_sample_width
|
| 427 |
+
|
| 428 |
+
def load_video_pair(self, original_path, edited_path):
|
| 429 |
+
"""加载视频对,保持原始分辨率用于bucket training"""
|
| 430 |
+
if self.data_root is not None:
|
| 431 |
+
original_path = os.path.join(self.data_root, original_path)
|
| 432 |
+
edited_path = os.path.join(self.data_root, edited_path)
|
| 433 |
+
|
| 434 |
+
with VideoReader_contextmanager(original_path, num_threads=2) as orig_reader, \
|
| 435 |
+
VideoReader_contextmanager(edited_path, num_threads=2) as edit_reader:
|
| 436 |
+
|
| 437 |
+
# 获取视频信息
|
| 438 |
+
orig_length = len(orig_reader)
|
| 439 |
+
edit_length = len(edit_reader)
|
| 440 |
+
min_length = min(orig_length, edit_length)
|
| 441 |
+
|
| 442 |
+
# 统一采样策略
|
| 443 |
+
start_idx = 0 # 从头开始
|
| 444 |
+
|
| 445 |
+
orig_indices = np.linspace(
|
| 446 |
+
start_idx,
|
| 447 |
+
min(start_idx + (self.source_frames - 1) * self.video_sample_stride, orig_length - 1),
|
| 448 |
+
self.source_frames,
|
| 449 |
+
dtype=int
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
edit_indices = np.linspace(
|
| 453 |
+
start_idx,
|
| 454 |
+
min(start_idx + (self.edit_frames - 1) * self.video_sample_stride, edit_length - 1),
|
| 455 |
+
self.edit_frames,
|
| 456 |
+
dtype=int
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# 加载帧
|
| 460 |
+
orig_frames = get_video_reader_batch(orig_reader, orig_indices)
|
| 461 |
+
edit_frames = get_video_reader_batch(edit_reader, edit_indices)
|
| 462 |
+
|
| 463 |
+
# 在拼接前对齐两段视频到相同 HxW(缩放后中心裁剪到 min(H1,H2) x min(W1,W2))
|
| 464 |
+
def resize_and_center_crop_batch(frames_np, target_h, target_w):
|
| 465 |
+
resized = []
|
| 466 |
+
for i in range(frames_np.shape[0]):
|
| 467 |
+
frame = frames_np[i]
|
| 468 |
+
h, w = frame.shape[0], frame.shape[1]
|
| 469 |
+
scale = max(target_h / h, target_w / w)
|
| 470 |
+
new_h = int(round(h * scale))
|
| 471 |
+
new_w = int(round(w * scale))
|
| 472 |
+
frame_resized = cv2.resize(frame, (new_w, new_h))
|
| 473 |
+
y0 = max((new_h - target_h) // 2, 0)
|
| 474 |
+
x0 = max((new_w - target_w) // 2, 0)
|
| 475 |
+
frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w]
|
| 476 |
+
resized.append(frame_cropped)
|
| 477 |
+
return np.stack(resized, axis=0)
|
| 478 |
+
|
| 479 |
+
oh, ow = orig_frames.shape[1], orig_frames.shape[2]
|
| 480 |
+
eh, ew = edit_frames.shape[1], edit_frames.shape[2]
|
| 481 |
+
target_h = min(oh, eh)
|
| 482 |
+
target_w = min(ow, ew)
|
| 483 |
+
if (oh != target_h or ow != target_w):
|
| 484 |
+
orig_frames = resize_and_center_crop_batch(orig_frames, target_h, target_w)
|
| 485 |
+
if (eh != target_h or ew != target_w):
|
| 486 |
+
edit_frames = resize_and_center_crop_batch(edit_frames, target_h, target_w)
|
| 487 |
+
|
| 488 |
+
# 如果启用bucket,返回numpy数组
|
| 489 |
+
if self.enable_bucket:
|
| 490 |
+
return np.concatenate([orig_frames, edit_frames], axis=0)
|
| 491 |
+
else:
|
| 492 |
+
# 转换为tensor并归一化
|
| 493 |
+
orig_frames = torch.from_numpy(orig_frames).permute(0, 3, 1, 2).contiguous() / 255.
|
| 494 |
+
edit_frames = torch.from_numpy(edit_frames).permute(0, 3, 1, 2).contiguous() / 255.
|
| 495 |
+
return torch.cat([orig_frames, edit_frames], dim=0)
|
| 496 |
+
|
| 497 |
+
def __len__(self):
|
| 498 |
+
return self.length
|
| 499 |
+
|
| 500 |
+
def __getitem__(self, idx):
|
| 501 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 502 |
+
|
| 503 |
+
while True:
|
| 504 |
+
try:
|
| 505 |
+
# 加载视频对
|
| 506 |
+
pixel_values = self.load_video_pair(
|
| 507 |
+
data_info['original_video'],
|
| 508 |
+
data_info['edited_video']
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# 准备文本
|
| 512 |
+
text = data_info['text']
|
| 513 |
+
if self.instruction_template and "{edit_instruction}" in self.instruction_template:
|
| 514 |
+
text = self.instruction_template.format(edit_instruction=text)
|
| 515 |
+
|
| 516 |
+
if random.random() < self.text_drop_ratio:
|
| 517 |
+
text = ''
|
| 518 |
+
|
| 519 |
+
sample = {
|
| 520 |
+
"pixel_values": pixel_values,
|
| 521 |
+
"text": text,
|
| 522 |
+
"data_type": "video",
|
| 523 |
+
"idx": idx,
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
# 如果需要inpainting
|
| 527 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 528 |
+
# 这里添加inpaint逻辑
|
| 529 |
+
pass
|
| 530 |
+
|
| 531 |
+
return sample
|
| 532 |
+
|
| 533 |
+
except Exception as e:
|
| 534 |
+
try:
|
| 535 |
+
print(
|
| 536 |
+
f"Error loading video pair: {e}\n"
|
| 537 |
+
f" original={os.path.join(self.data_root, data_info.get('original_video','')) if self.data_root else data_info.get('original_video','')}\n"
|
| 538 |
+
f" edited ={os.path.join(self.data_root, data_info.get('edited_video','')) if self.data_root else data_info.get('edited_video','')}"
|
| 539 |
+
)
|
| 540 |
+
except Exception:
|
| 541 |
+
print(f"Error loading video pair: {e}")
|
| 542 |
+
idx = random.randint(0, self.length-1)
|
| 543 |
+
|
| 544 |
+
class VideoEditReasoningDataset(Dataset):
|
| 545 |
+
def __init__(
|
| 546 |
+
self,
|
| 547 |
+
ann_path,
|
| 548 |
+
data_root=None,
|
| 549 |
+
video_sample_height: int = None,
|
| 550 |
+
video_sample_width: int = None,
|
| 551 |
+
video_sample_stride=1,
|
| 552 |
+
video_sample_n_frames=65,
|
| 553 |
+
source_frames=33,
|
| 554 |
+
reasoning_frames=4,
|
| 555 |
+
edit_frames=32,
|
| 556 |
+
text_drop_ratio=0.1,
|
| 557 |
+
enable_bucket=False,
|
| 558 |
+
enable_inpaint=False,
|
| 559 |
+
instruction_template="A video sequence showing three parts: first the original scene, then grounded {ground_instrction}, and finally the same scene but {edit_instruction}",
|
| 560 |
+
):
|
| 561 |
+
dataset = json.load(open(ann_path))
|
| 562 |
+
if isinstance(dataset, dict):
|
| 563 |
+
new_dataset = []
|
| 564 |
+
for vid_id, info in dataset.items():
|
| 565 |
+
text_content = info.get("edit_instruction", info.get("text", ""))
|
| 566 |
+
# support both 'grounded_video' and 'ground_video'
|
| 567 |
+
grounded_key = "grounded_video" if "grounded_video" in info else "ground_video"
|
| 568 |
+
new_dataset.append({
|
| 569 |
+
"original_video": info["original_video"],
|
| 570 |
+
"grounded_video": info[grounded_key],
|
| 571 |
+
"edited_video": info["edited_video"],
|
| 572 |
+
"text": text_content,
|
| 573 |
+
"edit_instruction": text_content,
|
| 574 |
+
"type": info.get("type", "video"),
|
| 575 |
+
"resolution": info.get("resolution", None),
|
| 576 |
+
})
|
| 577 |
+
dataset = new_dataset
|
| 578 |
+
|
| 579 |
+
self.data_root = data_root
|
| 580 |
+
self.dataset = dataset
|
| 581 |
+
self.length = len(self.dataset)
|
| 582 |
+
|
| 583 |
+
self.source_frames = source_frames
|
| 584 |
+
self.reasoning_frames = reasoning_frames
|
| 585 |
+
self.edit_frames = edit_frames
|
| 586 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 587 |
+
|
| 588 |
+
self.instruction_template = instruction_template
|
| 589 |
+
self.enable_bucket = enable_bucket
|
| 590 |
+
self.text_drop_ratio = text_drop_ratio
|
| 591 |
+
self.enable_inpaint = enable_inpaint
|
| 592 |
+
self.video_sample_stride = video_sample_stride
|
| 593 |
+
|
| 594 |
+
if enable_bucket:
|
| 595 |
+
self.video_sample_height = None
|
| 596 |
+
self.video_sample_width = None
|
| 597 |
+
else:
|
| 598 |
+
self.video_sample_height = video_sample_height
|
| 599 |
+
self.video_sample_width = video_sample_width
|
| 600 |
+
|
| 601 |
+
def load_video_pair(self, original_path, grounded_path, edited_path):
|
| 602 |
+
if self.data_root is not None:
|
| 603 |
+
original_path = os.path.join(self.data_root, original_path)
|
| 604 |
+
grounded_path = os.path.join(self.data_root, grounded_path)
|
| 605 |
+
edited_path = os.path.join(self.data_root, edited_path)
|
| 606 |
+
|
| 607 |
+
with VideoReader_contextmanager(original_path, num_threads=2) as orig_reader, \
|
| 608 |
+
VideoReader_contextmanager(grounded_path, num_threads=2) as ground_reader, \
|
| 609 |
+
VideoReader_contextmanager(edited_path, num_threads=2) as edit_reader:
|
| 610 |
+
|
| 611 |
+
orig_length = len(orig_reader)
|
| 612 |
+
ground_length = len(ground_reader)
|
| 613 |
+
edit_length = len(edit_reader)
|
| 614 |
+
|
| 615 |
+
start_idx = 0
|
| 616 |
+
|
| 617 |
+
orig_indices = np.linspace(
|
| 618 |
+
start_idx,
|
| 619 |
+
min(start_idx + (self.source_frames - 1) * self.video_sample_stride, max(orig_length - 1, 0)),
|
| 620 |
+
self.source_frames,
|
| 621 |
+
dtype=int
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
# reasoning/grounded indices at 8-frame interval (example: 0,7,14,21, ...)
|
| 625 |
+
interval = 8
|
| 626 |
+
ground_indices_full = np.arange(0, max(ground_length, 1), interval, dtype=int)
|
| 627 |
+
if len(ground_indices_full) == 0:
|
| 628 |
+
ground_indices = np.array([0] * self.reasoning_frames, dtype=int)
|
| 629 |
+
else:
|
| 630 |
+
ground_indices = ground_indices_full[: self.reasoning_frames]
|
| 631 |
+
if len(ground_indices) < self.reasoning_frames:
|
| 632 |
+
pad_value = ground_indices[-1] if len(ground_indices) > 0 else 0
|
| 633 |
+
ground_indices = np.pad(
|
| 634 |
+
ground_indices, (0, self.reasoning_frames - len(ground_indices)), constant_values=pad_value
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
edit_indices = np.linspace(
|
| 638 |
+
start_idx,
|
| 639 |
+
min(start_idx + (self.edit_frames - 1) * self.video_sample_stride, max(edit_length - 1, 0)),
|
| 640 |
+
self.edit_frames,
|
| 641 |
+
dtype=int
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
orig_frames = get_video_reader_batch(orig_reader, orig_indices)
|
| 645 |
+
ground_frames = get_video_reader_batch(ground_reader, ground_indices)
|
| 646 |
+
edit_frames = get_video_reader_batch(edit_reader, edit_indices)
|
| 647 |
+
|
| 648 |
+
def resize_and_center_crop_batch(frames_np, target_h, target_w):
|
| 649 |
+
resized = []
|
| 650 |
+
for i in range(frames_np.shape[0]):
|
| 651 |
+
frame = frames_np[i]
|
| 652 |
+
h, w = frame.shape[0], frame.shape[1]
|
| 653 |
+
scale = max(target_h / h, target_w / w)
|
| 654 |
+
new_h = int(round(h * scale))
|
| 655 |
+
new_w = int(round(w * scale))
|
| 656 |
+
frame_resized = cv2.resize(frame, (new_w, new_h))
|
| 657 |
+
y0 = max((new_h - target_h) // 2, 0)
|
| 658 |
+
x0 = max((new_w - target_w) // 2, 0)
|
| 659 |
+
frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w]
|
| 660 |
+
resized.append(frame_cropped)
|
| 661 |
+
return np.stack(resized, axis=0)
|
| 662 |
+
|
| 663 |
+
oh, ow = orig_frames.shape[1], orig_frames.shape[2]
|
| 664 |
+
gh, gw = ground_frames.shape[1], ground_frames.shape[2]
|
| 665 |
+
eh, ew = edit_frames.shape[1], edit_frames.shape[2]
|
| 666 |
+
target_h = min(oh, gh, eh)
|
| 667 |
+
target_w = min(ow, gw, ew)
|
| 668 |
+
if (oh != target_h or ow != target_w):
|
| 669 |
+
orig_frames = resize_and_center_crop_batch(orig_frames, target_h, target_w)
|
| 670 |
+
if (gh != target_h or gw != target_w):
|
| 671 |
+
ground_frames = resize_and_center_crop_batch(ground_frames, target_h, target_w)
|
| 672 |
+
if (eh != target_h or ew != target_w):
|
| 673 |
+
edit_frames = resize_and_center_crop_batch(edit_frames, target_h, target_w)
|
| 674 |
+
|
| 675 |
+
if self.enable_bucket:
|
| 676 |
+
return np.concatenate([orig_frames, ground_frames, edit_frames], axis=0)
|
| 677 |
+
else:
|
| 678 |
+
orig_frames = torch.from_numpy(orig_frames).permute(0, 3, 1, 2).contiguous() / 255.
|
| 679 |
+
ground_frames = torch.from_numpy(ground_frames).permute(0, 3, 1, 2).contiguous() / 255.
|
| 680 |
+
edit_frames = torch.from_numpy(edit_frames).permute(0, 3, 1, 2).contiguous() / 255.
|
| 681 |
+
return torch.cat([orig_frames, ground_frames, edit_frames], dim=0)
|
| 682 |
+
|
| 683 |
+
def __len__(self):
|
| 684 |
+
return self.length
|
| 685 |
+
|
| 686 |
+
def __getitem__(self, idx):
|
| 687 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 688 |
+
|
| 689 |
+
while True:
|
| 690 |
+
try:
|
| 691 |
+
pixel_values = self.load_video_pair(
|
| 692 |
+
data_info['original_video'],
|
| 693 |
+
data_info.get('grounded_video', data_info.get('ground_video')),
|
| 694 |
+
data_info['edited_video'],
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
# Prepare instructions
|
| 698 |
+
edit_text = data_info.get('edit_instruction', data_info.get('text', ''))
|
| 699 |
+
ground_instr = derive_ground_object_from_instruction(edit_text)
|
| 700 |
+
|
| 701 |
+
text = edit_text
|
| 702 |
+
if self.instruction_template:
|
| 703 |
+
text = self.instruction_template.format(edit_instruction=edit_text, ground_instrction=ground_instr)
|
| 704 |
+
|
| 705 |
+
if random.random() < self.text_drop_ratio:
|
| 706 |
+
text = ''
|
| 707 |
+
|
| 708 |
+
sample = {
|
| 709 |
+
"pixel_values": pixel_values,
|
| 710 |
+
"text": text,
|
| 711 |
+
"data_type": "video",
|
| 712 |
+
"idx": idx,
|
| 713 |
+
}
|
| 714 |
+
|
| 715 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 716 |
+
pass
|
| 717 |
+
|
| 718 |
+
return sample
|
| 719 |
+
|
| 720 |
+
except Exception as e:
|
| 721 |
+
print(f"Error loading video triplet: {e}")
|
| 722 |
+
idx = random.randint(0, self.length-1)
|
| 723 |
+
|
| 724 |
+
class ImageVideoDataset(Dataset):
|
| 725 |
+
def __init__(
|
| 726 |
+
self,
|
| 727 |
+
ann_path, data_root=None,
|
| 728 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
| 729 |
+
image_sample_size=512,
|
| 730 |
+
video_repeat=0,
|
| 731 |
+
text_drop_ratio=0.1,
|
| 732 |
+
enable_bucket=False,
|
| 733 |
+
video_length_drop_start=0.0,
|
| 734 |
+
video_length_drop_end=1.0,
|
| 735 |
+
enable_inpaint=False,
|
| 736 |
+
return_file_name=False,
|
| 737 |
+
):
|
| 738 |
+
# Loading annotations from files
|
| 739 |
+
print(f"loading annotations from {ann_path} ...")
|
| 740 |
+
if ann_path.endswith('.csv'):
|
| 741 |
+
with open(ann_path, 'r') as csvfile:
|
| 742 |
+
dataset = list(csv.DictReader(csvfile))
|
| 743 |
+
elif ann_path.endswith('.json'):
|
| 744 |
+
dataset = json.load(open(ann_path))
|
| 745 |
+
|
| 746 |
+
self.data_root = data_root
|
| 747 |
+
|
| 748 |
+
# It's used to balance num of images and videos.
|
| 749 |
+
if video_repeat > 0:
|
| 750 |
+
self.dataset = []
|
| 751 |
+
for data in dataset:
|
| 752 |
+
if data.get('type', 'image') != 'video':
|
| 753 |
+
self.dataset.append(data)
|
| 754 |
+
|
| 755 |
+
for _ in range(video_repeat):
|
| 756 |
+
for data in dataset:
|
| 757 |
+
if data.get('type', 'image') == 'video':
|
| 758 |
+
self.dataset.append(data)
|
| 759 |
+
else:
|
| 760 |
+
self.dataset = dataset
|
| 761 |
+
del dataset
|
| 762 |
+
|
| 763 |
+
self.length = len(self.dataset)
|
| 764 |
+
print(f"data scale: {self.length}")
|
| 765 |
+
# TODO: enable bucket training
|
| 766 |
+
self.enable_bucket = enable_bucket
|
| 767 |
+
self.text_drop_ratio = text_drop_ratio
|
| 768 |
+
self.enable_inpaint = enable_inpaint
|
| 769 |
+
self.return_file_name = return_file_name
|
| 770 |
+
|
| 771 |
+
self.video_length_drop_start = video_length_drop_start
|
| 772 |
+
self.video_length_drop_end = video_length_drop_end
|
| 773 |
+
|
| 774 |
+
# Video params
|
| 775 |
+
self.video_sample_stride = video_sample_stride
|
| 776 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 777 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 778 |
+
self.video_transforms = transforms.Compose(
|
| 779 |
+
[
|
| 780 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 781 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 782 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 783 |
+
]
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
# Image params
|
| 787 |
+
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
| 788 |
+
self.image_transforms = transforms.Compose([
|
| 789 |
+
transforms.Resize(min(self.image_sample_size)),
|
| 790 |
+
transforms.CenterCrop(self.image_sample_size),
|
| 791 |
+
transforms.ToTensor(),
|
| 792 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
| 793 |
+
])
|
| 794 |
+
|
| 795 |
+
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
|
| 796 |
+
|
| 797 |
+
def get_batch(self, idx):
|
| 798 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 799 |
+
|
| 800 |
+
if data_info.get('type', 'image')=='video':
|
| 801 |
+
video_id, text = data_info['file_path'], data_info['text']
|
| 802 |
+
|
| 803 |
+
if self.data_root is None:
|
| 804 |
+
video_dir = video_id
|
| 805 |
+
else:
|
| 806 |
+
video_dir = os.path.join(self.data_root, video_id)
|
| 807 |
+
|
| 808 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 809 |
+
min_sample_n_frames = min(
|
| 810 |
+
self.video_sample_n_frames,
|
| 811 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 812 |
+
)
|
| 813 |
+
if min_sample_n_frames == 0:
|
| 814 |
+
raise ValueError(f"No Frames in video.")
|
| 815 |
+
|
| 816 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 817 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 818 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 819 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 820 |
+
|
| 821 |
+
try:
|
| 822 |
+
sample_args = (video_reader, batch_index)
|
| 823 |
+
pixel_values = func_timeout(
|
| 824 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 825 |
+
)
|
| 826 |
+
resized_frames = []
|
| 827 |
+
for i in range(len(pixel_values)):
|
| 828 |
+
frame = pixel_values[i]
|
| 829 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 830 |
+
resized_frames.append(resized_frame)
|
| 831 |
+
pixel_values = np.array(resized_frames)
|
| 832 |
+
except FunctionTimedOut:
|
| 833 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 834 |
+
except Exception as e:
|
| 835 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 836 |
+
|
| 837 |
+
if not self.enable_bucket:
|
| 838 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 839 |
+
pixel_values = pixel_values / 255.
|
| 840 |
+
del video_reader
|
| 841 |
+
else:
|
| 842 |
+
pixel_values = pixel_values
|
| 843 |
+
|
| 844 |
+
if not self.enable_bucket:
|
| 845 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 846 |
+
|
| 847 |
+
# Random use no text generation
|
| 848 |
+
if random.random() < self.text_drop_ratio:
|
| 849 |
+
text = ''
|
| 850 |
+
return pixel_values, text, 'video', video_dir
|
| 851 |
+
else:
|
| 852 |
+
image_path, text = data_info['file_path'], data_info['text']
|
| 853 |
+
if self.data_root is not None:
|
| 854 |
+
image_path = os.path.join(self.data_root, image_path)
|
| 855 |
+
image = Image.open(image_path).convert('RGB')
|
| 856 |
+
if not self.enable_bucket:
|
| 857 |
+
image = self.image_transforms(image).unsqueeze(0)
|
| 858 |
+
else:
|
| 859 |
+
image = np.expand_dims(np.array(image), 0)
|
| 860 |
+
if random.random() < self.text_drop_ratio:
|
| 861 |
+
text = ''
|
| 862 |
+
return image, text, 'image', image_path
|
| 863 |
+
|
| 864 |
+
def __len__(self):
|
| 865 |
+
return self.length
|
| 866 |
+
|
| 867 |
+
def __getitem__(self, idx):
|
| 868 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 869 |
+
data_type = data_info.get('type', 'image')
|
| 870 |
+
while True:
|
| 871 |
+
sample = {}
|
| 872 |
+
try:
|
| 873 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 874 |
+
data_type_local = data_info_local.get('type', 'image')
|
| 875 |
+
if data_type_local != data_type:
|
| 876 |
+
raise ValueError("data_type_local != data_type")
|
| 877 |
+
|
| 878 |
+
pixel_values, name, data_type, file_path = self.get_batch(idx)
|
| 879 |
+
sample["pixel_values"] = pixel_values
|
| 880 |
+
sample["text"] = name
|
| 881 |
+
sample["data_type"] = data_type
|
| 882 |
+
sample["idx"] = idx
|
| 883 |
+
if self.return_file_name:
|
| 884 |
+
sample["file_name"] = os.path.basename(file_path)
|
| 885 |
+
|
| 886 |
+
if len(sample) > 0:
|
| 887 |
+
break
|
| 888 |
+
except Exception as e:
|
| 889 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 890 |
+
idx = random.randint(0, self.length-1)
|
| 891 |
+
|
| 892 |
+
class ImageVideoEditDataset(Dataset):
|
| 893 |
+
def __init__(
|
| 894 |
+
self,
|
| 895 |
+
ann_path,
|
| 896 |
+
data_root=None,
|
| 897 |
+
video_sample_size=512,
|
| 898 |
+
video_sample_stride=1,
|
| 899 |
+
source_frames=33,
|
| 900 |
+
target_frames=32,
|
| 901 |
+
text_drop_ratio=0.1,
|
| 902 |
+
enable_bucket=False,
|
| 903 |
+
enable_inpaint=False,
|
| 904 |
+
video_length_drop_start=0.0,
|
| 905 |
+
video_length_drop_end=1.0,
|
| 906 |
+
instruction_template="A video sequence showing two parts: the first half shows the original scene, and the second half shows the same scene but {edit_instruction}",
|
| 907 |
+
):
|
| 908 |
+
dataset = json.load(open(ann_path))
|
| 909 |
+
if isinstance(dataset, dict):
|
| 910 |
+
new_dataset = []
|
| 911 |
+
for _, info in dataset.items():
|
| 912 |
+
# Keep original keys, just standardize text field
|
| 913 |
+
data_type = info.get("type", "video")
|
| 914 |
+
entry = dict(info) # Copy original entry
|
| 915 |
+
# Standardize text field name and handle None/empty values
|
| 916 |
+
if "edit_instruction" in entry:
|
| 917 |
+
entry["text"] = entry["edit_instruction"]
|
| 918 |
+
elif "instruction" in entry:
|
| 919 |
+
entry["text"] = entry["instruction"]
|
| 920 |
+
elif "text" not in entry:
|
| 921 |
+
entry["text"] = ""
|
| 922 |
+
|
| 923 |
+
# Ensure text is not None (convert None to empty string)
|
| 924 |
+
if entry["text"] is None:
|
| 925 |
+
entry["text"] = ""
|
| 926 |
+
|
| 927 |
+
# Add file_path for bucket sampler compatibility
|
| 928 |
+
# Bucket sampler expects 'file_path' to get dimensions
|
| 929 |
+
if data_type == "video":
|
| 930 |
+
entry["file_path"] = entry.get("original_video", "")
|
| 931 |
+
else: # image
|
| 932 |
+
entry["file_path"] = entry.get("original_image", "")
|
| 933 |
+
|
| 934 |
+
new_dataset.append(entry)
|
| 935 |
+
dataset = new_dataset
|
| 936 |
+
|
| 937 |
+
self.data_root = data_root
|
| 938 |
+
self.dataset = dataset
|
| 939 |
+
self.length = len(self.dataset)
|
| 940 |
+
|
| 941 |
+
# sampling params
|
| 942 |
+
self.video_sample_stride = video_sample_stride
|
| 943 |
+
self.source_frames = source_frames
|
| 944 |
+
self.target_frames = target_frames
|
| 945 |
+
self.video_length_drop_start = video_length_drop_start
|
| 946 |
+
self.video_length_drop_end = video_length_drop_end
|
| 947 |
+
|
| 948 |
+
# transforms params (match ImageVideoDataset)
|
| 949 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 950 |
+
self.video_transforms = transforms.Compose(
|
| 951 |
+
[
|
| 952 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 953 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 954 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 955 |
+
]
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
# Image transforms for non-bucket mode
|
| 959 |
+
self.image_transforms = transforms.Compose([
|
| 960 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 961 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 962 |
+
transforms.ToTensor(),
|
| 963 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
| 964 |
+
])
|
| 965 |
+
|
| 966 |
+
self.instruction_template = instruction_template
|
| 967 |
+
self.enable_bucket = enable_bucket
|
| 968 |
+
self.text_drop_ratio = text_drop_ratio
|
| 969 |
+
self.enable_inpaint = enable_inpaint
|
| 970 |
+
|
| 971 |
+
# For pre-resize like ImageVideoDataset
|
| 972 |
+
self.larger_side_of_image_and_video = min(self.video_sample_size)
|
| 973 |
+
|
| 974 |
+
def _resize_and_center_crop_batch(self, frames_np, target_h, target_w):
|
| 975 |
+
resized = []
|
| 976 |
+
for i in range(frames_np.shape[0]):
|
| 977 |
+
frame = frames_np[i]
|
| 978 |
+
h, w = frame.shape[0], frame.shape[1]
|
| 979 |
+
scale = max(target_h / h, target_w / w)
|
| 980 |
+
new_h = int(round(h * scale))
|
| 981 |
+
new_w = int(round(w * scale))
|
| 982 |
+
frame_resized = cv2.resize(frame, (new_w, new_h))
|
| 983 |
+
y0 = max((new_h - target_h) // 2, 0)
|
| 984 |
+
x0 = max((new_w - target_w) // 2, 0)
|
| 985 |
+
frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w]
|
| 986 |
+
resized.append(frame_cropped)
|
| 987 |
+
return np.stack(resized, axis=0)
|
| 988 |
+
|
| 989 |
+
def _resize_and_center_crop_image(self, image_np, target_h, target_w):
|
| 990 |
+
h, w = image_np.shape[0], image_np.shape[1]
|
| 991 |
+
scale = max(target_h / h, target_w / w)
|
| 992 |
+
new_h = int(round(h * scale))
|
| 993 |
+
new_w = int(round(w * scale))
|
| 994 |
+
image_resized = cv2.resize(image_np, (new_w, new_h))
|
| 995 |
+
y0 = max((new_h - target_h) // 2, 0)
|
| 996 |
+
x0 = max((new_w - target_w) // 2, 0)
|
| 997 |
+
image_cropped = image_resized[y0:y0 + target_h, x0:x0 + target_w]
|
| 998 |
+
return image_cropped
|
| 999 |
+
|
| 1000 |
+
def get_batch(self, idx):
|
| 1001 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 1002 |
+
|
| 1003 |
+
data_type = data_info.get('type', 'video')
|
| 1004 |
+
|
| 1005 |
+
# Handle None or empty instruction with safety fallback
|
| 1006 |
+
raw_text = data_info.get('text', '')
|
| 1007 |
+
if raw_text is None or (isinstance(raw_text, str) and not raw_text.strip()):
|
| 1008 |
+
# Use a generic fallback description if instruction is missing
|
| 1009 |
+
raw_text = "the content has been modified"
|
| 1010 |
+
|
| 1011 |
+
# Apply instruction template if available
|
| 1012 |
+
if self.instruction_template and "{edit_instruction}" in self.instruction_template:
|
| 1013 |
+
text = self.instruction_template.format(edit_instruction=raw_text)
|
| 1014 |
+
else:
|
| 1015 |
+
text = raw_text
|
| 1016 |
+
|
| 1017 |
+
if data_type == 'video':
|
| 1018 |
+
# video pair branch (default)
|
| 1019 |
+
src_rel, tgt_rel = data_info['original_video'], data_info['edited_video']
|
| 1020 |
+
|
| 1021 |
+
if self.data_root is not None:
|
| 1022 |
+
src_path = os.path.join(self.data_root, src_rel)
|
| 1023 |
+
tgt_path = os.path.join(self.data_root, tgt_rel)
|
| 1024 |
+
else:
|
| 1025 |
+
src_path = src_rel
|
| 1026 |
+
tgt_path = tgt_rel
|
| 1027 |
+
|
| 1028 |
+
# Force use CPU decoder to read all frames instead of just keyframes
|
| 1029 |
+
from decord import cpu
|
| 1030 |
+
with VideoReader_contextmanager(src_path, num_threads=2, ctx=cpu(0)) as src_reader, \
|
| 1031 |
+
VideoReader_contextmanager(tgt_path, num_threads=2, ctx=cpu(0)) as tgt_reader:
|
| 1032 |
+
|
| 1033 |
+
# Get video lengths
|
| 1034 |
+
src_length = len(src_reader)
|
| 1035 |
+
tgt_length = len(tgt_reader)
|
| 1036 |
+
|
| 1037 |
+
# Check if video has enough frames
|
| 1038 |
+
if src_length < self.source_frames:
|
| 1039 |
+
raise ValueError(f"Source video only has {src_length} frames, but requested {self.source_frames}")
|
| 1040 |
+
if tgt_length < self.target_frames:
|
| 1041 |
+
raise ValueError(f"Target video only has {tgt_length} frames, but requested {self.target_frames}")
|
| 1042 |
+
|
| 1043 |
+
# Unified sampling strategy: start from beginning (same as VideoEditDataset)
|
| 1044 |
+
start_idx = 0
|
| 1045 |
+
|
| 1046 |
+
src_indices = np.linspace(
|
| 1047 |
+
start_idx,
|
| 1048 |
+
min(start_idx + (self.source_frames - 1) * self.video_sample_stride, src_length - 1),
|
| 1049 |
+
self.source_frames,
|
| 1050 |
+
dtype=int
|
| 1051 |
+
)
|
| 1052 |
+
|
| 1053 |
+
tgt_indices = np.linspace(
|
| 1054 |
+
start_idx,
|
| 1055 |
+
min(start_idx + (self.target_frames - 1) * self.video_sample_stride, tgt_length - 1),
|
| 1056 |
+
self.target_frames,
|
| 1057 |
+
dtype=int
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
# read batches with timeout
|
| 1061 |
+
try:
|
| 1062 |
+
src_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(src_reader, src_indices))
|
| 1063 |
+
tgt_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(tgt_reader, tgt_indices))
|
| 1064 |
+
except FunctionTimedOut:
|
| 1065 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 1066 |
+
except Exception as e:
|
| 1067 |
+
raise ValueError(f"Failed to extract frames from pair. Error is {e}.")
|
| 1068 |
+
|
| 1069 |
+
# align HxW between source and target to enable concat
|
| 1070 |
+
sh, sw = src_frames.shape[1], src_frames.shape[2]
|
| 1071 |
+
th, tw = tgt_frames.shape[1], tgt_frames.shape[2]
|
| 1072 |
+
target_h = min(sh, th)
|
| 1073 |
+
target_w = min(sw, tw)
|
| 1074 |
+
if (sh != target_h or sw != target_w):
|
| 1075 |
+
src_frames = self._resize_and_center_crop_batch(src_frames, target_h, target_w)
|
| 1076 |
+
if (th != target_h or tw != target_w):
|
| 1077 |
+
tgt_frames = self._resize_and_center_crop_batch(tgt_frames, target_h, target_w)
|
| 1078 |
+
|
| 1079 |
+
if not self.enable_bucket:
|
| 1080 |
+
src_tensor = torch.from_numpy(src_frames).permute(0, 3, 1, 2).contiguous() / 255.
|
| 1081 |
+
tgt_tensor = torch.from_numpy(tgt_frames).permute(0, 3, 1, 2).contiguous() / 255.
|
| 1082 |
+
|
| 1083 |
+
src_tensor = self.video_transforms(src_tensor)
|
| 1084 |
+
tgt_tensor = self.video_transforms(tgt_tensor)
|
| 1085 |
+
else:
|
| 1086 |
+
src_tensor = src_frames
|
| 1087 |
+
tgt_tensor = tgt_frames
|
| 1088 |
+
|
| 1089 |
+
# Random text drop
|
| 1090 |
+
if random.random() < self.text_drop_ratio:
|
| 1091 |
+
text = ''
|
| 1092 |
+
|
| 1093 |
+
return src_tensor, tgt_tensor, text, 'video'
|
| 1094 |
+
else:
|
| 1095 |
+
# image pair branch (simple like ImageVideoDataset image path)
|
| 1096 |
+
src_img_rel = data_info.get('original_image')
|
| 1097 |
+
tgt_img_rel = data_info.get('edited_image')
|
| 1098 |
+
if src_img_rel is None or tgt_img_rel is None:
|
| 1099 |
+
raise ValueError('Missing original_image/edited_image for image sample')
|
| 1100 |
+
|
| 1101 |
+
if self.data_root is not None:
|
| 1102 |
+
src_img_path = os.path.join(self.data_root, src_img_rel)
|
| 1103 |
+
tgt_img_path = os.path.join(self.data_root, tgt_img_rel)
|
| 1104 |
+
else:
|
| 1105 |
+
src_img_path = src_img_rel
|
| 1106 |
+
tgt_img_path = tgt_img_rel
|
| 1107 |
+
|
| 1108 |
+
src_img = Image.open(src_img_path).convert('RGB')
|
| 1109 |
+
tgt_img = Image.open(tgt_img_path).convert('RGB')
|
| 1110 |
+
|
| 1111 |
+
if not self.enable_bucket:
|
| 1112 |
+
# Apply transforms and add frame dimension
|
| 1113 |
+
src_tensor = self.image_transforms(src_img).unsqueeze(0) # (1, C, H, W)
|
| 1114 |
+
tgt_tensor = self.image_transforms(tgt_img).unsqueeze(0) # (1, C, H, W)
|
| 1115 |
+
else:
|
| 1116 |
+
# For bucket mode, keep as numpy and add frame dimension
|
| 1117 |
+
src_tensor = np.expand_dims(np.array(src_img), axis=0) # (1, H, W, C)
|
| 1118 |
+
tgt_tensor = np.expand_dims(np.array(tgt_img), axis=0) # (1, H, W, C)
|
| 1119 |
+
|
| 1120 |
+
if random.random() < self.text_drop_ratio:
|
| 1121 |
+
text = ''
|
| 1122 |
+
|
| 1123 |
+
return src_tensor, tgt_tensor, text, 'image'
|
| 1124 |
+
|
| 1125 |
+
def __len__(self):
|
| 1126 |
+
return self.length
|
| 1127 |
+
|
| 1128 |
+
def __getitem__(self, idx):
|
| 1129 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 1130 |
+
data_type = data_info.get('type', 'video')
|
| 1131 |
+
while True:
|
| 1132 |
+
sample = {}
|
| 1133 |
+
try:
|
| 1134 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 1135 |
+
data_type_local = data_info_local.get('type', 'video')
|
| 1136 |
+
if data_type_local != data_type:
|
| 1137 |
+
raise ValueError("data_type_local != data_type")
|
| 1138 |
+
|
| 1139 |
+
src_vals, tgt_vals, name, data_type = self.get_batch(idx)
|
| 1140 |
+
if data_type == 'video':
|
| 1141 |
+
sample["pixel_values_src_video"] = src_vals
|
| 1142 |
+
sample["pixel_values_tgt_video"] = tgt_vals
|
| 1143 |
+
else:
|
| 1144 |
+
sample["pixel_values_src_image"] = src_vals
|
| 1145 |
+
sample["pixel_values_tgt_image"] = tgt_vals
|
| 1146 |
+
sample["text"] = name
|
| 1147 |
+
sample["data_type"] = data_type
|
| 1148 |
+
sample["idx"] = idx
|
| 1149 |
+
|
| 1150 |
+
if len(sample) > 0:
|
| 1151 |
+
break
|
| 1152 |
+
except Exception as e:
|
| 1153 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 1154 |
+
idx = random.randint(0, self.length-1)
|
| 1155 |
+
|
| 1156 |
+
# Inpaint not applied here to avoid ambiguity across src/tgt branches
|
| 1157 |
+
|
| 1158 |
+
return sample
|
| 1159 |
+
|
| 1160 |
+
|
| 1161 |
+
class ImageVideoCoTDataset(Dataset):
|
| 1162 |
+
"""
|
| 1163 |
+
Dataset for Chain-of-Thought (CoT) style image/video editing.
|
| 1164 |
+
- For videos: loads original_video, grounded_video, and edited_video (3-part)
|
| 1165 |
+
- For images: loads original_image and edited_image (2-part, same as ImageVideoEditDataset)
|
| 1166 |
+
"""
|
| 1167 |
+
def __init__(
|
| 1168 |
+
self,
|
| 1169 |
+
ann_path,
|
| 1170 |
+
data_root=None,
|
| 1171 |
+
video_sample_size=512,
|
| 1172 |
+
video_sample_stride=1,
|
| 1173 |
+
source_frames=33,
|
| 1174 |
+
reasoning_frames=4,
|
| 1175 |
+
target_frames=33,
|
| 1176 |
+
text_drop_ratio=0.1,
|
| 1177 |
+
enable_bucket=False,
|
| 1178 |
+
enable_inpaint=False,
|
| 1179 |
+
video_length_drop_start=0.0,
|
| 1180 |
+
video_length_drop_end=1.0,
|
| 1181 |
+
instruction_template="A video sequence showing three parts: first the original scene, then grounded {ground_instruction}, and finally the same scene but {edit_instruction}",
|
| 1182 |
+
enable_gradual_ground=False,
|
| 1183 |
+
enable_gray_red_mask=False,
|
| 1184 |
+
enable_gray_black_background=False,
|
| 1185 |
+
enable_gray_alpha_overlay=False,
|
| 1186 |
+
gray_alpha=0.5,
|
| 1187 |
+
gray_intensity_range=(96, 160),
|
| 1188 |
+
gray_tolerance=12,
|
| 1189 |
+
):
|
| 1190 |
+
dataset = json.load(open(ann_path))
|
| 1191 |
+
if isinstance(dataset, dict):
|
| 1192 |
+
new_dataset = []
|
| 1193 |
+
for _, info in dataset.items():
|
| 1194 |
+
data_type = info.get("type", "video")
|
| 1195 |
+
entry = dict(info) # Copy original entry
|
| 1196 |
+
|
| 1197 |
+
# Standardize text field name and handle None/empty values
|
| 1198 |
+
if "edit_instruction" in entry:
|
| 1199 |
+
entry["text"] = entry["edit_instruction"]
|
| 1200 |
+
elif "instruction" in entry:
|
| 1201 |
+
entry["text"] = entry["instruction"]
|
| 1202 |
+
elif "text" not in entry:
|
| 1203 |
+
entry["text"] = ""
|
| 1204 |
+
|
| 1205 |
+
# Ensure text is not None
|
| 1206 |
+
if entry["text"] is None:
|
| 1207 |
+
entry["text"] = ""
|
| 1208 |
+
|
| 1209 |
+
# Add file_path for bucket sampler compatibility
|
| 1210 |
+
if data_type == "video":
|
| 1211 |
+
entry["file_path"] = entry.get("original_video", "")
|
| 1212 |
+
else: # image
|
| 1213 |
+
entry["file_path"] = entry.get("original_image", "")
|
| 1214 |
+
|
| 1215 |
+
new_dataset.append(entry)
|
| 1216 |
+
dataset = new_dataset
|
| 1217 |
+
|
| 1218 |
+
self.data_root = data_root
|
| 1219 |
+
self.dataset = dataset
|
| 1220 |
+
self.length = len(self.dataset)
|
| 1221 |
+
|
| 1222 |
+
# sampling params
|
| 1223 |
+
self.video_sample_stride = video_sample_stride
|
| 1224 |
+
self.source_frames = source_frames
|
| 1225 |
+
self.reasoning_frames = reasoning_frames
|
| 1226 |
+
self.target_frames = target_frames
|
| 1227 |
+
self.video_length_drop_start = video_length_drop_start
|
| 1228 |
+
self.video_length_drop_end = video_length_drop_end
|
| 1229 |
+
|
| 1230 |
+
# transforms params
|
| 1231 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 1232 |
+
self.video_transforms = transforms.Compose(
|
| 1233 |
+
[
|
| 1234 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 1235 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 1236 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 1237 |
+
]
|
| 1238 |
+
)
|
| 1239 |
+
|
| 1240 |
+
# Image transforms for non-bucket mode
|
| 1241 |
+
self.image_transforms = transforms.Compose([
|
| 1242 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 1243 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 1244 |
+
transforms.ToTensor(),
|
| 1245 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
| 1246 |
+
])
|
| 1247 |
+
|
| 1248 |
+
self.instruction_template = instruction_template
|
| 1249 |
+
self.enable_bucket = enable_bucket
|
| 1250 |
+
self.text_drop_ratio = text_drop_ratio
|
| 1251 |
+
self.enable_inpaint = enable_inpaint
|
| 1252 |
+
self.enable_gradual_ground = enable_gradual_ground
|
| 1253 |
+
# only one visualization mode at a time
|
| 1254 |
+
enabled_modes = int(bool(enable_gray_red_mask)) + int(bool(enable_gray_black_background)) + int(bool(enable_gray_alpha_overlay))
|
| 1255 |
+
if enabled_modes > 1:
|
| 1256 |
+
raise ValueError("enable_gray_red_mask, enable_gray_black_background and enable_gray_alpha_overlay cannot be enabled simultaneously.")
|
| 1257 |
+
self.enable_gray_red_mask = enable_gray_red_mask
|
| 1258 |
+
self.enable_gray_black_background = enable_gray_black_background
|
| 1259 |
+
self.enable_gray_alpha_overlay = enable_gray_alpha_overlay
|
| 1260 |
+
self.gray_alpha = float(gray_alpha)
|
| 1261 |
+
if not (0.0 <= self.gray_alpha <= 1.0):
|
| 1262 |
+
raise ValueError("gray_alpha must be in [0,1].")
|
| 1263 |
+
if not isinstance(gray_intensity_range, (list, tuple)) or len(gray_intensity_range) != 2:
|
| 1264 |
+
raise ValueError("gray_intensity_range must contain exactly two values (min and max intensity).")
|
| 1265 |
+
self.gray_intensity_range = (int(gray_intensity_range[0]), int(gray_intensity_range[1]))
|
| 1266 |
+
if self.gray_intensity_range[0] > self.gray_intensity_range[1]:
|
| 1267 |
+
raise ValueError("gray_intensity_range min value cannot be greater than max value.")
|
| 1268 |
+
self.gray_tolerance = int(gray_tolerance)
|
| 1269 |
+
|
| 1270 |
+
# For pre-resize like ImageVideoDataset
|
| 1271 |
+
self.larger_side_of_image_and_video = min(self.video_sample_size)
|
| 1272 |
+
|
| 1273 |
+
def _resize_and_center_crop_batch(self, frames_np, target_h, target_w):
|
| 1274 |
+
resized = []
|
| 1275 |
+
for i in range(frames_np.shape[0]):
|
| 1276 |
+
frame = frames_np[i]
|
| 1277 |
+
h, w = frame.shape[0], frame.shape[1]
|
| 1278 |
+
scale = max(target_h / h, target_w / w)
|
| 1279 |
+
new_h = int(round(h * scale))
|
| 1280 |
+
new_w = int(round(w * scale))
|
| 1281 |
+
frame_resized = cv2.resize(frame, (new_w, new_h))
|
| 1282 |
+
y0 = max((new_h - target_h) // 2, 0)
|
| 1283 |
+
x0 = max((new_w - target_w) // 2, 0)
|
| 1284 |
+
frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w]
|
| 1285 |
+
resized.append(frame_cropped)
|
| 1286 |
+
return np.stack(resized, axis=0)
|
| 1287 |
+
|
| 1288 |
+
def _resize_and_center_crop_image(self, image_np, target_h, target_w):
|
| 1289 |
+
h, w = image_np.shape[0], image_np.shape[1]
|
| 1290 |
+
scale = max(target_h / h, target_w / w)
|
| 1291 |
+
new_h = int(round(h * scale))
|
| 1292 |
+
new_w = int(round(w * scale))
|
| 1293 |
+
image_resized = cv2.resize(image_np, (new_w, new_h))
|
| 1294 |
+
y0 = max((new_h - target_h) // 2, 0)
|
| 1295 |
+
x0 = max((new_w - target_w) // 2, 0)
|
| 1296 |
+
image_cropped = image_resized[y0:y0 + target_h, x0:x0 + target_w]
|
| 1297 |
+
return image_cropped
|
| 1298 |
+
|
| 1299 |
+
def _derive_ground_instruction(self, edit_instruction_text: str) -> str:
|
| 1300 |
+
"""Derive grounded object phrase from instruction using shared rules."""
|
| 1301 |
+
return derive_ground_object_from_instruction(edit_instruction_text)
|
| 1302 |
+
|
| 1303 |
+
def _ensure_same_size_pair(self, img_a: np.ndarray, img_b: np.ndarray) -> tuple:
|
| 1304 |
+
"""Resize img_b to img_a's size if needed to enable per-pixel interpolation."""
|
| 1305 |
+
ha, wa = img_a.shape[:2]
|
| 1306 |
+
hb, wb = img_b.shape[:2]
|
| 1307 |
+
if (ha, wa) == (hb, wb):
|
| 1308 |
+
return img_a, img_b
|
| 1309 |
+
resized_b = cv2.resize(img_b, (wa, ha), interpolation=cv2.INTER_LINEAR)
|
| 1310 |
+
return img_a, resized_b
|
| 1311 |
+
|
| 1312 |
+
def _interpolate_ground_frames(self, ground_first: np.ndarray, target_first: np.ndarray,
|
| 1313 |
+
total_steps: int = 16,
|
| 1314 |
+
pick_indices: tuple = (0, 4, 8, 12)) -> np.ndarray:
|
| 1315 |
+
"""
|
| 1316 |
+
Create grounding frames by linearly interpolating between the first frame of
|
| 1317 |
+
the grounding video and the first frame of the edited video, then picking
|
| 1318 |
+
specific indices.
|
| 1319 |
+
Returns array of shape (len(pick_indices), H, W, 3) in uint8.
|
| 1320 |
+
"""
|
| 1321 |
+
a_np, b_np = self._ensure_same_size_pair(ground_first, target_first)
|
| 1322 |
+
|
| 1323 |
+
a_t = torch.from_numpy(a_np).float() / 255.0 # H, W, C
|
| 1324 |
+
b_t = torch.from_numpy(b_np).float() / 255.0 # H, W, C
|
| 1325 |
+
|
| 1326 |
+
a_t = a_t.permute(2, 0, 1).contiguous() # C, H, W
|
| 1327 |
+
b_t = b_t.permute(2, 0, 1).contiguous() # C, H, W
|
| 1328 |
+
|
| 1329 |
+
c, h, w = a_t.shape
|
| 1330 |
+
pair = torch.stack([a_t, b_t], dim=0) # 2, C, H, W
|
| 1331 |
+
pair_chw_t = pair.permute(1, 2, 3, 0).contiguous() # C, H, W, 2
|
| 1332 |
+
seq = pair_chw_t.view(1, c * h * w, 2) # 1, (C*H*W), 2
|
| 1333 |
+
with torch.no_grad():
|
| 1334 |
+
seq_interp = F.interpolate(seq, size=int(total_steps), mode="linear", align_corners=True)
|
| 1335 |
+
seq_interp = seq_interp.view(c, h, w, int(total_steps)).permute(3, 0, 1, 2).contiguous() # T, C, H, W
|
| 1336 |
+
|
| 1337 |
+
out_frames = []
|
| 1338 |
+
t_steps = int(total_steps)
|
| 1339 |
+
for idx in pick_indices:
|
| 1340 |
+
safe_idx = max(0, min(int(idx), t_steps - 1))
|
| 1341 |
+
img = (seq_interp[safe_idx].clamp(0.0, 1.0) * 255.0).byte().permute(1, 2, 0).cpu().numpy()
|
| 1342 |
+
out_frames.append(img)
|
| 1343 |
+
return np.stack(out_frames, axis=0)
|
| 1344 |
+
|
| 1345 |
+
def _build_gray_mask(self, frame: np.ndarray) -> np.ndarray:
|
| 1346 |
+
"""Detect gray regions in a frame using intensity range and tolerance."""
|
| 1347 |
+
frame_float = frame.astype(np.float32)
|
| 1348 |
+
if frame_float.max() <= 1.0:
|
| 1349 |
+
frame_float = frame_float * 255.0
|
| 1350 |
+
channel_max = frame_float.max(axis=2)
|
| 1351 |
+
channel_min = frame_float.min(axis=2)
|
| 1352 |
+
min_intensity, max_intensity = self.gray_intensity_range
|
| 1353 |
+
tone_flatness = channel_max - channel_min
|
| 1354 |
+
mask = tone_flatness <= float(self.gray_tolerance)
|
| 1355 |
+
mask &= channel_max >= float(min_intensity)
|
| 1356 |
+
mask &= channel_max <= float(max_intensity)
|
| 1357 |
+
return mask
|
| 1358 |
+
|
| 1359 |
+
def _apply_gray_region_effect(self, frames_np: np.ndarray, mode: str) -> np.ndarray:
|
| 1360 |
+
"""Apply requested effect on detected gray regions for a batch of frames."""
|
| 1361 |
+
processed_frames = []
|
| 1362 |
+
for frame in frames_np:
|
| 1363 |
+
mask = self._build_gray_mask(frame)
|
| 1364 |
+
if not np.any(mask):
|
| 1365 |
+
processed_frames.append(frame)
|
| 1366 |
+
continue
|
| 1367 |
+
frame_out = frame.copy()
|
| 1368 |
+
if np.issubdtype(frame_out.dtype, np.floating) and frame_out.max() <= 1.0:
|
| 1369 |
+
red_value = np.array([1.0, 0.0, 0.0], dtype=frame_out.dtype)
|
| 1370 |
+
else:
|
| 1371 |
+
red_value = np.array([255, 0, 0], dtype=frame_out.dtype)
|
| 1372 |
+
if mode == "red":
|
| 1373 |
+
frame_out[mask] = red_value
|
| 1374 |
+
else:
|
| 1375 |
+
frame_out[:] = 0
|
| 1376 |
+
frame_out[mask] = frame[mask]
|
| 1377 |
+
processed_frames.append(frame_out)
|
| 1378 |
+
return np.stack(processed_frames, axis=0)
|
| 1379 |
+
|
| 1380 |
+
def _apply_gray_overlay_from_reference(self, src_frames_np: np.ndarray, ref_frames_np: np.ndarray,
|
| 1381 |
+
alpha: float = 0.5, gray_value: float = 0.5, num_frames: int = 4) -> np.ndarray:
|
| 1382 |
+
"""
|
| 1383 |
+
Detect gray regions on ref frames, and overlay gray with alpha onto the
|
| 1384 |
+
first `num_frames` frames of src frames at the same positions.
|
| 1385 |
+
"""
|
| 1386 |
+
n = min(int(num_frames), int(src_frames_np.shape[0]), int(ref_frames_np.shape[0]))
|
| 1387 |
+
if n <= 0:
|
| 1388 |
+
return src_frames_np
|
| 1389 |
+
out = src_frames_np.copy()
|
| 1390 |
+
a = float(alpha)
|
| 1391 |
+
a = 0.0 if a < 0.0 else (1.0 if a > 1.0 else a)
|
| 1392 |
+
gv = float(gray_value)
|
| 1393 |
+
gv = 0.0 if gv < 0.0 else (1.0 if gv > 1.0 else gv)
|
| 1394 |
+
for i in range(n):
|
| 1395 |
+
mask = self._build_gray_mask(ref_frames_np[i])
|
| 1396 |
+
if not np.any(mask):
|
| 1397 |
+
continue
|
| 1398 |
+
src = out[i]
|
| 1399 |
+
# normalize to 0..1 float
|
| 1400 |
+
if np.issubdtype(src.dtype, np.floating):
|
| 1401 |
+
f = src.astype(np.float32)
|
| 1402 |
+
if f.max() > 1.0:
|
| 1403 |
+
f = np.clip(f / 255.0, 0.0, 1.0)
|
| 1404 |
+
back_to_uint8 = False
|
| 1405 |
+
else:
|
| 1406 |
+
f = src.astype(np.float32) / 255.0
|
| 1407 |
+
back_to_uint8 = True
|
| 1408 |
+
gray_color = np.array([gv, gv, gv], dtype=np.float32)
|
| 1409 |
+
# boolean mask is (H,W); f[mask] -> (K,3), broadcast with gray_color (3,)
|
| 1410 |
+
f[mask] = (1.0 - a) * f[mask] + a * gray_color
|
| 1411 |
+
if back_to_uint8:
|
| 1412 |
+
out[i] = (f * 255.0).clip(0, 255).astype(src.dtype)
|
| 1413 |
+
else:
|
| 1414 |
+
out[i] = f.astype(src.dtype)
|
| 1415 |
+
return out
|
| 1416 |
+
|
| 1417 |
+
def get_batch(self, idx):
|
| 1418 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 1419 |
+
data_type = data_info.get('type', 'video')
|
| 1420 |
+
|
| 1421 |
+
# Handle None or empty instruction with safety fallback
|
| 1422 |
+
raw_text = data_info.get('text', '')
|
| 1423 |
+
if raw_text is None or (isinstance(raw_text, str) and not raw_text.strip()):
|
| 1424 |
+
raw_text = "the content has been modified"
|
| 1425 |
+
|
| 1426 |
+
if data_type == 'video':
|
| 1427 |
+
# Video triplet branch: original + grounded + edited
|
| 1428 |
+
src_rel = data_info['original_video']
|
| 1429 |
+
# Support both 'grounded_video' and 'ground_video' keys
|
| 1430 |
+
ground_rel = data_info.get('grounded_video', data_info.get('ground_video'))
|
| 1431 |
+
tgt_rel = data_info['edited_video']
|
| 1432 |
+
|
| 1433 |
+
if self.data_root is not None:
|
| 1434 |
+
src_path = os.path.join(self.data_root, src_rel)
|
| 1435 |
+
ground_path = os.path.join(self.data_root, ground_rel)
|
| 1436 |
+
tgt_path = os.path.join(self.data_root, tgt_rel)
|
| 1437 |
+
else:
|
| 1438 |
+
src_path = src_rel
|
| 1439 |
+
ground_path = ground_rel
|
| 1440 |
+
tgt_path = tgt_rel
|
| 1441 |
+
|
| 1442 |
+
# Force use CPU decoder to read all frames
|
| 1443 |
+
from decord import cpu
|
| 1444 |
+
with VideoReader_contextmanager(src_path, num_threads=2, ctx=cpu(0)) as src_reader, \
|
| 1445 |
+
VideoReader_contextmanager(ground_path, num_threads=2, ctx=cpu(0)) as ground_reader, \
|
| 1446 |
+
VideoReader_contextmanager(tgt_path, num_threads=2, ctx=cpu(0)) as tgt_reader:
|
| 1447 |
+
|
| 1448 |
+
# Get video lengths
|
| 1449 |
+
src_length = len(src_reader)
|
| 1450 |
+
ground_length = len(ground_reader)
|
| 1451 |
+
tgt_length = len(tgt_reader)
|
| 1452 |
+
|
| 1453 |
+
# Check if video has enough frames
|
| 1454 |
+
if src_length < self.source_frames:
|
| 1455 |
+
raise ValueError(f"Source video only has {src_length} frames, but requested {self.source_frames}")
|
| 1456 |
+
if tgt_length < self.target_frames:
|
| 1457 |
+
raise ValueError(f"Target video only has {tgt_length} frames, but requested {self.target_frames}")
|
| 1458 |
+
|
| 1459 |
+
# Unified sampling strategy: start from beginning
|
| 1460 |
+
start_idx = 0
|
| 1461 |
+
|
| 1462 |
+
# Sample source frames
|
| 1463 |
+
src_indices = np.linspace(
|
| 1464 |
+
start_idx,
|
| 1465 |
+
min(start_idx + (self.source_frames - 1) * self.video_sample_stride, src_length - 1),
|
| 1466 |
+
self.source_frames,
|
| 1467 |
+
dtype=int
|
| 1468 |
+
)
|
| 1469 |
+
|
| 1470 |
+
# Sample target frames
|
| 1471 |
+
tgt_indices = np.linspace(
|
| 1472 |
+
start_idx,
|
| 1473 |
+
min(start_idx + (self.target_frames - 1) * self.video_sample_stride, tgt_length - 1),
|
| 1474 |
+
self.target_frames,
|
| 1475 |
+
dtype=int
|
| 1476 |
+
)
|
| 1477 |
+
|
| 1478 |
+
# Read batches with timeout
|
| 1479 |
+
try:
|
| 1480 |
+
src_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(src_reader, src_indices))
|
| 1481 |
+
tgt_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(tgt_reader, tgt_indices))
|
| 1482 |
+
|
| 1483 |
+
if self.enable_gradual_ground:
|
| 1484 |
+
# Interpolate between first frame of grounded and edited videos
|
| 1485 |
+
ground_first = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(ground_reader, [0]))
|
| 1486 |
+
# Use the first decoded edited frame if available to avoid double decode
|
| 1487 |
+
tgt_first_frame = tgt_frames[0]
|
| 1488 |
+
# steps: 0..15, pick 0,3,6,9,12 -> 5 grounding frames
|
| 1489 |
+
ground_frames = self._interpolate_ground_frames(
|
| 1490 |
+
ground_first=ground_first[0],
|
| 1491 |
+
target_first=tgt_first_frame,
|
| 1492 |
+
total_steps=16,
|
| 1493 |
+
pick_indices=(0, 3, 6, 9, 12),
|
| 1494 |
+
)
|
| 1495 |
+
else:
|
| 1496 |
+
# # Original behavior: sample grounding frames evenly by stride
|
| 1497 |
+
# ground_indices = np.linspace(
|
| 1498 |
+
# start_idx,
|
| 1499 |
+
# min(start_idx + (self.reasoning_frames - 1) * self.video_sample_stride, ground_length - 1),
|
| 1500 |
+
# self.reasoning_frames,
|
| 1501 |
+
# dtype=int
|
| 1502 |
+
# )
|
| 1503 |
+
|
| 1504 |
+
#==============================================================
|
| 1505 |
+
# New behavior: ground_indices are the first 'reasoning_frames' from src_indices
|
| 1506 |
+
ground_indices = src_indices[:self.reasoning_frames]
|
| 1507 |
+
|
| 1508 |
+
# --- 增加这个重要的安全检查 ---
|
| 1509 |
+
# 确保我们想采样的最后一帧 (ground_indices[-1])
|
| 1510 |
+
# 没有超出 ground_video 的总长度 (ground_length)
|
| 1511 |
+
if len(ground_indices) > 0 and ground_indices[-1] >= ground_length:
|
| 1512 |
+
raise ValueError(
|
| 1513 |
+
f"Data inconsistency error: Ground video has only {ground_length} frames, "
|
| 1514 |
+
f"but the source-based sampling (stride={self.video_sample_stride}) "
|
| 1515 |
+
f"requires reading up to frame {ground_indices[-1]}. "
|
| 1516 |
+
f"File: {ground_path}"
|
| 1517 |
+
)
|
| 1518 |
+
ground_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(ground_reader, ground_indices))
|
| 1519 |
+
except FunctionTimedOut:
|
| 1520 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 1521 |
+
except Exception as e:
|
| 1522 |
+
raise ValueError(f"Failed to extract frames from triplet. Error is {e}.")
|
| 1523 |
+
|
| 1524 |
+
# Align HxW among source, ground, and target to enable concat
|
| 1525 |
+
sh, sw = src_frames.shape[1], src_frames.shape[2]
|
| 1526 |
+
gh, gw = ground_frames.shape[1], ground_frames.shape[2]
|
| 1527 |
+
th, tw = tgt_frames.shape[1], tgt_frames.shape[2]
|
| 1528 |
+
target_h = min(sh, gh, th)
|
| 1529 |
+
target_w = min(sw, gw, tw)
|
| 1530 |
+
|
| 1531 |
+
if (sh != target_h or sw != target_w):
|
| 1532 |
+
src_frames = self._resize_and_center_crop_batch(src_frames, target_h, target_w)
|
| 1533 |
+
if (gh != target_h or gw != target_w):
|
| 1534 |
+
ground_frames = self._resize_and_center_crop_batch(ground_frames, target_h, target_w)
|
| 1535 |
+
if (th != target_h or tw != target_w):
|
| 1536 |
+
tgt_frames = self._resize_and_center_crop_batch(tgt_frames, target_h, target_w)
|
| 1537 |
+
|
| 1538 |
+
if self.enable_gray_red_mask or self.enable_gray_black_background:
|
| 1539 |
+
effect_mode = "red" if self.enable_gray_red_mask else "black"
|
| 1540 |
+
ground_frames = self._apply_gray_region_effect(ground_frames, effect_mode)
|
| 1541 |
+
elif self.enable_gray_alpha_overlay:
|
| 1542 |
+
# Use gray regions detected on grounding frames to overlay 50% gray on the
|
| 1543 |
+
# first 4 frames of the original video.
|
| 1544 |
+
ground_frames = self._apply_gray_overlay_from_reference(
|
| 1545 |
+
src_frames, ground_frames, alpha=self.gray_alpha, gray_value=0.5, num_frames=4
|
| 1546 |
+
)
|
| 1547 |
+
|
| 1548 |
+
if not self.enable_bucket:
|
| 1549 |
+
src_tensor = torch.from_numpy(src_frames).permute(0, 3, 1, 2).contiguous() / 255.
|
| 1550 |
+
ground_tensor = torch.from_numpy(ground_frames).permute(0, 3, 1, 2).contiguous() / 255.
|
| 1551 |
+
tgt_tensor = torch.from_numpy(tgt_frames).permute(0, 3, 1, 2).contiguous() / 255.
|
| 1552 |
+
|
| 1553 |
+
src_tensor = self.video_transforms(src_tensor)
|
| 1554 |
+
ground_tensor = self.video_transforms(ground_tensor)
|
| 1555 |
+
tgt_tensor = self.video_transforms(tgt_tensor)
|
| 1556 |
+
else:
|
| 1557 |
+
src_tensor = src_frames
|
| 1558 |
+
ground_tensor = ground_frames
|
| 1559 |
+
tgt_tensor = tgt_frames
|
| 1560 |
+
# Prepare text with template
|
| 1561 |
+
ground_instr = self._derive_ground_instruction(raw_text)
|
| 1562 |
+
if self.instruction_template and "{edit_instruction}" in self.instruction_template:
|
| 1563 |
+
text = self.instruction_template.format(
|
| 1564 |
+
edit_instruction=raw_text,
|
| 1565 |
+
ground_instruction=ground_instr
|
| 1566 |
+
)
|
| 1567 |
+
else:
|
| 1568 |
+
text = raw_text
|
| 1569 |
+
|
| 1570 |
+
# Random text drop
|
| 1571 |
+
if random.random() < self.text_drop_ratio:
|
| 1572 |
+
text = ''
|
| 1573 |
+
|
| 1574 |
+
return src_tensor, ground_tensor, tgt_tensor, text, 'video'
|
| 1575 |
+
|
| 1576 |
+
else:
|
| 1577 |
+
# Image pair branch (simple like ImageVideoEditDataset)
|
| 1578 |
+
src_img_rel = data_info.get('original_image')
|
| 1579 |
+
tgt_img_rel = data_info.get('edited_image')
|
| 1580 |
+
if src_img_rel is None or tgt_img_rel is None:
|
| 1581 |
+
raise ValueError('Missing original_image/edited_image for image sample')
|
| 1582 |
+
|
| 1583 |
+
if self.data_root is not None:
|
| 1584 |
+
src_img_path = os.path.join(self.data_root, src_img_rel)
|
| 1585 |
+
tgt_img_path = os.path.join(self.data_root, tgt_img_rel)
|
| 1586 |
+
else:
|
| 1587 |
+
src_img_path = src_img_rel
|
| 1588 |
+
tgt_img_path = tgt_img_rel
|
| 1589 |
+
|
| 1590 |
+
src_img = Image.open(src_img_path).convert('RGB')
|
| 1591 |
+
tgt_img = Image.open(tgt_img_path).convert('RGB')
|
| 1592 |
+
|
| 1593 |
+
if not self.enable_bucket:
|
| 1594 |
+
# Apply transforms and add frame dimension
|
| 1595 |
+
src_tensor = self.image_transforms(src_img).unsqueeze(0) # (1, C, H, W)
|
| 1596 |
+
tgt_tensor = self.image_transforms(tgt_img).unsqueeze(0) # (1, C, H, W)
|
| 1597 |
+
else:
|
| 1598 |
+
# For bucket mode, keep as numpy and add frame dimension
|
| 1599 |
+
src_tensor = np.expand_dims(np.array(src_img), axis=0) # (1, H, W, C)
|
| 1600 |
+
tgt_tensor = np.expand_dims(np.array(tgt_img), axis=0) # (1, H, W, C)
|
| 1601 |
+
|
| 1602 |
+
# Apply instruction template if available
|
| 1603 |
+
if self.instruction_template and "{edit_instruction}" in self.instruction_template:
|
| 1604 |
+
text = self.instruction_template.format(edit_instruction=raw_text, ground_instruction="")
|
| 1605 |
+
else:
|
| 1606 |
+
text = raw_text
|
| 1607 |
+
|
| 1608 |
+
if random.random() < self.text_drop_ratio:
|
| 1609 |
+
text = ''
|
| 1610 |
+
|
| 1611 |
+
# For images, ground_tensor is None
|
| 1612 |
+
return src_tensor, None, tgt_tensor, text, 'image'
|
| 1613 |
+
|
| 1614 |
+
def __len__(self):
|
| 1615 |
+
return self.length
|
| 1616 |
+
|
| 1617 |
+
def __getitem__(self, idx):
|
| 1618 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 1619 |
+
data_type = data_info.get('type', 'video')
|
| 1620 |
+
while True:
|
| 1621 |
+
sample = {}
|
| 1622 |
+
try:
|
| 1623 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 1624 |
+
data_type_local = data_info_local.get('type', 'video')
|
| 1625 |
+
if data_type_local != data_type:
|
| 1626 |
+
raise ValueError("data_type_local != data_type")
|
| 1627 |
+
|
| 1628 |
+
result = self.get_batch(idx)
|
| 1629 |
+
|
| 1630 |
+
if data_type == 'video':
|
| 1631 |
+
src_vals, ground_vals, tgt_vals, name, data_type = result
|
| 1632 |
+
sample["pixel_values_src_video"] = src_vals
|
| 1633 |
+
sample["pixel_values_ground_video"] = ground_vals
|
| 1634 |
+
sample["pixel_values_tgt_video"] = tgt_vals
|
| 1635 |
+
else:
|
| 1636 |
+
src_vals, _, tgt_vals, name, data_type = result
|
| 1637 |
+
sample["pixel_values_src_image"] = src_vals
|
| 1638 |
+
sample["pixel_values_tgt_image"] = tgt_vals
|
| 1639 |
+
|
| 1640 |
+
sample["text"] = name
|
| 1641 |
+
sample["data_type"] = data_type
|
| 1642 |
+
sample["idx"] = idx
|
| 1643 |
+
|
| 1644 |
+
if len(sample) > 0:
|
| 1645 |
+
break
|
| 1646 |
+
except Exception as e:
|
| 1647 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 1648 |
+
idx = random.randint(0, self.length-1)
|
| 1649 |
+
|
| 1650 |
+
return sample
|
| 1651 |
+
|
| 1652 |
+
def padding_image(images, new_width, new_height):
|
| 1653 |
+
new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255))
|
| 1654 |
+
|
| 1655 |
+
aspect_ratio = images.width / images.height
|
| 1656 |
+
if new_width / new_height > 1:
|
| 1657 |
+
if aspect_ratio > new_width / new_height:
|
| 1658 |
+
new_img_width = new_width
|
| 1659 |
+
new_img_height = int(new_img_width / aspect_ratio)
|
| 1660 |
+
else:
|
| 1661 |
+
new_img_height = new_height
|
| 1662 |
+
new_img_width = int(new_img_height * aspect_ratio)
|
| 1663 |
+
else:
|
| 1664 |
+
if aspect_ratio > new_width / new_height:
|
| 1665 |
+
new_img_width = new_width
|
| 1666 |
+
new_img_height = int(new_img_width / aspect_ratio)
|
| 1667 |
+
else:
|
| 1668 |
+
new_img_height = new_height
|
| 1669 |
+
new_img_width = int(new_img_height * aspect_ratio)
|
| 1670 |
+
|
| 1671 |
+
resized_img = images.resize((new_img_width, new_img_height))
|
| 1672 |
+
|
| 1673 |
+
paste_x = (new_width - new_img_width) // 2
|
| 1674 |
+
paste_y = (new_height - new_img_height) // 2
|
| 1675 |
+
|
| 1676 |
+
new_image.paste(resized_img, (paste_x, paste_y))
|
| 1677 |
+
|
| 1678 |
+
return new_image
|
| 1679 |
+
|
| 1680 |
+
class ImageVideoControlDataset(Dataset):
|
| 1681 |
+
def __init__(
|
| 1682 |
+
self,
|
| 1683 |
+
ann_path, data_root=None,
|
| 1684 |
+
video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
|
| 1685 |
+
image_sample_size=512,
|
| 1686 |
+
video_repeat=0,
|
| 1687 |
+
text_drop_ratio=0.1,
|
| 1688 |
+
enable_bucket=False,
|
| 1689 |
+
video_length_drop_start=0.1,
|
| 1690 |
+
video_length_drop_end=0.9,
|
| 1691 |
+
enable_inpaint=False,
|
| 1692 |
+
enable_camera_info=False,
|
| 1693 |
+
):
|
| 1694 |
+
# Loading annotations from files
|
| 1695 |
+
if ann_path.endswith('.csv'):
|
| 1696 |
+
with open(ann_path, 'r') as csvfile:
|
| 1697 |
+
dataset = list(csv.DictReader(csvfile))
|
| 1698 |
+
elif ann_path.endswith('.json'):
|
| 1699 |
+
dataset = json.load(open(ann_path))
|
| 1700 |
+
|
| 1701 |
+
self.data_root = data_root
|
| 1702 |
+
|
| 1703 |
+
# It's used to balance num of images and videos.
|
| 1704 |
+
if video_repeat > 0:
|
| 1705 |
+
self.dataset = []
|
| 1706 |
+
for data in dataset:
|
| 1707 |
+
if data.get('type', 'image') != 'video':
|
| 1708 |
+
self.dataset.append(data)
|
| 1709 |
+
|
| 1710 |
+
for _ in range(video_repeat):
|
| 1711 |
+
for data in dataset:
|
| 1712 |
+
if data.get('type', 'image') == 'video':
|
| 1713 |
+
self.dataset.append(data)
|
| 1714 |
+
else:
|
| 1715 |
+
self.dataset = dataset
|
| 1716 |
+
del dataset
|
| 1717 |
+
|
| 1718 |
+
self.length = len(self.dataset)
|
| 1719 |
+
print(f"data scale: {self.length}")
|
| 1720 |
+
# TODO: enable bucket training
|
| 1721 |
+
self.enable_bucket = enable_bucket
|
| 1722 |
+
self.text_drop_ratio = text_drop_ratio
|
| 1723 |
+
self.enable_inpaint = enable_inpaint
|
| 1724 |
+
self.enable_camera_info = enable_camera_info
|
| 1725 |
+
|
| 1726 |
+
self.video_length_drop_start = video_length_drop_start
|
| 1727 |
+
self.video_length_drop_end = video_length_drop_end
|
| 1728 |
+
|
| 1729 |
+
# Video params
|
| 1730 |
+
self.video_sample_stride = video_sample_stride
|
| 1731 |
+
self.video_sample_n_frames = video_sample_n_frames
|
| 1732 |
+
self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
|
| 1733 |
+
self.video_transforms = transforms.Compose(
|
| 1734 |
+
[
|
| 1735 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 1736 |
+
transforms.CenterCrop(self.video_sample_size),
|
| 1737 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 1738 |
+
]
|
| 1739 |
+
)
|
| 1740 |
+
if self.enable_camera_info:
|
| 1741 |
+
self.video_transforms_camera = transforms.Compose(
|
| 1742 |
+
[
|
| 1743 |
+
transforms.Resize(min(self.video_sample_size)),
|
| 1744 |
+
transforms.CenterCrop(self.video_sample_size)
|
| 1745 |
+
]
|
| 1746 |
+
)
|
| 1747 |
+
|
| 1748 |
+
# Image params
|
| 1749 |
+
self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
|
| 1750 |
+
self.image_transforms = transforms.Compose([
|
| 1751 |
+
transforms.Resize(min(self.image_sample_size)),
|
| 1752 |
+
transforms.CenterCrop(self.image_sample_size),
|
| 1753 |
+
transforms.ToTensor(),
|
| 1754 |
+
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
|
| 1755 |
+
])
|
| 1756 |
+
|
| 1757 |
+
self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
|
| 1758 |
+
|
| 1759 |
+
def get_batch(self, idx):
|
| 1760 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 1761 |
+
video_id, text = data_info['file_path'], data_info['text']
|
| 1762 |
+
|
| 1763 |
+
if data_info.get('type', 'image')=='video':
|
| 1764 |
+
if self.data_root is None:
|
| 1765 |
+
video_dir = video_id
|
| 1766 |
+
else:
|
| 1767 |
+
video_dir = os.path.join(self.data_root, video_id)
|
| 1768 |
+
|
| 1769 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 1770 |
+
min_sample_n_frames = min(
|
| 1771 |
+
self.video_sample_n_frames,
|
| 1772 |
+
int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
|
| 1773 |
+
)
|
| 1774 |
+
if min_sample_n_frames == 0:
|
| 1775 |
+
raise ValueError(f"No Frames in video.")
|
| 1776 |
+
|
| 1777 |
+
video_length = int(self.video_length_drop_end * len(video_reader))
|
| 1778 |
+
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
|
| 1779 |
+
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
|
| 1780 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
|
| 1781 |
+
|
| 1782 |
+
try:
|
| 1783 |
+
sample_args = (video_reader, batch_index)
|
| 1784 |
+
pixel_values = func_timeout(
|
| 1785 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 1786 |
+
)
|
| 1787 |
+
resized_frames = []
|
| 1788 |
+
for i in range(len(pixel_values)):
|
| 1789 |
+
frame = pixel_values[i]
|
| 1790 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 1791 |
+
resized_frames.append(resized_frame)
|
| 1792 |
+
pixel_values = np.array(resized_frames)
|
| 1793 |
+
except FunctionTimedOut:
|
| 1794 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 1795 |
+
except Exception as e:
|
| 1796 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 1797 |
+
|
| 1798 |
+
if not self.enable_bucket:
|
| 1799 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 1800 |
+
pixel_values = pixel_values / 255.
|
| 1801 |
+
del video_reader
|
| 1802 |
+
else:
|
| 1803 |
+
pixel_values = pixel_values
|
| 1804 |
+
|
| 1805 |
+
if not self.enable_bucket:
|
| 1806 |
+
pixel_values = self.video_transforms(pixel_values)
|
| 1807 |
+
|
| 1808 |
+
# Random use no text generation
|
| 1809 |
+
if random.random() < self.text_drop_ratio:
|
| 1810 |
+
text = ''
|
| 1811 |
+
|
| 1812 |
+
control_video_id = data_info['control_file_path']
|
| 1813 |
+
|
| 1814 |
+
if self.data_root is None:
|
| 1815 |
+
control_video_id = control_video_id
|
| 1816 |
+
else:
|
| 1817 |
+
control_video_id = os.path.join(self.data_root, control_video_id)
|
| 1818 |
+
|
| 1819 |
+
if self.enable_camera_info:
|
| 1820 |
+
if control_video_id.lower().endswith('.txt'):
|
| 1821 |
+
if not self.enable_bucket:
|
| 1822 |
+
control_pixel_values = torch.zeros_like(pixel_values)
|
| 1823 |
+
|
| 1824 |
+
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0])
|
| 1825 |
+
control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous()
|
| 1826 |
+
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)
|
| 1827 |
+
control_camera_values = self.video_transforms_camera(control_camera_values)
|
| 1828 |
+
else:
|
| 1829 |
+
control_pixel_values = np.zeros_like(pixel_values)
|
| 1830 |
+
|
| 1831 |
+
control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True)
|
| 1832 |
+
control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0)
|
| 1833 |
+
control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0]
|
| 1834 |
+
control_camera_values = np.array([control_camera_values[index] for index in batch_index])
|
| 1835 |
+
else:
|
| 1836 |
+
if not self.enable_bucket:
|
| 1837 |
+
control_pixel_values = torch.zeros_like(pixel_values)
|
| 1838 |
+
control_camera_values = None
|
| 1839 |
+
else:
|
| 1840 |
+
control_pixel_values = np.zeros_like(pixel_values)
|
| 1841 |
+
control_camera_values = None
|
| 1842 |
+
else:
|
| 1843 |
+
with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
|
| 1844 |
+
try:
|
| 1845 |
+
sample_args = (control_video_reader, batch_index)
|
| 1846 |
+
control_pixel_values = func_timeout(
|
| 1847 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 1848 |
+
)
|
| 1849 |
+
resized_frames = []
|
| 1850 |
+
for i in range(len(control_pixel_values)):
|
| 1851 |
+
frame = control_pixel_values[i]
|
| 1852 |
+
resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
|
| 1853 |
+
resized_frames.append(resized_frame)
|
| 1854 |
+
control_pixel_values = np.array(resized_frames)
|
| 1855 |
+
except FunctionTimedOut:
|
| 1856 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 1857 |
+
except Exception as e:
|
| 1858 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 1859 |
+
|
| 1860 |
+
if not self.enable_bucket:
|
| 1861 |
+
control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 1862 |
+
control_pixel_values = control_pixel_values / 255.
|
| 1863 |
+
del control_video_reader
|
| 1864 |
+
else:
|
| 1865 |
+
control_pixel_values = control_pixel_values
|
| 1866 |
+
|
| 1867 |
+
if not self.enable_bucket:
|
| 1868 |
+
control_pixel_values = self.video_transforms(control_pixel_values)
|
| 1869 |
+
control_camera_values = None
|
| 1870 |
+
|
| 1871 |
+
return pixel_values, control_pixel_values, control_camera_values, text, "video"
|
| 1872 |
+
else:
|
| 1873 |
+
image_path, text = data_info['file_path'], data_info['text']
|
| 1874 |
+
if self.data_root is not None:
|
| 1875 |
+
image_path = os.path.join(self.data_root, image_path)
|
| 1876 |
+
image = Image.open(image_path).convert('RGB')
|
| 1877 |
+
if not self.enable_bucket:
|
| 1878 |
+
image = self.image_transforms(image).unsqueeze(0)
|
| 1879 |
+
else:
|
| 1880 |
+
image = np.expand_dims(np.array(image), 0)
|
| 1881 |
+
|
| 1882 |
+
if random.random() < self.text_drop_ratio:
|
| 1883 |
+
text = ''
|
| 1884 |
+
|
| 1885 |
+
control_image_id = data_info['control_file_path']
|
| 1886 |
+
|
| 1887 |
+
if self.image_root is None:
|
| 1888 |
+
control_image_id = control_image_id
|
| 1889 |
+
else:
|
| 1890 |
+
control_image_id = os.path.join(self.image_root, control_image_id)
|
| 1891 |
+
|
| 1892 |
+
control_image = Image.open(control_image_id).convert('RGB')
|
| 1893 |
+
if not self.enable_bucket:
|
| 1894 |
+
control_image = self.image_transforms(control_image).unsqueeze(0)
|
| 1895 |
+
else:
|
| 1896 |
+
control_image = np.expand_dims(np.array(control_image), 0)
|
| 1897 |
+
return image, control_image, None, text, 'image'
|
| 1898 |
+
def __len__(self):
|
| 1899 |
+
return self.length
|
| 1900 |
+
|
| 1901 |
+
def __getitem__(self, idx):
|
| 1902 |
+
data_info = self.dataset[idx % len(self.dataset)]
|
| 1903 |
+
data_type = data_info.get('type', 'image')
|
| 1904 |
+
while True:
|
| 1905 |
+
sample = {}
|
| 1906 |
+
try:
|
| 1907 |
+
data_info_local = self.dataset[idx % len(self.dataset)]
|
| 1908 |
+
data_type_local = data_info_local.get('type', 'image')
|
| 1909 |
+
if data_type_local != data_type:
|
| 1910 |
+
raise ValueError("data_type_local != data_type")
|
| 1911 |
+
|
| 1912 |
+
pixel_values, control_pixel_values, control_camera_values, name, data_type = self.get_batch(idx)
|
| 1913 |
+
|
| 1914 |
+
sample["pixel_values"] = pixel_values
|
| 1915 |
+
sample["control_pixel_values"] = control_pixel_values
|
| 1916 |
+
sample["text"] = name
|
| 1917 |
+
sample["data_type"] = data_type
|
| 1918 |
+
sample["idx"] = idx
|
| 1919 |
+
|
| 1920 |
+
if self.enable_camera_info:
|
| 1921 |
+
sample["control_camera_values"] = control_camera_values
|
| 1922 |
+
|
| 1923 |
+
if len(sample) > 0:
|
| 1924 |
+
break
|
| 1925 |
+
except Exception as e:
|
| 1926 |
+
print(e, self.dataset[idx % len(self.dataset)])
|
| 1927 |
+
idx = random.randint(0, self.length-1)
|
| 1928 |
+
|
| 1929 |
+
if self.enable_inpaint and not self.enable_bucket:
|
| 1930 |
+
mask = get_random_mask(pixel_values.size())
|
| 1931 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
|
| 1932 |
+
sample["mask_pixel_values"] = mask_pixel_values
|
| 1933 |
+
sample["mask"] = mask
|
| 1934 |
+
|
| 1935 |
+
clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
|
| 1936 |
+
clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
|
| 1937 |
+
sample["clip_pixel_values"] = clip_pixel_values
|
| 1938 |
+
|
| 1939 |
+
return sample
|
videox_fun/data/dataset_video.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import gc
|
| 3 |
+
import io
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from threading import Thread
|
| 10 |
+
|
| 11 |
+
import albumentations
|
| 12 |
+
import cv2
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torchvision.transforms as transforms
|
| 16 |
+
from decord import VideoReader
|
| 17 |
+
from einops import rearrange
|
| 18 |
+
from func_timeout import FunctionTimedOut, func_timeout
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from torch.utils.data import BatchSampler, Sampler
|
| 21 |
+
from torch.utils.data.dataset import Dataset
|
| 22 |
+
|
| 23 |
+
VIDEO_READER_TIMEOUT = 20
|
| 24 |
+
|
| 25 |
+
def get_random_mask(shape):
|
| 26 |
+
f, c, h, w = shape
|
| 27 |
+
|
| 28 |
+
mask_index = np.random.randint(0, 4)
|
| 29 |
+
mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
|
| 30 |
+
if mask_index == 0:
|
| 31 |
+
mask[1:, :, :, :] = 1
|
| 32 |
+
elif mask_index == 1:
|
| 33 |
+
mask_frame_index = 1
|
| 34 |
+
mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
|
| 35 |
+
elif mask_index == 2:
|
| 36 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 37 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 38 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
| 39 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
| 40 |
+
|
| 41 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 42 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 43 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 44 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 45 |
+
mask[:, :, start_y:end_y, start_x:end_x] = 1
|
| 46 |
+
elif mask_index == 3:
|
| 47 |
+
center_x = torch.randint(0, w, (1,)).item()
|
| 48 |
+
center_y = torch.randint(0, h, (1,)).item()
|
| 49 |
+
block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
|
| 50 |
+
block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
|
| 51 |
+
|
| 52 |
+
start_x = max(center_x - block_size_x // 2, 0)
|
| 53 |
+
end_x = min(center_x + block_size_x // 2, w)
|
| 54 |
+
start_y = max(center_y - block_size_y // 2, 0)
|
| 55 |
+
end_y = min(center_y + block_size_y // 2, h)
|
| 56 |
+
|
| 57 |
+
mask_frame_before = np.random.randint(0, f // 2)
|
| 58 |
+
mask_frame_after = np.random.randint(f // 2, f)
|
| 59 |
+
mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError(f"The mask_index {mask_index} is not define")
|
| 62 |
+
return mask
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@contextmanager
|
| 66 |
+
def VideoReader_contextmanager(*args, **kwargs):
|
| 67 |
+
vr = VideoReader(*args, **kwargs)
|
| 68 |
+
try:
|
| 69 |
+
yield vr
|
| 70 |
+
finally:
|
| 71 |
+
del vr
|
| 72 |
+
gc.collect()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_video_reader_batch(video_reader, batch_index):
|
| 76 |
+
frames = video_reader.get_batch(batch_index).asnumpy()
|
| 77 |
+
return frames
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class WebVid10M(Dataset):
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
csv_path, video_folder,
|
| 84 |
+
sample_size=256, sample_stride=4, sample_n_frames=16,
|
| 85 |
+
enable_bucket=False, enable_inpaint=False, is_image=False,
|
| 86 |
+
):
|
| 87 |
+
print(f"loading annotations from {csv_path} ...")
|
| 88 |
+
with open(csv_path, 'r') as csvfile:
|
| 89 |
+
self.dataset = list(csv.DictReader(csvfile))
|
| 90 |
+
self.length = len(self.dataset)
|
| 91 |
+
print(f"data scale: {self.length}")
|
| 92 |
+
|
| 93 |
+
self.video_folder = video_folder
|
| 94 |
+
self.sample_stride = sample_stride
|
| 95 |
+
self.sample_n_frames = sample_n_frames
|
| 96 |
+
self.enable_bucket = enable_bucket
|
| 97 |
+
self.enable_inpaint = enable_inpaint
|
| 98 |
+
self.is_image = is_image
|
| 99 |
+
|
| 100 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
| 101 |
+
self.pixel_transforms = transforms.Compose([
|
| 102 |
+
transforms.Resize(sample_size[0]),
|
| 103 |
+
transforms.CenterCrop(sample_size),
|
| 104 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 105 |
+
])
|
| 106 |
+
|
| 107 |
+
def get_batch(self, idx):
|
| 108 |
+
video_dict = self.dataset[idx]
|
| 109 |
+
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
|
| 110 |
+
|
| 111 |
+
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
|
| 112 |
+
video_reader = VideoReader(video_dir)
|
| 113 |
+
video_length = len(video_reader)
|
| 114 |
+
|
| 115 |
+
if not self.is_image:
|
| 116 |
+
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
|
| 117 |
+
start_idx = random.randint(0, video_length - clip_length)
|
| 118 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
| 119 |
+
else:
|
| 120 |
+
batch_index = [random.randint(0, video_length - 1)]
|
| 121 |
+
|
| 122 |
+
if not self.enable_bucket:
|
| 123 |
+
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
| 124 |
+
pixel_values = pixel_values / 255.
|
| 125 |
+
del video_reader
|
| 126 |
+
else:
|
| 127 |
+
pixel_values = video_reader.get_batch(batch_index).asnumpy()
|
| 128 |
+
|
| 129 |
+
if self.is_image:
|
| 130 |
+
pixel_values = pixel_values[0]
|
| 131 |
+
return pixel_values, name
|
| 132 |
+
|
| 133 |
+
def __len__(self):
|
| 134 |
+
return self.length
|
| 135 |
+
|
| 136 |
+
def __getitem__(self, idx):
|
| 137 |
+
while True:
|
| 138 |
+
try:
|
| 139 |
+
pixel_values, name = self.get_batch(idx)
|
| 140 |
+
break
|
| 141 |
+
|
| 142 |
+
except Exception as e:
|
| 143 |
+
print("Error info:", e)
|
| 144 |
+
idx = random.randint(0, self.length-1)
|
| 145 |
+
|
| 146 |
+
if not self.enable_bucket:
|
| 147 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
| 148 |
+
if self.enable_inpaint:
|
| 149 |
+
mask = get_random_mask(pixel_values.size())
|
| 150 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
| 151 |
+
sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
|
| 152 |
+
else:
|
| 153 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
| 154 |
+
return sample
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class VideoDataset(Dataset):
|
| 158 |
+
def __init__(
|
| 159 |
+
self,
|
| 160 |
+
json_path, video_folder=None,
|
| 161 |
+
sample_size=256, sample_stride=4, sample_n_frames=16,
|
| 162 |
+
enable_bucket=False, enable_inpaint=False
|
| 163 |
+
):
|
| 164 |
+
print(f"loading annotations from {json_path} ...")
|
| 165 |
+
self.dataset = json.load(open(json_path, 'r'))
|
| 166 |
+
self.length = len(self.dataset)
|
| 167 |
+
print(f"data scale: {self.length}")
|
| 168 |
+
|
| 169 |
+
self.video_folder = video_folder
|
| 170 |
+
self.sample_stride = sample_stride
|
| 171 |
+
self.sample_n_frames = sample_n_frames
|
| 172 |
+
self.enable_bucket = enable_bucket
|
| 173 |
+
self.enable_inpaint = enable_inpaint
|
| 174 |
+
|
| 175 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
| 176 |
+
self.pixel_transforms = transforms.Compose(
|
| 177 |
+
[
|
| 178 |
+
transforms.Resize(sample_size[0]),
|
| 179 |
+
transforms.CenterCrop(sample_size),
|
| 180 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
| 181 |
+
]
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def get_batch(self, idx):
|
| 185 |
+
video_dict = self.dataset[idx]
|
| 186 |
+
video_id, name = video_dict['file_path'], video_dict['text']
|
| 187 |
+
|
| 188 |
+
if self.video_folder is None:
|
| 189 |
+
video_dir = video_id
|
| 190 |
+
else:
|
| 191 |
+
video_dir = os.path.join(self.video_folder, video_id)
|
| 192 |
+
|
| 193 |
+
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
| 194 |
+
video_length = len(video_reader)
|
| 195 |
+
|
| 196 |
+
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
|
| 197 |
+
start_idx = random.randint(0, video_length - clip_length)
|
| 198 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
sample_args = (video_reader, batch_index)
|
| 202 |
+
pixel_values = func_timeout(
|
| 203 |
+
VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
|
| 204 |
+
)
|
| 205 |
+
except FunctionTimedOut:
|
| 206 |
+
raise ValueError(f"Read {idx} timeout.")
|
| 207 |
+
except Exception as e:
|
| 208 |
+
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
| 209 |
+
|
| 210 |
+
if not self.enable_bucket:
|
| 211 |
+
pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
|
| 212 |
+
pixel_values = pixel_values / 255.
|
| 213 |
+
del video_reader
|
| 214 |
+
else:
|
| 215 |
+
pixel_values = pixel_values
|
| 216 |
+
|
| 217 |
+
return pixel_values, name
|
| 218 |
+
|
| 219 |
+
def __len__(self):
|
| 220 |
+
return self.length
|
| 221 |
+
|
| 222 |
+
def __getitem__(self, idx):
|
| 223 |
+
while True:
|
| 224 |
+
try:
|
| 225 |
+
pixel_values, name = self.get_batch(idx)
|
| 226 |
+
break
|
| 227 |
+
|
| 228 |
+
except Exception as e:
|
| 229 |
+
print("Error info:", e)
|
| 230 |
+
idx = random.randint(0, self.length-1)
|
| 231 |
+
|
| 232 |
+
if not self.enable_bucket:
|
| 233 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
| 234 |
+
if self.enable_inpaint:
|
| 235 |
+
mask = get_random_mask(pixel_values.size())
|
| 236 |
+
mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
|
| 237 |
+
sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
|
| 238 |
+
else:
|
| 239 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
| 240 |
+
return sample
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
if __name__ == "__main__":
|
| 244 |
+
if 1:
|
| 245 |
+
dataset = VideoDataset(
|
| 246 |
+
json_path="/home/zhoumo.xjq/disk3/datasets/webvidval/results_2M_val.json",
|
| 247 |
+
sample_size=256,
|
| 248 |
+
sample_stride=4, sample_n_frames=16,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
if 0:
|
| 252 |
+
dataset = WebVid10M(
|
| 253 |
+
csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
|
| 254 |
+
video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
|
| 255 |
+
sample_size=256,
|
| 256 |
+
sample_stride=4, sample_n_frames=16,
|
| 257 |
+
is_image=False,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
|
| 261 |
+
for idx, batch in enumerate(dataloader):
|
| 262 |
+
print(batch["pixel_values"].shape, len(batch["text"]))
|
videox_fun/dist/__init__.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.util
|
| 2 |
+
|
| 3 |
+
from .cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0
|
| 4 |
+
from .fsdp import shard_model
|
| 5 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 6 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 7 |
+
get_world_group, init_distributed_environment,
|
| 8 |
+
initialize_model_parallel, set_multi_gpus_devices,
|
| 9 |
+
xFuserLongContextAttention)
|
| 10 |
+
from .wan_xfuser import usp_attn_forward, usp_attn_s2v_forward
|
| 11 |
+
from .qwen_xfuser import QwenImageMultiGPUsAttnProcessor2_0
|
| 12 |
+
from .flux_xfuser import FluxMultiGPUsAttnProcessor2_0
|
| 13 |
+
|
| 14 |
+
# The pai_fuser is an internally developed acceleration package, which can be used on PAI.
|
| 15 |
+
if importlib.util.find_spec("paifuser") is not None:
|
| 16 |
+
# --------------------------------------------------------------- #
|
| 17 |
+
# The simple_wrapper is used to solve the problem
|
| 18 |
+
# about conflicts between cython and torch.compile
|
| 19 |
+
# --------------------------------------------------------------- #
|
| 20 |
+
def simple_wrapper(func):
|
| 21 |
+
def inner(*args, **kwargs):
|
| 22 |
+
return func(*args, **kwargs)
|
| 23 |
+
return inner
|
| 24 |
+
|
| 25 |
+
# --------------------------------------------------------------- #
|
| 26 |
+
# Sparse Attention Kernel
|
| 27 |
+
# --------------------------------------------------------------- #
|
| 28 |
+
from paifuser.models import parallel_magvit_vae
|
| 29 |
+
from paifuser.ops import wan_usp_sparse_attention_wrapper
|
| 30 |
+
from . import wan_xfuser
|
| 31 |
+
|
| 32 |
+
# --------------------------------------------------------------- #
|
| 33 |
+
# Sparse Attention
|
| 34 |
+
# --------------------------------------------------------------- #
|
| 35 |
+
usp_sparse_attn_wrap_forward = simple_wrapper(wan_usp_sparse_attention_wrapper()(wan_xfuser.usp_attn_forward))
|
| 36 |
+
wan_xfuser.usp_attn_forward = usp_sparse_attn_wrap_forward
|
| 37 |
+
usp_attn_forward = usp_sparse_attn_wrap_forward
|
| 38 |
+
print("Import PAI VAE Turbo and Sparse Attention")
|
| 39 |
+
|
| 40 |
+
# --------------------------------------------------------------- #
|
| 41 |
+
# Fast Rope Kernel
|
| 42 |
+
# --------------------------------------------------------------- #
|
| 43 |
+
import types
|
| 44 |
+
import torch
|
| 45 |
+
from paifuser.ops import (ENABLE_KERNEL, usp_fast_rope_apply_qk,
|
| 46 |
+
usp_rope_apply_real_qk)
|
| 47 |
+
|
| 48 |
+
def deepcopy_function(f):
|
| 49 |
+
return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)
|
| 50 |
+
|
| 51 |
+
local_rope_apply_qk = deepcopy_function(wan_xfuser.rope_apply_qk)
|
| 52 |
+
|
| 53 |
+
if ENABLE_KERNEL:
|
| 54 |
+
def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
|
| 55 |
+
if torch.is_grad_enabled():
|
| 56 |
+
return local_rope_apply_qk(q, k, grid_sizes, freqs)
|
| 57 |
+
else:
|
| 58 |
+
return usp_fast_rope_apply_qk(q, k, grid_sizes, freqs)
|
| 59 |
+
|
| 60 |
+
else:
|
| 61 |
+
def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
|
| 62 |
+
return usp_rope_apply_real_qk(q, k, grid_sizes, freqs)
|
| 63 |
+
|
| 64 |
+
wan_xfuser.rope_apply_qk = adaptive_fast_usp_rope_apply_qk
|
| 65 |
+
rope_apply_qk = adaptive_fast_usp_rope_apply_qk
|
| 66 |
+
print("Import PAI Fast rope")
|
videox_fun/dist/cogvideox_xfuser.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from diffusers.models.attention import Attention
|
| 6 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 7 |
+
|
| 8 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 9 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 10 |
+
init_distributed_environment, initialize_model_parallel,
|
| 11 |
+
xFuserLongContextAttention)
|
| 12 |
+
|
| 13 |
+
class CogVideoXMultiGPUsAttnProcessor2_0:
|
| 14 |
+
r"""
|
| 15 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
| 16 |
+
query and key vectors, but does not include spatial normalization.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
if xFuserLongContextAttention is not None:
|
| 21 |
+
try:
|
| 22 |
+
self.hybrid_seq_parallel_attn = xFuserLongContextAttention()
|
| 23 |
+
except Exception:
|
| 24 |
+
self.hybrid_seq_parallel_attn = None
|
| 25 |
+
else:
|
| 26 |
+
self.hybrid_seq_parallel_attn = None
|
| 27 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 28 |
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 29 |
+
|
| 30 |
+
def __call__(
|
| 31 |
+
self,
|
| 32 |
+
attn: Attention,
|
| 33 |
+
hidden_states: torch.Tensor,
|
| 34 |
+
encoder_hidden_states: torch.Tensor,
|
| 35 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 36 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 37 |
+
) -> torch.Tensor:
|
| 38 |
+
text_seq_length = encoder_hidden_states.size(1)
|
| 39 |
+
|
| 40 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 41 |
+
|
| 42 |
+
batch_size, sequence_length, _ = (
|
| 43 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if attention_mask is not None:
|
| 47 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 48 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 49 |
+
|
| 50 |
+
query = attn.to_q(hidden_states)
|
| 51 |
+
key = attn.to_k(hidden_states)
|
| 52 |
+
value = attn.to_v(hidden_states)
|
| 53 |
+
|
| 54 |
+
inner_dim = key.shape[-1]
|
| 55 |
+
head_dim = inner_dim // attn.heads
|
| 56 |
+
|
| 57 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 58 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 59 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 60 |
+
|
| 61 |
+
if attn.norm_q is not None:
|
| 62 |
+
query = attn.norm_q(query)
|
| 63 |
+
if attn.norm_k is not None:
|
| 64 |
+
key = attn.norm_k(key)
|
| 65 |
+
|
| 66 |
+
# Apply RoPE if needed
|
| 67 |
+
if image_rotary_emb is not None:
|
| 68 |
+
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
| 69 |
+
if not attn.is_cross_attention:
|
| 70 |
+
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
| 71 |
+
|
| 72 |
+
if self.hybrid_seq_parallel_attn is None:
|
| 73 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 74 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 75 |
+
)
|
| 76 |
+
hidden_states = hidden_states
|
| 77 |
+
else:
|
| 78 |
+
img_q = query[:, :, text_seq_length:].transpose(1, 2)
|
| 79 |
+
txt_q = query[:, :, :text_seq_length].transpose(1, 2)
|
| 80 |
+
img_k = key[:, :, text_seq_length:].transpose(1, 2)
|
| 81 |
+
txt_k = key[:, :, :text_seq_length].transpose(1, 2)
|
| 82 |
+
img_v = value[:, :, text_seq_length:].transpose(1, 2)
|
| 83 |
+
txt_v = value[:, :, :text_seq_length].transpose(1, 2)
|
| 84 |
+
|
| 85 |
+
hidden_states = self.hybrid_seq_parallel_attn(
|
| 86 |
+
None,
|
| 87 |
+
img_q, img_k, img_v, dropout_p=0.0, causal=False,
|
| 88 |
+
joint_tensor_query=txt_q,
|
| 89 |
+
joint_tensor_key=txt_k,
|
| 90 |
+
joint_tensor_value=txt_v,
|
| 91 |
+
joint_strategy='front',
|
| 92 |
+
).transpose(1, 2)
|
| 93 |
+
|
| 94 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 95 |
+
|
| 96 |
+
# linear proj
|
| 97 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 98 |
+
# dropout
|
| 99 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 100 |
+
|
| 101 |
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
| 102 |
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
| 103 |
+
)
|
| 104 |
+
return hidden_states, encoder_hidden_states
|
| 105 |
+
|
videox_fun/dist/flux_xfuser.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from diffusers.models.attention_processor import Attention
|
| 6 |
+
|
| 7 |
+
from .fuser import xFuserLongContextAttention
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
| 11 |
+
query = attn.to_q(hidden_states)
|
| 12 |
+
key = attn.to_k(hidden_states)
|
| 13 |
+
value = attn.to_v(hidden_states)
|
| 14 |
+
|
| 15 |
+
encoder_query = encoder_key = encoder_value = None
|
| 16 |
+
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
| 17 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 18 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 19 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 20 |
+
|
| 21 |
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
| 25 |
+
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def apply_rotary_emb(
|
| 29 |
+
x: torch.Tensor,
|
| 30 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 31 |
+
use_real: bool = True,
|
| 32 |
+
use_real_unbind_dim: int = -1,
|
| 33 |
+
sequence_dim: int = 2,
|
| 34 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 35 |
+
"""
|
| 36 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 37 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 38 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 39 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
x (`torch.Tensor`):
|
| 43 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| 44 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 48 |
+
"""
|
| 49 |
+
if use_real:
|
| 50 |
+
cos, sin = freqs_cis # [S, D]
|
| 51 |
+
if sequence_dim == 2:
|
| 52 |
+
cos = cos[None, None, :, :]
|
| 53 |
+
sin = sin[None, None, :, :]
|
| 54 |
+
elif sequence_dim == 1:
|
| 55 |
+
cos = cos[None, :, None, :]
|
| 56 |
+
sin = sin[None, :, None, :]
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
| 59 |
+
|
| 60 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 61 |
+
|
| 62 |
+
if use_real_unbind_dim == -1:
|
| 63 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 64 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
|
| 65 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 66 |
+
elif use_real_unbind_dim == -2:
|
| 67 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 68 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
|
| 69 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 72 |
+
|
| 73 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 74 |
+
|
| 75 |
+
return out
|
| 76 |
+
else:
|
| 77 |
+
# used for lumina
|
| 78 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 79 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 80 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 81 |
+
|
| 82 |
+
return x_out.type_as(x)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class FluxMultiGPUsAttnProcessor2_0:
|
| 86 |
+
r"""
|
| 87 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
| 88 |
+
query and key vectors, but does not include spatial normalization.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self):
|
| 92 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 93 |
+
raise ImportError("FluxMultiGPUsAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 94 |
+
|
| 95 |
+
def __call__(
|
| 96 |
+
self,
|
| 97 |
+
attn: "FluxAttention",
|
| 98 |
+
hidden_states: torch.Tensor,
|
| 99 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 100 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 101 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 102 |
+
text_seq_len: int = None,
|
| 103 |
+
) -> torch.FloatTensor:
|
| 104 |
+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
| 105 |
+
attn, hidden_states, encoder_hidden_states
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 109 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 110 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 111 |
+
|
| 112 |
+
query = attn.norm_q(query)
|
| 113 |
+
key = attn.norm_k(key)
|
| 114 |
+
|
| 115 |
+
if attn.added_kv_proj_dim is not None:
|
| 116 |
+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
| 117 |
+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
| 118 |
+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
| 119 |
+
|
| 120 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 121 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 122 |
+
|
| 123 |
+
query = torch.cat([encoder_query, query], dim=1)
|
| 124 |
+
key = torch.cat([encoder_key, key], dim=1)
|
| 125 |
+
value = torch.cat([encoder_value, value], dim=1)
|
| 126 |
+
|
| 127 |
+
if image_rotary_emb is not None:
|
| 128 |
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 129 |
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 130 |
+
|
| 131 |
+
text_seq_len = encoder_query.shape[1]
|
| 132 |
+
txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len]
|
| 133 |
+
img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:]
|
| 134 |
+
else:
|
| 135 |
+
if image_rotary_emb is not None:
|
| 136 |
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 137 |
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 138 |
+
txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len]
|
| 139 |
+
img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:]
|
| 140 |
+
|
| 141 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 142 |
+
def half(x):
|
| 143 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 144 |
+
|
| 145 |
+
hidden_states = xFuserLongContextAttention()(
|
| 146 |
+
None,
|
| 147 |
+
half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
|
| 148 |
+
joint_tensor_query=half(txt_query) if txt_query is not None else None,
|
| 149 |
+
joint_tensor_key=half(txt_key) if txt_key is not None else None,
|
| 150 |
+
joint_tensor_value=half(txt_value) if txt_value is not None else None,
|
| 151 |
+
joint_strategy='front',
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Reshape back
|
| 155 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 156 |
+
hidden_states = hidden_states.to(img_query.dtype)
|
| 157 |
+
|
| 158 |
+
if encoder_hidden_states is not None:
|
| 159 |
+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
| 160 |
+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
| 161 |
+
)
|
| 162 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 163 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 164 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 165 |
+
|
| 166 |
+
return hidden_states, encoder_hidden_states
|
| 167 |
+
else:
|
| 168 |
+
return hidden_states
|
videox_fun/dist/fsdp.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyied from https://github.com/Wan-Video/Wan2.1/blob/main/wan/distributed/fsdp.py
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import gc
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 8 |
+
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
| 9 |
+
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
| 10 |
+
from torch.distributed.utils import _free_storage
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def shard_model(
|
| 14 |
+
model,
|
| 15 |
+
device_id,
|
| 16 |
+
param_dtype=torch.bfloat16,
|
| 17 |
+
reduce_dtype=torch.float32,
|
| 18 |
+
buffer_dtype=torch.float32,
|
| 19 |
+
process_group=None,
|
| 20 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
| 21 |
+
sync_module_states=True,
|
| 22 |
+
module_to_wrapper=None,
|
| 23 |
+
):
|
| 24 |
+
model = FSDP(
|
| 25 |
+
module=model,
|
| 26 |
+
process_group=process_group,
|
| 27 |
+
sharding_strategy=sharding_strategy,
|
| 28 |
+
auto_wrap_policy=partial(
|
| 29 |
+
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks if module_to_wrapper is None else module_to_wrapper),
|
| 30 |
+
mixed_precision=MixedPrecision(
|
| 31 |
+
param_dtype=param_dtype,
|
| 32 |
+
reduce_dtype=reduce_dtype,
|
| 33 |
+
buffer_dtype=buffer_dtype),
|
| 34 |
+
device_id=device_id,
|
| 35 |
+
sync_module_states=sync_module_states)
|
| 36 |
+
return model
|
| 37 |
+
|
| 38 |
+
def free_model(model):
|
| 39 |
+
for m in model.modules():
|
| 40 |
+
if isinstance(m, FSDP):
|
| 41 |
+
_free_storage(m._handle.flat_param.data)
|
| 42 |
+
del model
|
| 43 |
+
gc.collect()
|
| 44 |
+
torch.cuda.empty_cache()
|
videox_fun/dist/fuser.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.util
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
# The pai_fuser is an internally developed acceleration package, which can be used on PAI.
|
| 8 |
+
if importlib.util.find_spec("paifuser") is not None:
|
| 9 |
+
import paifuser
|
| 10 |
+
from paifuser.xfuser.core.distributed import (
|
| 11 |
+
get_sequence_parallel_rank, get_sequence_parallel_world_size,
|
| 12 |
+
get_sp_group, get_world_group, init_distributed_environment,
|
| 13 |
+
initialize_model_parallel)
|
| 14 |
+
from paifuser.xfuser.core.long_ctx_attention import \
|
| 15 |
+
xFuserLongContextAttention
|
| 16 |
+
print("Import PAI DiT Turbo")
|
| 17 |
+
else:
|
| 18 |
+
import xfuser
|
| 19 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
| 20 |
+
get_sequence_parallel_world_size,
|
| 21 |
+
get_sp_group, get_world_group,
|
| 22 |
+
init_distributed_environment,
|
| 23 |
+
initialize_model_parallel)
|
| 24 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
| 25 |
+
print("Xfuser import sucessful")
|
| 26 |
+
except Exception as ex:
|
| 27 |
+
get_sequence_parallel_world_size = None
|
| 28 |
+
get_sequence_parallel_rank = None
|
| 29 |
+
xFuserLongContextAttention = None
|
| 30 |
+
get_sp_group = None
|
| 31 |
+
get_world_group = None
|
| 32 |
+
init_distributed_environment = None
|
| 33 |
+
initialize_model_parallel = None
|
| 34 |
+
|
| 35 |
+
def set_multi_gpus_devices(ulysses_degree, ring_degree, classifier_free_guidance_degree=1):
|
| 36 |
+
if ulysses_degree > 1 or ring_degree > 1 or classifier_free_guidance_degree > 1:
|
| 37 |
+
if get_sp_group is None:
|
| 38 |
+
raise RuntimeError("xfuser is not installed.")
|
| 39 |
+
dist.init_process_group("nccl")
|
| 40 |
+
print('parallel inference enabled: ulysses_degree=%d ring_degree=%d classifier_free_guidance_degree=% rank=%d world_size=%d' % (
|
| 41 |
+
ulysses_degree, ring_degree, classifier_free_guidance_degree, dist.get_rank(),
|
| 42 |
+
dist.get_world_size()))
|
| 43 |
+
assert dist.get_world_size() == ring_degree * ulysses_degree * classifier_free_guidance_degree, \
|
| 44 |
+
"number of GPUs(%d) should be equal to ring_degree * ulysses_degree * classifier_free_guidance_degree." % dist.get_world_size()
|
| 45 |
+
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
| 46 |
+
initialize_model_parallel(sequence_parallel_degree=ring_degree * ulysses_degree,
|
| 47 |
+
classifier_free_guidance_degree=classifier_free_guidance_degree,
|
| 48 |
+
ring_degree=ring_degree,
|
| 49 |
+
ulysses_degree=ulysses_degree)
|
| 50 |
+
# device = torch.device("cuda:%d" % dist.get_rank())
|
| 51 |
+
device = torch.device(f"cuda:{get_world_group().local_rank}")
|
| 52 |
+
print('rank=%d device=%s' % (get_world_group().rank, str(device)))
|
| 53 |
+
else:
|
| 54 |
+
device = "cuda"
|
| 55 |
+
return device
|
videox_fun/dist/qwen_xfuser.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import glob
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import types
|
| 7 |
+
import warnings
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.cuda.amp as amp
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 16 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 17 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 18 |
+
from diffusers.models.attention import FeedForward
|
| 19 |
+
from diffusers.models.attention_processor import Attention
|
| 20 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
| 21 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 22 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 23 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
|
| 24 |
+
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
|
| 25 |
+
scale_lora_layers, unscale_lora_layers)
|
| 26 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 27 |
+
from torch import nn
|
| 28 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 29 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 30 |
+
init_distributed_environment, initialize_model_parallel,
|
| 31 |
+
xFuserLongContextAttention)
|
| 32 |
+
|
| 33 |
+
def apply_rotary_emb_qwen(
|
| 34 |
+
x: torch.Tensor,
|
| 35 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 36 |
+
use_real: bool = True,
|
| 37 |
+
use_real_unbind_dim: int = -1,
|
| 38 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 39 |
+
"""
|
| 40 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 41 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 42 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 43 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
x (`torch.Tensor`):
|
| 47 |
+
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
|
| 48 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 52 |
+
"""
|
| 53 |
+
if use_real:
|
| 54 |
+
cos, sin = freqs_cis # [S, D]
|
| 55 |
+
cos = cos[None, None]
|
| 56 |
+
sin = sin[None, None]
|
| 57 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 58 |
+
|
| 59 |
+
if use_real_unbind_dim == -1:
|
| 60 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 61 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
| 62 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 63 |
+
elif use_real_unbind_dim == -2:
|
| 64 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 65 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
| 66 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 69 |
+
|
| 70 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 71 |
+
|
| 72 |
+
return out
|
| 73 |
+
else:
|
| 74 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 75 |
+
freqs_cis = freqs_cis.unsqueeze(1)
|
| 76 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 77 |
+
|
| 78 |
+
return x_out.type_as(x)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class QwenImageMultiGPUsAttnProcessor2_0:
|
| 82 |
+
r"""
|
| 83 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
| 84 |
+
query and key vectors, but does not include spatial normalization.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self):
|
| 88 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 89 |
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 90 |
+
|
| 91 |
+
def __call__(
|
| 92 |
+
self,
|
| 93 |
+
attn: Attention,
|
| 94 |
+
hidden_states: torch.FloatTensor, # Image stream
|
| 95 |
+
encoder_hidden_states: torch.FloatTensor = None, # Text stream
|
| 96 |
+
encoder_hidden_states_mask: torch.FloatTensor = None,
|
| 97 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 98 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 99 |
+
) -> torch.FloatTensor:
|
| 100 |
+
if encoder_hidden_states is None:
|
| 101 |
+
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
|
| 102 |
+
|
| 103 |
+
seq_txt = encoder_hidden_states.shape[1]
|
| 104 |
+
|
| 105 |
+
# Compute QKV for image stream (sample projections)
|
| 106 |
+
img_query = attn.to_q(hidden_states)
|
| 107 |
+
img_key = attn.to_k(hidden_states)
|
| 108 |
+
img_value = attn.to_v(hidden_states)
|
| 109 |
+
|
| 110 |
+
# Compute QKV for text stream (context projections)
|
| 111 |
+
txt_query = attn.add_q_proj(encoder_hidden_states)
|
| 112 |
+
txt_key = attn.add_k_proj(encoder_hidden_states)
|
| 113 |
+
txt_value = attn.add_v_proj(encoder_hidden_states)
|
| 114 |
+
|
| 115 |
+
# Reshape for multi-head attention
|
| 116 |
+
img_query = img_query.unflatten(-1, (attn.heads, -1))
|
| 117 |
+
img_key = img_key.unflatten(-1, (attn.heads, -1))
|
| 118 |
+
img_value = img_value.unflatten(-1, (attn.heads, -1))
|
| 119 |
+
|
| 120 |
+
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
|
| 121 |
+
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
|
| 122 |
+
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
|
| 123 |
+
|
| 124 |
+
# Apply QK normalization
|
| 125 |
+
if attn.norm_q is not None:
|
| 126 |
+
img_query = attn.norm_q(img_query)
|
| 127 |
+
if attn.norm_k is not None:
|
| 128 |
+
img_key = attn.norm_k(img_key)
|
| 129 |
+
if attn.norm_added_q is not None:
|
| 130 |
+
txt_query = attn.norm_added_q(txt_query)
|
| 131 |
+
if attn.norm_added_k is not None:
|
| 132 |
+
txt_key = attn.norm_added_k(txt_key)
|
| 133 |
+
|
| 134 |
+
# Apply RoPE
|
| 135 |
+
if image_rotary_emb is not None:
|
| 136 |
+
img_freqs, txt_freqs = image_rotary_emb
|
| 137 |
+
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
|
| 138 |
+
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
|
| 139 |
+
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
|
| 140 |
+
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
|
| 141 |
+
|
| 142 |
+
# Concatenate for joint attention
|
| 143 |
+
# Order: [text, image]
|
| 144 |
+
# joint_query = torch.cat([txt_query, img_query], dim=1)
|
| 145 |
+
# joint_key = torch.cat([txt_key, img_key], dim=1)
|
| 146 |
+
# joint_value = torch.cat([txt_value, img_value], dim=1)
|
| 147 |
+
|
| 148 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 149 |
+
def half(x):
|
| 150 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 151 |
+
|
| 152 |
+
joint_hidden_states = xFuserLongContextAttention()(
|
| 153 |
+
None,
|
| 154 |
+
half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
|
| 155 |
+
joint_tensor_query=half(txt_query),
|
| 156 |
+
joint_tensor_key=half(txt_key),
|
| 157 |
+
joint_tensor_value=half(txt_value),
|
| 158 |
+
joint_strategy='front',
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Reshape back
|
| 162 |
+
joint_hidden_states = joint_hidden_states.flatten(2, 3)
|
| 163 |
+
joint_hidden_states = joint_hidden_states.to(img_query.dtype)
|
| 164 |
+
|
| 165 |
+
# Split attention outputs back
|
| 166 |
+
txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
|
| 167 |
+
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
|
| 168 |
+
|
| 169 |
+
# Apply output projections
|
| 170 |
+
img_attn_output = attn.to_out[0](img_attn_output)
|
| 171 |
+
if len(attn.to_out) > 1:
|
| 172 |
+
img_attn_output = attn.to_out[1](img_attn_output) # dropout
|
| 173 |
+
|
| 174 |
+
txt_attn_output = attn.to_add_out(txt_attn_output)
|
| 175 |
+
|
| 176 |
+
return img_attn_output, txt_attn_output
|
videox_fun/dist/wan_xfuser.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.cuda.amp as amp
|
| 3 |
+
|
| 4 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 5 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 6 |
+
init_distributed_environment, initialize_model_parallel,
|
| 7 |
+
xFuserLongContextAttention)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def pad_freqs(original_tensor, target_len):
|
| 11 |
+
seq_len, s1, s2 = original_tensor.shape
|
| 12 |
+
pad_size = target_len - seq_len
|
| 13 |
+
padding_tensor = torch.ones(
|
| 14 |
+
pad_size,
|
| 15 |
+
s1,
|
| 16 |
+
s2,
|
| 17 |
+
dtype=original_tensor.dtype,
|
| 18 |
+
device=original_tensor.device)
|
| 19 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
| 20 |
+
return padded_tensor
|
| 21 |
+
|
| 22 |
+
@amp.autocast(enabled=False)
|
| 23 |
+
@torch.compiler.disable()
|
| 24 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 25 |
+
"""
|
| 26 |
+
x: [B, L, N, C].
|
| 27 |
+
grid_sizes: [B, 3].
|
| 28 |
+
freqs: [M, C // 2].
|
| 29 |
+
"""
|
| 30 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
| 31 |
+
# split freqs
|
| 32 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 33 |
+
|
| 34 |
+
# loop over samples
|
| 35 |
+
output = []
|
| 36 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 37 |
+
seq_len = f * h * w
|
| 38 |
+
|
| 39 |
+
# precompute multipliers
|
| 40 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape(
|
| 41 |
+
s, n, -1, 2))
|
| 42 |
+
freqs_i = torch.cat([
|
| 43 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 44 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 45 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 46 |
+
],
|
| 47 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 48 |
+
|
| 49 |
+
# apply rotary embedding
|
| 50 |
+
sp_size = get_sequence_parallel_world_size()
|
| 51 |
+
sp_rank = get_sequence_parallel_rank()
|
| 52 |
+
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
| 53 |
+
s_per_rank = s
|
| 54 |
+
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
|
| 55 |
+
s_per_rank), :, :]
|
| 56 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
| 57 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
| 58 |
+
|
| 59 |
+
# append to collection
|
| 60 |
+
output.append(x_i)
|
| 61 |
+
return torch.stack(output)
|
| 62 |
+
|
| 63 |
+
def rope_apply_qk(q, k, grid_sizes, freqs):
|
| 64 |
+
q = rope_apply(q, grid_sizes, freqs)
|
| 65 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 66 |
+
return q, k
|
| 67 |
+
|
| 68 |
+
def usp_attn_forward(self,
|
| 69 |
+
x,
|
| 70 |
+
seq_lens,
|
| 71 |
+
grid_sizes,
|
| 72 |
+
freqs,
|
| 73 |
+
dtype=torch.bfloat16,
|
| 74 |
+
t=0):
|
| 75 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 76 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 77 |
+
|
| 78 |
+
def half(x):
|
| 79 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 80 |
+
|
| 81 |
+
# query, key, value function
|
| 82 |
+
def qkv_fn(x):
|
| 83 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 84 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 85 |
+
v = self.v(x).view(b, s, n, d)
|
| 86 |
+
return q, k, v
|
| 87 |
+
|
| 88 |
+
q, k, v = qkv_fn(x)
|
| 89 |
+
q, k = rope_apply_qk(q, k, grid_sizes, freqs)
|
| 90 |
+
|
| 91 |
+
# TODO: We should use unpaded q,k,v for attention.
|
| 92 |
+
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
| 93 |
+
# if k_lens is not None:
|
| 94 |
+
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
|
| 95 |
+
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
| 96 |
+
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
| 97 |
+
|
| 98 |
+
x = xFuserLongContextAttention()(
|
| 99 |
+
None,
|
| 100 |
+
query=half(q),
|
| 101 |
+
key=half(k),
|
| 102 |
+
value=half(v),
|
| 103 |
+
window_size=self.window_size)
|
| 104 |
+
|
| 105 |
+
# TODO: padding after attention.
|
| 106 |
+
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
|
| 107 |
+
|
| 108 |
+
# output
|
| 109 |
+
x = x.flatten(2)
|
| 110 |
+
x = self.o(x)
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
@amp.autocast(enabled=False)
|
| 114 |
+
@torch.compiler.disable()
|
| 115 |
+
def s2v_rope_apply(x, grid_sizes, freqs):
|
| 116 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
| 117 |
+
# loop over samples
|
| 118 |
+
output = []
|
| 119 |
+
for i, _ in enumerate(x):
|
| 120 |
+
s = x.size(1)
|
| 121 |
+
# precompute multipliers
|
| 122 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
|
| 123 |
+
s, n, -1, 2))
|
| 124 |
+
freqs_i = freqs[i]
|
| 125 |
+
freqs_i_rank = pad_freqs(freqs_i, s)
|
| 126 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
| 127 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
| 128 |
+
# append to collection
|
| 129 |
+
output.append(x_i)
|
| 130 |
+
return torch.stack(output).float()
|
| 131 |
+
|
| 132 |
+
def s2v_rope_apply_qk(q, k, grid_sizes, freqs):
|
| 133 |
+
q = s2v_rope_apply(q, grid_sizes, freqs)
|
| 134 |
+
k = s2v_rope_apply(k, grid_sizes, freqs)
|
| 135 |
+
return q, k
|
| 136 |
+
|
| 137 |
+
def usp_attn_s2v_forward(self,
|
| 138 |
+
x,
|
| 139 |
+
seq_lens,
|
| 140 |
+
grid_sizes,
|
| 141 |
+
freqs,
|
| 142 |
+
dtype=torch.bfloat16,
|
| 143 |
+
t=0):
|
| 144 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 145 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 146 |
+
|
| 147 |
+
def half(x):
|
| 148 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 149 |
+
|
| 150 |
+
# query, key, value function
|
| 151 |
+
def qkv_fn(x):
|
| 152 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 153 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 154 |
+
v = self.v(x).view(b, s, n, d)
|
| 155 |
+
return q, k, v
|
| 156 |
+
|
| 157 |
+
q, k, v = qkv_fn(x)
|
| 158 |
+
q, k = s2v_rope_apply_qk(q, k, grid_sizes, freqs)
|
| 159 |
+
|
| 160 |
+
# TODO: We should use unpaded q,k,v for attention.
|
| 161 |
+
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
| 162 |
+
# if k_lens is not None:
|
| 163 |
+
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
|
| 164 |
+
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
| 165 |
+
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
| 166 |
+
|
| 167 |
+
x = xFuserLongContextAttention()(
|
| 168 |
+
None,
|
| 169 |
+
query=half(q),
|
| 170 |
+
key=half(k),
|
| 171 |
+
value=half(v),
|
| 172 |
+
window_size=self.window_size)
|
| 173 |
+
|
| 174 |
+
# TODO: padding after attention.
|
| 175 |
+
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
|
| 176 |
+
|
| 177 |
+
# output
|
| 178 |
+
x = x.flatten(2)
|
| 179 |
+
x = self.o(x)
|
| 180 |
+
return x
|
videox_fun/pipeline/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .pipeline_wan import WanPipeline
|
| 2 |
+
from .pipeline_wan2_2 import Wan2_2Pipeline
|
| 3 |
+
|
| 4 |
+
WanFunPipeline = WanPipeline
|
| 5 |
+
Wan2_2FunPipeline = Wan2_2Pipeline
|
| 6 |
+
|
| 7 |
+
import importlib.util
|
| 8 |
+
|
| 9 |
+
if importlib.util.find_spec("paifuser") is not None:
|
| 10 |
+
# --------------------------------------------------------------- #
|
| 11 |
+
# Sparse Attention
|
| 12 |
+
# --------------------------------------------------------------- #
|
| 13 |
+
from paifuser.ops import sparse_reset
|
| 14 |
+
|
| 15 |
+
# Wan2.1
|
| 16 |
+
WanFunPipeline.__call__ = sparse_reset(WanFunPipeline.__call__)
|
| 17 |
+
WanPipeline.__call__ = sparse_reset(WanPipeline.__call__)
|
| 18 |
+
|
| 19 |
+
# Wan2.2
|
| 20 |
+
Wan2_2FunPipeline.__call__ = sparse_reset(Wan2_2FunPipeline.__call__)
|
| 21 |
+
Wan2_2Pipeline.__call__ = sparse_reset(Wan2_2Pipeline.__call__)
|
videox_fun/pipeline/pipeline_wan.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 9 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 10 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 11 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 12 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 13 |
+
from diffusers.video_processor import VideoProcessor
|
| 14 |
+
|
| 15 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer,
|
| 16 |
+
WanT5EncoderModel, WanTransformer3DModel)
|
| 17 |
+
from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 18 |
+
get_sampling_sigmas)
|
| 19 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
EXAMPLE_DOC_STRING = """
|
| 25 |
+
Examples:
|
| 26 |
+
```python
|
| 27 |
+
pass
|
| 28 |
+
```
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 33 |
+
def retrieve_timesteps(
|
| 34 |
+
scheduler,
|
| 35 |
+
num_inference_steps: Optional[int] = None,
|
| 36 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 37 |
+
timesteps: Optional[List[int]] = None,
|
| 38 |
+
sigmas: Optional[List[float]] = None,
|
| 39 |
+
**kwargs,
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 43 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
scheduler (`SchedulerMixin`):
|
| 47 |
+
The scheduler to get timesteps from.
|
| 48 |
+
num_inference_steps (`int`):
|
| 49 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 50 |
+
must be `None`.
|
| 51 |
+
device (`str` or `torch.device`, *optional*):
|
| 52 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 53 |
+
timesteps (`List[int]`, *optional*):
|
| 54 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 55 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 56 |
+
sigmas (`List[float]`, *optional*):
|
| 57 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 58 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 62 |
+
second element is the number of inference steps.
|
| 63 |
+
"""
|
| 64 |
+
if timesteps is not None and sigmas is not None:
|
| 65 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 66 |
+
if timesteps is not None:
|
| 67 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 68 |
+
if not accepts_timesteps:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 71 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 72 |
+
)
|
| 73 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 74 |
+
timesteps = scheduler.timesteps
|
| 75 |
+
num_inference_steps = len(timesteps)
|
| 76 |
+
elif sigmas is not None:
|
| 77 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 78 |
+
if not accept_sigmas:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 81 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 82 |
+
)
|
| 83 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 84 |
+
timesteps = scheduler.timesteps
|
| 85 |
+
num_inference_steps = len(timesteps)
|
| 86 |
+
else:
|
| 87 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 88 |
+
timesteps = scheduler.timesteps
|
| 89 |
+
return timesteps, num_inference_steps
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclass
|
| 93 |
+
class WanPipelineOutput(BaseOutput):
|
| 94 |
+
r"""
|
| 95 |
+
Output class for Wan pipelines.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
videos: full decoded video tensor
|
| 99 |
+
ground_videos: decoded grounding segment (optional)
|
| 100 |
+
edit_videos: decoded edited segment (optional)
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
videos: torch.Tensor
|
| 104 |
+
ground_videos: Optional[torch.Tensor] = None
|
| 105 |
+
edit_videos: Optional[torch.Tensor] = None
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class WanPipeline(DiffusionPipeline):
|
| 109 |
+
r"""
|
| 110 |
+
Pipeline for text-to-video generation using Wan.
|
| 111 |
+
|
| 112 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 113 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
_optional_components = []
|
| 117 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 118 |
+
|
| 119 |
+
_callback_tensor_inputs = [
|
| 120 |
+
"latents",
|
| 121 |
+
"prompt_embeds",
|
| 122 |
+
"negative_prompt_embeds",
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
tokenizer: AutoTokenizer,
|
| 128 |
+
text_encoder: WanT5EncoderModel,
|
| 129 |
+
vae: AutoencoderKLWan,
|
| 130 |
+
transformer: WanTransformer3DModel,
|
| 131 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 132 |
+
):
|
| 133 |
+
super().__init__()
|
| 134 |
+
|
| 135 |
+
self.register_modules(
|
| 136 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
| 137 |
+
)
|
| 138 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 139 |
+
|
| 140 |
+
def _get_t5_prompt_embeds(
|
| 141 |
+
self,
|
| 142 |
+
prompt: Union[str, List[str]] = None,
|
| 143 |
+
num_videos_per_prompt: int = 1,
|
| 144 |
+
max_sequence_length: int = 512,
|
| 145 |
+
device: Optional[torch.device] = None,
|
| 146 |
+
dtype: Optional[torch.dtype] = None,
|
| 147 |
+
):
|
| 148 |
+
device = device or self._execution_device
|
| 149 |
+
dtype = dtype or self.text_encoder.dtype
|
| 150 |
+
|
| 151 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 152 |
+
batch_size = len(prompt)
|
| 153 |
+
|
| 154 |
+
text_inputs = self.tokenizer(
|
| 155 |
+
prompt,
|
| 156 |
+
padding="max_length",
|
| 157 |
+
max_length=max_sequence_length,
|
| 158 |
+
truncation=True,
|
| 159 |
+
add_special_tokens=True,
|
| 160 |
+
return_tensors="pt",
|
| 161 |
+
)
|
| 162 |
+
text_input_ids = text_inputs.input_ids
|
| 163 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 164 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 165 |
+
|
| 166 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 167 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 168 |
+
logger.warning(
|
| 169 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 170 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 174 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 175 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 176 |
+
|
| 177 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 178 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 179 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 180 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 181 |
+
|
| 182 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 183 |
+
|
| 184 |
+
def encode_prompt(
|
| 185 |
+
self,
|
| 186 |
+
prompt: Union[str, List[str]],
|
| 187 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 188 |
+
do_classifier_free_guidance: bool = True,
|
| 189 |
+
num_videos_per_prompt: int = 1,
|
| 190 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 191 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 192 |
+
max_sequence_length: int = 512,
|
| 193 |
+
device: Optional[torch.device] = None,
|
| 194 |
+
dtype: Optional[torch.dtype] = None,
|
| 195 |
+
):
|
| 196 |
+
r"""
|
| 197 |
+
Encodes the prompt into text encoder hidden states.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 201 |
+
prompt to be encoded
|
| 202 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 203 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 204 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 205 |
+
less than `1`).
|
| 206 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 207 |
+
Whether to use classifier free guidance or not.
|
| 208 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 209 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 210 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 211 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 212 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 213 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 214 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 215 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 216 |
+
argument.
|
| 217 |
+
device: (`torch.device`, *optional*):
|
| 218 |
+
torch device
|
| 219 |
+
dtype: (`torch.dtype`, *optional*):
|
| 220 |
+
torch dtype
|
| 221 |
+
"""
|
| 222 |
+
device = device or self._execution_device
|
| 223 |
+
|
| 224 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 225 |
+
if prompt is not None:
|
| 226 |
+
batch_size = len(prompt)
|
| 227 |
+
else:
|
| 228 |
+
batch_size = prompt_embeds.shape[0]
|
| 229 |
+
|
| 230 |
+
if prompt_embeds is None:
|
| 231 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 232 |
+
prompt=prompt,
|
| 233 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 234 |
+
max_sequence_length=max_sequence_length,
|
| 235 |
+
device=device,
|
| 236 |
+
dtype=dtype,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 240 |
+
negative_prompt = negative_prompt or ""
|
| 241 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 242 |
+
|
| 243 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 244 |
+
raise TypeError(
|
| 245 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 246 |
+
f" {type(prompt)}."
|
| 247 |
+
)
|
| 248 |
+
elif batch_size != len(negative_prompt):
|
| 249 |
+
raise ValueError(
|
| 250 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 251 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 252 |
+
" the batch size of `prompt`."
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 256 |
+
prompt=negative_prompt,
|
| 257 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 258 |
+
max_sequence_length=max_sequence_length,
|
| 259 |
+
device=device,
|
| 260 |
+
dtype=dtype,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
return prompt_embeds, negative_prompt_embeds
|
| 264 |
+
|
| 265 |
+
def prepare_latents(
|
| 266 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 267 |
+
):
|
| 268 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 269 |
+
raise ValueError(
|
| 270 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 271 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
shape = (
|
| 275 |
+
batch_size,
|
| 276 |
+
num_channels_latents,
|
| 277 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1,
|
| 278 |
+
height // self.vae.spatial_compression_ratio,
|
| 279 |
+
width // self.vae.spatial_compression_ratio,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
if latents is None:
|
| 283 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 284 |
+
else:
|
| 285 |
+
latents = latents.to(device)
|
| 286 |
+
|
| 287 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 288 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 289 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 290 |
+
return latents
|
| 291 |
+
|
| 292 |
+
def prepare_video_latents(
|
| 293 |
+
self,
|
| 294 |
+
video: torch.Tensor,
|
| 295 |
+
batch_size: int = 1,
|
| 296 |
+
num_channels_latents: int = 16,
|
| 297 |
+
height: int = 480,
|
| 298 |
+
width: int = 832,
|
| 299 |
+
dtype: torch.dtype = torch.float32,
|
| 300 |
+
device: torch.device = None,
|
| 301 |
+
generator: torch.Generator = None,
|
| 302 |
+
condition_count: int = None,
|
| 303 |
+
latents: torch.Tensor = None,
|
| 304 |
+
timestep: torch.Tensor = None,
|
| 305 |
+
):
|
| 306 |
+
|
| 307 |
+
video = video.to(device=device, dtype=dtype)
|
| 308 |
+
num_latent_frames = (video.shape[2] - 1) // self.vae.temporal_compression_ratio + 1
|
| 309 |
+
|
| 310 |
+
shape = (
|
| 311 |
+
batch_size,
|
| 312 |
+
num_channels_latents,
|
| 313 |
+
num_latent_frames,
|
| 314 |
+
height // self.vae.spatial_compression_ratio,
|
| 315 |
+
width // self.vae.spatial_compression_ratio,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
if latents is not None:
|
| 319 |
+
return latents.to(device=device, dtype=dtype)
|
| 320 |
+
|
| 321 |
+
video_latents = []
|
| 322 |
+
print('video',video.shape)
|
| 323 |
+
for i in range(video.shape[0]):
|
| 324 |
+
# 假设 self.vae.encode 返回的是 (LatentDistribution, …)
|
| 325 |
+
latent_dist = self.vae.encode(video[i : i + 1])[0]
|
| 326 |
+
latent = latent_dist.mode() # 直接取 mode,不做 mean/std
|
| 327 |
+
video_latents.append(latent)
|
| 328 |
+
init_latents = torch.cat(video_latents, dim=0) # (B, C, T, H', W')
|
| 329 |
+
|
| 330 |
+
# 再往前 condition_count 帧注入随机 noise
|
| 331 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 332 |
+
init_latents[:, :, condition_count:, :, :] = noise[:, :, condition_count:, :, :]
|
| 333 |
+
|
| 334 |
+
# 现在可以正确调用 add_noise
|
| 335 |
+
# init_latents[:, :, condition_count:, :, :] = self.scheduler.add_noise(
|
| 336 |
+
# init_latents[:, :, condition_count:, :, :],
|
| 337 |
+
# noise[:, :, condition_count:, :, :],
|
| 338 |
+
# timestep
|
| 339 |
+
# )
|
| 340 |
+
# print('init_latents shape',init_latents.shape)
|
| 341 |
+
return init_latents
|
| 342 |
+
|
| 343 |
+
def prepare_video_latents_new(
|
| 344 |
+
self,
|
| 345 |
+
video: torch.Tensor,
|
| 346 |
+
batch_size: int = 1,
|
| 347 |
+
num_channels_latents: int = 16,
|
| 348 |
+
height: int = 480,
|
| 349 |
+
width: int = 832,
|
| 350 |
+
dtype: torch.dtype = torch.float32,
|
| 351 |
+
device: torch.device = None,
|
| 352 |
+
generator: torch.Generator = None,
|
| 353 |
+
condition_count: int = None,
|
| 354 |
+
latents: torch.Tensor = None,
|
| 355 |
+
timestep: torch.Tensor = None,
|
| 356 |
+
):
|
| 357 |
+
|
| 358 |
+
video = video.to(device=device, dtype=dtype)
|
| 359 |
+
|
| 360 |
+
if latents is not None:
|
| 361 |
+
return latents.to(device=device, dtype=dtype)
|
| 362 |
+
|
| 363 |
+
video_latents = []
|
| 364 |
+
print('video',video.shape)
|
| 365 |
+
for i in range(video.shape[0]):
|
| 366 |
+
# 假设 self.vae.encode 返回的是 (LatentDistribution, …)
|
| 367 |
+
latent_dist = self.vae.encode(video[i : i + 1])[0]
|
| 368 |
+
latent = latent_dist.mode() # 直接取 mode,不做 mean/std
|
| 369 |
+
video_latents.append(latent)
|
| 370 |
+
org_latents = torch.cat(video_latents, dim=0) # (B, C, T, H', W')
|
| 371 |
+
print('org_latents',org_latents.shape)
|
| 372 |
+
|
| 373 |
+
# 再往后 condition_count 帧注入随机 noise,shape和org_latents一样
|
| 374 |
+
noise = randn_tensor(org_latents.shape, generator=generator, device=device, dtype=dtype)
|
| 375 |
+
print('noise',noise.shape)
|
| 376 |
+
init_latents = torch.cat([org_latents, noise], dim=2)
|
| 377 |
+
print('init_latents',init_latents.shape)
|
| 378 |
+
return init_latents
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def prepare_cot_video_latents(
|
| 382 |
+
self,
|
| 383 |
+
video: torch.Tensor,
|
| 384 |
+
reasoning_latent_count: int = 1,
|
| 385 |
+
batch_size: int = 1,
|
| 386 |
+
num_channels_latents: int = 16,
|
| 387 |
+
height: int = 480,
|
| 388 |
+
width: int = 832,
|
| 389 |
+
dtype: torch.dtype = torch.float32,
|
| 390 |
+
device: torch.device = None,
|
| 391 |
+
generator: torch.Generator = None,
|
| 392 |
+
condition_count: int = None,
|
| 393 |
+
latents: torch.Tensor = None,
|
| 394 |
+
timestep: torch.Tensor = None,
|
| 395 |
+
):
|
| 396 |
+
|
| 397 |
+
video = video.to(device=device, dtype=dtype)
|
| 398 |
+
|
| 399 |
+
if latents is not None:
|
| 400 |
+
return latents.to(device=device, dtype=dtype)
|
| 401 |
+
|
| 402 |
+
video_latents = []
|
| 403 |
+
#print('video',video.shape)
|
| 404 |
+
for i in range(video.shape[0]):
|
| 405 |
+
# 假设 self.vae.encode 返回的是 (LatentDistribution, …)
|
| 406 |
+
latent_dist = self.vae.encode(video[i : i + 1])[0]
|
| 407 |
+
latent = latent_dist.mode() # 直接取 mode,不做 mean/std
|
| 408 |
+
video_latents.append(latent)
|
| 409 |
+
org_latents = torch.cat(video_latents, dim=0) # (B, C, T, H', W')
|
| 410 |
+
print('org_latents',org_latents.shape)
|
| 411 |
+
batch_size, num_channels_latents, num_frames_latent, height_latent, width_latent = org_latents.shape
|
| 412 |
+
tgt_frames = num_frames_latent + reasoning_latent_count
|
| 413 |
+
noise_latents_shape = (batch_size, num_channels_latents, tgt_frames, height_latent, width_latent)
|
| 414 |
+
# 再往后 condition_count 帧注入随机 noise,shape和org_latents一样
|
| 415 |
+
noise = randn_tensor(noise_latents_shape, generator=generator, device=device, dtype=dtype)
|
| 416 |
+
print('noise',noise.shape)
|
| 417 |
+
init_latents = torch.cat([org_latents, noise], dim=2)
|
| 418 |
+
print('init_latents',init_latents.shape)
|
| 419 |
+
return init_latents
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 424 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 425 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 426 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 427 |
+
frames = frames.cpu().float().numpy()
|
| 428 |
+
return frames
|
| 429 |
+
|
| 430 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 431 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 432 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 433 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 434 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 435 |
+
# and should be between [0, 1]
|
| 436 |
+
|
| 437 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 438 |
+
extra_step_kwargs = {}
|
| 439 |
+
if accepts_eta:
|
| 440 |
+
extra_step_kwargs["eta"] = eta
|
| 441 |
+
|
| 442 |
+
# check if the scheduler accepts generator
|
| 443 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 444 |
+
if accepts_generator:
|
| 445 |
+
extra_step_kwargs["generator"] = generator
|
| 446 |
+
return extra_step_kwargs
|
| 447 |
+
|
| 448 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 449 |
+
def check_inputs(
|
| 450 |
+
self,
|
| 451 |
+
prompt,
|
| 452 |
+
height,
|
| 453 |
+
width,
|
| 454 |
+
negative_prompt,
|
| 455 |
+
callback_on_step_end_tensor_inputs,
|
| 456 |
+
prompt_embeds=None,
|
| 457 |
+
negative_prompt_embeds=None,
|
| 458 |
+
):
|
| 459 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 460 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 461 |
+
|
| 462 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 463 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 464 |
+
):
|
| 465 |
+
raise ValueError(
|
| 466 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 467 |
+
)
|
| 468 |
+
if prompt is not None and prompt_embeds is not None:
|
| 469 |
+
raise ValueError(
|
| 470 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 471 |
+
" only forward one of the two."
|
| 472 |
+
)
|
| 473 |
+
elif prompt is None and prompt_embeds is None:
|
| 474 |
+
raise ValueError(
|
| 475 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 476 |
+
)
|
| 477 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 478 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 479 |
+
|
| 480 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 481 |
+
raise ValueError(
|
| 482 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 483 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 487 |
+
raise ValueError(
|
| 488 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 489 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 493 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 494 |
+
raise ValueError(
|
| 495 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 496 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 497 |
+
f" {negative_prompt_embeds.shape}."
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
@property
|
| 501 |
+
def guidance_scale(self):
|
| 502 |
+
return self._guidance_scale
|
| 503 |
+
|
| 504 |
+
@property
|
| 505 |
+
def num_timesteps(self):
|
| 506 |
+
return self._num_timesteps
|
| 507 |
+
|
| 508 |
+
@property
|
| 509 |
+
def attention_kwargs(self):
|
| 510 |
+
return self._attention_kwargs
|
| 511 |
+
|
| 512 |
+
@property
|
| 513 |
+
def interrupt(self):
|
| 514 |
+
return self._interrupt
|
| 515 |
+
|
| 516 |
+
@torch.no_grad()
|
| 517 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 518 |
+
def __call__(
|
| 519 |
+
self,
|
| 520 |
+
video: Union[torch.FloatTensor] = None,
|
| 521 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 522 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 523 |
+
height: int = 480,
|
| 524 |
+
width: int = 720,
|
| 525 |
+
num_frames: int = 49,
|
| 526 |
+
source_frames: int = 33,
|
| 527 |
+
reasoning_frames: int = 4,
|
| 528 |
+
num_inference_steps: int = 50,
|
| 529 |
+
timesteps: Optional[List[int]] = None,
|
| 530 |
+
guidance_scale: float = 6,
|
| 531 |
+
num_videos_per_prompt: int = 1,
|
| 532 |
+
eta: float = 0.0,
|
| 533 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 534 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 535 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 536 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 537 |
+
output_type: str = "numpy",
|
| 538 |
+
return_dict: bool = False,
|
| 539 |
+
callback_on_step_end: Optional[
|
| 540 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 541 |
+
] = None,
|
| 542 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 543 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 544 |
+
max_sequence_length: int = 512,
|
| 545 |
+
comfyui_progressbar: bool = False,
|
| 546 |
+
shift: int = 5,
|
| 547 |
+
repeat_rope: bool = True,
|
| 548 |
+
cot: bool = False,
|
| 549 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 550 |
+
"""
|
| 551 |
+
Function invoked when calling the pipeline for generation.
|
| 552 |
+
Args:
|
| 553 |
+
|
| 554 |
+
Examples:
|
| 555 |
+
|
| 556 |
+
Returns:
|
| 557 |
+
|
| 558 |
+
"""
|
| 559 |
+
|
| 560 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 561 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 562 |
+
num_videos_per_prompt = 1
|
| 563 |
+
|
| 564 |
+
# 1. Check inputs. Raise error if not correct
|
| 565 |
+
self.check_inputs(
|
| 566 |
+
prompt,
|
| 567 |
+
height,
|
| 568 |
+
width,
|
| 569 |
+
negative_prompt,
|
| 570 |
+
callback_on_step_end_tensor_inputs,
|
| 571 |
+
prompt_embeds,
|
| 572 |
+
negative_prompt_embeds,
|
| 573 |
+
)
|
| 574 |
+
self._guidance_scale = guidance_scale
|
| 575 |
+
self._attention_kwargs = attention_kwargs
|
| 576 |
+
self._interrupt = False
|
| 577 |
+
|
| 578 |
+
# 2. Default call parameters
|
| 579 |
+
if prompt is not None and isinstance(prompt, str):
|
| 580 |
+
batch_size = 1
|
| 581 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 582 |
+
batch_size = len(prompt)
|
| 583 |
+
else:
|
| 584 |
+
batch_size = prompt_embeds.shape[0]
|
| 585 |
+
|
| 586 |
+
device = self._execution_device
|
| 587 |
+
weight_dtype = self.text_encoder.dtype
|
| 588 |
+
|
| 589 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 590 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 591 |
+
# corresponds to doing no classifier free guidance.
|
| 592 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 593 |
+
|
| 594 |
+
# 3. Encode input prompt
|
| 595 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 596 |
+
prompt,
|
| 597 |
+
negative_prompt,
|
| 598 |
+
do_classifier_free_guidance,
|
| 599 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 600 |
+
prompt_embeds=prompt_embeds,
|
| 601 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 602 |
+
max_sequence_length=max_sequence_length,
|
| 603 |
+
device=device,
|
| 604 |
+
)
|
| 605 |
+
if do_classifier_free_guidance:
|
| 606 |
+
in_prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 607 |
+
else:
|
| 608 |
+
in_prompt_embeds = prompt_embeds
|
| 609 |
+
|
| 610 |
+
# 4. Prepare timesteps
|
| 611 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 612 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
| 613 |
+
elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
|
| 614 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
|
| 615 |
+
timesteps = self.scheduler.timesteps
|
| 616 |
+
elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
|
| 617 |
+
sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
|
| 618 |
+
timesteps, _ = retrieve_timesteps(
|
| 619 |
+
self.scheduler,
|
| 620 |
+
device=device,
|
| 621 |
+
sigmas=sampling_sigmas)
|
| 622 |
+
else:
|
| 623 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 624 |
+
self._num_timesteps = len(timesteps)
|
| 625 |
+
if comfyui_progressbar:
|
| 626 |
+
from comfy.utils import ProgressBar
|
| 627 |
+
pbar = ProgressBar(num_inference_steps + 1)
|
| 628 |
+
|
| 629 |
+
# compute latent source length consistent with training: (F-1)//ratio + 1, or 1 when F==1
|
| 630 |
+
compression_ratio = getattr(self.vae, "temporal_compression_ratio", 4)
|
| 631 |
+
condition_count = 1 if source_frames == 1 else (source_frames - 1) // compression_ratio + 1
|
| 632 |
+
|
| 633 |
+
# 5. Prepare latents (unified across org/repeat/cot)
|
| 634 |
+
latent_channels = self.transformer.config.in_channels
|
| 635 |
+
if cot:
|
| 636 |
+
# latent grounding segment length from pixel-space reasoning_frames (used only when cot=True)
|
| 637 |
+
ground_latent_count = 1 if reasoning_frames <= 1 else (reasoning_frames - 1) // compression_ratio + 1
|
| 638 |
+
print('ground_latent_count',ground_latent_count)
|
| 639 |
+
latents = self.prepare_cot_video_latents(
|
| 640 |
+
video,
|
| 641 |
+
ground_latent_count,
|
| 642 |
+
batch_size,
|
| 643 |
+
latent_channels,
|
| 644 |
+
height,
|
| 645 |
+
width,
|
| 646 |
+
weight_dtype,
|
| 647 |
+
device,
|
| 648 |
+
generator,
|
| 649 |
+
condition_count,
|
| 650 |
+
latents,
|
| 651 |
+
)
|
| 652 |
+
elif repeat_rope:
|
| 653 |
+
latents = self.prepare_video_latents_new(
|
| 654 |
+
video,
|
| 655 |
+
batch_size,
|
| 656 |
+
latent_channels,
|
| 657 |
+
height,
|
| 658 |
+
width,
|
| 659 |
+
weight_dtype,
|
| 660 |
+
device,
|
| 661 |
+
generator,
|
| 662 |
+
condition_count,
|
| 663 |
+
latents,
|
| 664 |
+
)
|
| 665 |
+
else:
|
| 666 |
+
latents = self.prepare_video_latents_new(
|
| 667 |
+
video,
|
| 668 |
+
batch_size,
|
| 669 |
+
latent_channels,
|
| 670 |
+
height,
|
| 671 |
+
width,
|
| 672 |
+
weight_dtype,
|
| 673 |
+
device,
|
| 674 |
+
generator,
|
| 675 |
+
condition_count,
|
| 676 |
+
latents,
|
| 677 |
+
)
|
| 678 |
+
if comfyui_progressbar:
|
| 679 |
+
pbar.update(1)
|
| 680 |
+
|
| 681 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 682 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 683 |
+
|
| 684 |
+
# Get actual latent dimensions (consistent with training)
|
| 685 |
+
#print('latents',latents.shape)
|
| 686 |
+
bsz, channel, actual_num_frames, actual_height, actual_width = latents.size()
|
| 687 |
+
target_shape = (self.vae.latent_channels, actual_num_frames, actual_height, actual_width)
|
| 688 |
+
#print('target_shape',target_shape)
|
| 689 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 690 |
+
# 7. Denoising loop
|
| 691 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 692 |
+
self.transformer.num_inference_steps = num_inference_steps
|
| 693 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 694 |
+
for i, t in enumerate(timesteps):
|
| 695 |
+
self.transformer.current_steps = i
|
| 696 |
+
|
| 697 |
+
if self.interrupt:
|
| 698 |
+
continue
|
| 699 |
+
|
| 700 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 701 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 702 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 703 |
+
|
| 704 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 705 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 706 |
+
|
| 707 |
+
# predict noise model_output
|
| 708 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
|
| 709 |
+
# frame_split_indices enables repeat temporal RoPE for paired (src+tgt) inputs
|
| 710 |
+
frame_split_indices = None
|
| 711 |
+
ground_frame_indices = None
|
| 712 |
+
if repeat_rope and video is not None:
|
| 713 |
+
frame_split_indices = [condition_count] * latent_model_input.shape[0]
|
| 714 |
+
if cot:
|
| 715 |
+
# grounding frames should use temporal RoPE position 0
|
| 716 |
+
ground_frame_indices = [
|
| 717 |
+
(condition_count, condition_count + ground_latent_count)
|
| 718 |
+
] * latent_model_input.shape[0]
|
| 719 |
+
# print('ground_frame_indices',ground_frame_indices)
|
| 720 |
+
# print('frame_split_indices',frame_split_indices)
|
| 721 |
+
noise_pred = self.transformer(
|
| 722 |
+
x=latent_model_input,
|
| 723 |
+
context=in_prompt_embeds,
|
| 724 |
+
t=timestep,
|
| 725 |
+
seq_len=seq_len,
|
| 726 |
+
frame_split_indices=frame_split_indices,
|
| 727 |
+
ground_frame_indices=ground_frame_indices,
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
# perform guidance
|
| 731 |
+
if do_classifier_free_guidance:
|
| 732 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 733 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 734 |
+
|
| 735 |
+
######source video no noise pred################
|
| 736 |
+
noise_pred[:, :, :condition_count] = 0
|
| 737 |
+
######source video no noise pred################
|
| 738 |
+
|
| 739 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 740 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 741 |
+
|
| 742 |
+
if callback_on_step_end is not None:
|
| 743 |
+
callback_kwargs = {}
|
| 744 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 745 |
+
callback_kwargs[k] = locals()[k]
|
| 746 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 747 |
+
|
| 748 |
+
latents = callback_outputs.pop("latents", latents)
|
| 749 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 750 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 751 |
+
|
| 752 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 753 |
+
progress_bar.update()
|
| 754 |
+
if comfyui_progressbar:
|
| 755 |
+
pbar.update(1)
|
| 756 |
+
|
| 757 |
+
# Optionally decode outputs. For cot=True, segment into src/ground/edit; otherwise decode whole latents
|
| 758 |
+
ground_video = None
|
| 759 |
+
edit_video = None
|
| 760 |
+
if cot:
|
| 761 |
+
if output_type == "numpy":
|
| 762 |
+
ground_start = condition_count
|
| 763 |
+
ground_end = condition_count + ground_latent_count
|
| 764 |
+
src_lat = latents[:, :, :ground_start] if ground_start > 0 else None
|
| 765 |
+
ground_lat = latents[:, :, ground_start:ground_end] if ground_end > ground_start and ground_start < latents.shape[2] else None
|
| 766 |
+
edit_lat = latents[:, :, ground_end:] if ground_end < latents.shape[2] else None
|
| 767 |
+
|
| 768 |
+
parts = []
|
| 769 |
+
## only ground and edit
|
| 770 |
+
if ground_lat is not None and ground_lat.shape[2] > 0:
|
| 771 |
+
ground_video = self.decode_latents(ground_lat)
|
| 772 |
+
parts.append(ground_video)
|
| 773 |
+
if edit_lat is not None and edit_lat.shape[2] > 0:
|
| 774 |
+
edit_video = self.decode_latents(edit_lat)
|
| 775 |
+
parts.append(edit_video)
|
| 776 |
+
print('ground_video',ground_video.shape, 'edit_video',edit_video.shape)
|
| 777 |
+
video = np.concatenate(parts, axis=2)
|
| 778 |
+
else:
|
| 779 |
+
# org/repeat: split by condition_count -> src + edit, then temporal concat
|
| 780 |
+
if output_type == "numpy":
|
| 781 |
+
src_lat = latents[:, :, :condition_count] if condition_count > 0 else None
|
| 782 |
+
edit_lat = latents[:, :, condition_count:] if condition_count < latents.shape[2] else None
|
| 783 |
+
## only decode edit video
|
| 784 |
+
if edit_lat is not None and edit_lat.shape[2] > 0:
|
| 785 |
+
edit_video = self.decode_latents(edit_lat)
|
| 786 |
+
video = edit_video
|
| 787 |
+
|
| 788 |
+
# Offload all models
|
| 789 |
+
self.maybe_free_model_hooks()
|
| 790 |
+
|
| 791 |
+
if not return_dict:
|
| 792 |
+
if isinstance(video, np.ndarray):
|
| 793 |
+
video = torch.from_numpy(video)
|
| 794 |
+
if ground_video is not None and isinstance(ground_video, np.ndarray):
|
| 795 |
+
ground_video = torch.from_numpy(ground_video)
|
| 796 |
+
if edit_video is not None and isinstance(edit_video, np.ndarray):
|
| 797 |
+
edit_video = torch.from_numpy(edit_video)
|
| 798 |
+
|
| 799 |
+
return WanPipelineOutput(videos=video, ground_videos=ground_video, edit_videos=edit_video)
|
videox_fun/pipeline/pipeline_wan2_2.py
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 9 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 10 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 11 |
+
from diffusers.utils import BaseOutput, logging, replace_example_docstring
|
| 12 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 13 |
+
from diffusers.video_processor import VideoProcessor
|
| 14 |
+
|
| 15 |
+
from ..models import (AutoencoderKLWan, AutoTokenizer,
|
| 16 |
+
WanT5EncoderModel, Wan2_2Transformer3DModel)
|
| 17 |
+
from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
| 18 |
+
get_sampling_sigmas)
|
| 19 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
EXAMPLE_DOC_STRING = """
|
| 25 |
+
Examples:
|
| 26 |
+
```python
|
| 27 |
+
pass
|
| 28 |
+
```
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 33 |
+
def retrieve_timesteps(
|
| 34 |
+
scheduler,
|
| 35 |
+
num_inference_steps: Optional[int] = None,
|
| 36 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 37 |
+
timesteps: Optional[List[int]] = None,
|
| 38 |
+
sigmas: Optional[List[float]] = None,
|
| 39 |
+
**kwargs,
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 43 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
scheduler (`SchedulerMixin`):
|
| 47 |
+
The scheduler to get timesteps from.
|
| 48 |
+
num_inference_steps (`int`):
|
| 49 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 50 |
+
must be `None`.
|
| 51 |
+
device (`str` or `torch.device`, *optional*):
|
| 52 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 53 |
+
timesteps (`List[int]`, *optional*):
|
| 54 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 55 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 56 |
+
sigmas (`List[float]`, *optional*):
|
| 57 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 58 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 62 |
+
second element is the number of inference steps.
|
| 63 |
+
"""
|
| 64 |
+
if timesteps is not None and sigmas is not None:
|
| 65 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 66 |
+
if timesteps is not None:
|
| 67 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 68 |
+
if not accepts_timesteps:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 71 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 72 |
+
)
|
| 73 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 74 |
+
timesteps = scheduler.timesteps
|
| 75 |
+
num_inference_steps = len(timesteps)
|
| 76 |
+
elif sigmas is not None:
|
| 77 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 78 |
+
if not accept_sigmas:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 81 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 82 |
+
)
|
| 83 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 84 |
+
timesteps = scheduler.timesteps
|
| 85 |
+
num_inference_steps = len(timesteps)
|
| 86 |
+
else:
|
| 87 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 88 |
+
timesteps = scheduler.timesteps
|
| 89 |
+
return timesteps, num_inference_steps
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclass
|
| 93 |
+
class WanPipelineOutput(BaseOutput):
|
| 94 |
+
r"""
|
| 95 |
+
Output class for CogVideo pipelines.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
| 99 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
| 100 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
| 101 |
+
`(batch_size, num_frames, channels, height, width)`.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
videos: torch.Tensor
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class Wan2_2Pipeline(DiffusionPipeline):
|
| 108 |
+
r"""
|
| 109 |
+
Pipeline for text-to-video generation using Wan.
|
| 110 |
+
|
| 111 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 112 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
_optional_components = ["transformer_2"]
|
| 116 |
+
model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
|
| 117 |
+
|
| 118 |
+
_callback_tensor_inputs = [
|
| 119 |
+
"latents",
|
| 120 |
+
"prompt_embeds",
|
| 121 |
+
"negative_prompt_embeds",
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
tokenizer: AutoTokenizer,
|
| 127 |
+
text_encoder: WanT5EncoderModel,
|
| 128 |
+
vae: AutoencoderKLWan,
|
| 129 |
+
transformer: Wan2_2Transformer3DModel,
|
| 130 |
+
transformer_2: Wan2_2Transformer3DModel = None,
|
| 131 |
+
scheduler: FlowMatchEulerDiscreteScheduler = None,
|
| 132 |
+
):
|
| 133 |
+
super().__init__()
|
| 134 |
+
|
| 135 |
+
self.register_modules(
|
| 136 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
|
| 137 |
+
transformer_2=transformer_2, scheduler=scheduler
|
| 138 |
+
)
|
| 139 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
|
| 140 |
+
|
| 141 |
+
def _get_t5_prompt_embeds(
|
| 142 |
+
self,
|
| 143 |
+
prompt: Union[str, List[str]] = None,
|
| 144 |
+
num_videos_per_prompt: int = 1,
|
| 145 |
+
max_sequence_length: int = 512,
|
| 146 |
+
device: Optional[torch.device] = None,
|
| 147 |
+
dtype: Optional[torch.dtype] = None,
|
| 148 |
+
):
|
| 149 |
+
device = device or self._execution_device
|
| 150 |
+
dtype = dtype or self.text_encoder.dtype
|
| 151 |
+
|
| 152 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 153 |
+
batch_size = len(prompt)
|
| 154 |
+
|
| 155 |
+
text_inputs = self.tokenizer(
|
| 156 |
+
prompt,
|
| 157 |
+
padding="max_length",
|
| 158 |
+
max_length=max_sequence_length,
|
| 159 |
+
truncation=True,
|
| 160 |
+
add_special_tokens=True,
|
| 161 |
+
return_tensors="pt",
|
| 162 |
+
)
|
| 163 |
+
text_input_ids = text_inputs.input_ids
|
| 164 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 165 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 166 |
+
|
| 167 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 168 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 169 |
+
logger.warning(
|
| 170 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 171 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
|
| 175 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
|
| 176 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 177 |
+
|
| 178 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 179 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 180 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 181 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 182 |
+
|
| 183 |
+
return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 184 |
+
|
| 185 |
+
def encode_prompt(
|
| 186 |
+
self,
|
| 187 |
+
prompt: Union[str, List[str]],
|
| 188 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 189 |
+
do_classifier_free_guidance: bool = True,
|
| 190 |
+
num_videos_per_prompt: int = 1,
|
| 191 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 192 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 193 |
+
max_sequence_length: int = 512,
|
| 194 |
+
device: Optional[torch.device] = None,
|
| 195 |
+
dtype: Optional[torch.dtype] = None,
|
| 196 |
+
):
|
| 197 |
+
r"""
|
| 198 |
+
Encodes the prompt into text encoder hidden states.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 202 |
+
prompt to be encoded
|
| 203 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 204 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 205 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 206 |
+
less than `1`).
|
| 207 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 208 |
+
Whether to use classifier free guidance or not.
|
| 209 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 210 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 211 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 212 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 213 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 214 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 215 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 216 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 217 |
+
argument.
|
| 218 |
+
device: (`torch.device`, *optional*):
|
| 219 |
+
torch device
|
| 220 |
+
dtype: (`torch.dtype`, *optional*):
|
| 221 |
+
torch dtype
|
| 222 |
+
"""
|
| 223 |
+
device = device or self._execution_device
|
| 224 |
+
|
| 225 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 226 |
+
if prompt is not None:
|
| 227 |
+
batch_size = len(prompt)
|
| 228 |
+
else:
|
| 229 |
+
batch_size = prompt_embeds.shape[0]
|
| 230 |
+
|
| 231 |
+
if prompt_embeds is None:
|
| 232 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 233 |
+
prompt=prompt,
|
| 234 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 235 |
+
max_sequence_length=max_sequence_length,
|
| 236 |
+
device=device,
|
| 237 |
+
dtype=dtype,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 241 |
+
negative_prompt = negative_prompt or ""
|
| 242 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 243 |
+
|
| 244 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 245 |
+
raise TypeError(
|
| 246 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 247 |
+
f" {type(prompt)}."
|
| 248 |
+
)
|
| 249 |
+
elif batch_size != len(negative_prompt):
|
| 250 |
+
raise ValueError(
|
| 251 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 252 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 253 |
+
" the batch size of `prompt`."
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 257 |
+
prompt=negative_prompt,
|
| 258 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 259 |
+
max_sequence_length=max_sequence_length,
|
| 260 |
+
device=device,
|
| 261 |
+
dtype=dtype,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
return prompt_embeds, negative_prompt_embeds
|
| 265 |
+
|
| 266 |
+
def prepare_latents(
|
| 267 |
+
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
| 268 |
+
):
|
| 269 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 270 |
+
raise ValueError(
|
| 271 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 272 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
shape = (
|
| 276 |
+
batch_size,
|
| 277 |
+
num_channels_latents,
|
| 278 |
+
(num_frames - 1) // self.vae.temporal_compression_ratio + 1,
|
| 279 |
+
height // self.vae.spatial_compression_ratio,
|
| 280 |
+
width // self.vae.spatial_compression_ratio,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if latents is None:
|
| 284 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 285 |
+
else:
|
| 286 |
+
latents = latents.to(device)
|
| 287 |
+
|
| 288 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 289 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 290 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 291 |
+
return latents
|
| 292 |
+
|
| 293 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 294 |
+
frames = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 295 |
+
frames = (frames / 2 + 0.5).clamp(0, 1)
|
| 296 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 297 |
+
frames = frames.cpu().float().numpy()
|
| 298 |
+
return frames
|
| 299 |
+
|
| 300 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 301 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 302 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 303 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 304 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 305 |
+
# and should be between [0, 1]
|
| 306 |
+
|
| 307 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 308 |
+
extra_step_kwargs = {}
|
| 309 |
+
if accepts_eta:
|
| 310 |
+
extra_step_kwargs["eta"] = eta
|
| 311 |
+
|
| 312 |
+
# check if the scheduler accepts generator
|
| 313 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 314 |
+
if accepts_generator:
|
| 315 |
+
extra_step_kwargs["generator"] = generator
|
| 316 |
+
return extra_step_kwargs
|
| 317 |
+
|
| 318 |
+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
| 319 |
+
def check_inputs(
|
| 320 |
+
self,
|
| 321 |
+
prompt,
|
| 322 |
+
height,
|
| 323 |
+
width,
|
| 324 |
+
negative_prompt,
|
| 325 |
+
callback_on_step_end_tensor_inputs,
|
| 326 |
+
prompt_embeds=None,
|
| 327 |
+
negative_prompt_embeds=None,
|
| 328 |
+
):
|
| 329 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 330 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 331 |
+
|
| 332 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 333 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 334 |
+
):
|
| 335 |
+
raise ValueError(
|
| 336 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 337 |
+
)
|
| 338 |
+
if prompt is not None and prompt_embeds is not None:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 341 |
+
" only forward one of the two."
|
| 342 |
+
)
|
| 343 |
+
elif prompt is None and prompt_embeds is None:
|
| 344 |
+
raise ValueError(
|
| 345 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 346 |
+
)
|
| 347 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 348 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 349 |
+
|
| 350 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 351 |
+
raise ValueError(
|
| 352 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 353 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 357 |
+
raise ValueError(
|
| 358 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 359 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 363 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 364 |
+
raise ValueError(
|
| 365 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 366 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 367 |
+
f" {negative_prompt_embeds.shape}."
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
@property
|
| 371 |
+
def guidance_scale(self):
|
| 372 |
+
return self._guidance_scale
|
| 373 |
+
|
| 374 |
+
@property
|
| 375 |
+
def num_timesteps(self):
|
| 376 |
+
return self._num_timesteps
|
| 377 |
+
|
| 378 |
+
@property
|
| 379 |
+
def attention_kwargs(self):
|
| 380 |
+
return self._attention_kwargs
|
| 381 |
+
|
| 382 |
+
@property
|
| 383 |
+
def interrupt(self):
|
| 384 |
+
return self._interrupt
|
| 385 |
+
|
| 386 |
+
@torch.no_grad()
|
| 387 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 388 |
+
def __call__(
|
| 389 |
+
self,
|
| 390 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 391 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 392 |
+
height: int = 480,
|
| 393 |
+
width: int = 720,
|
| 394 |
+
num_frames: int = 49,
|
| 395 |
+
num_inference_steps: int = 50,
|
| 396 |
+
timesteps: Optional[List[int]] = None,
|
| 397 |
+
guidance_scale: float = 6,
|
| 398 |
+
num_videos_per_prompt: int = 1,
|
| 399 |
+
eta: float = 0.0,
|
| 400 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 401 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 402 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 403 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 404 |
+
output_type: str = "numpy",
|
| 405 |
+
return_dict: bool = False,
|
| 406 |
+
callback_on_step_end: Optional[
|
| 407 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 408 |
+
] = None,
|
| 409 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 410 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 411 |
+
max_sequence_length: int = 512,
|
| 412 |
+
boundary: float = 0.875,
|
| 413 |
+
comfyui_progressbar: bool = False,
|
| 414 |
+
shift: int = 5,
|
| 415 |
+
) -> Union[WanPipelineOutput, Tuple]:
|
| 416 |
+
"""
|
| 417 |
+
Function invoked when calling the pipeline for generation.
|
| 418 |
+
Args:
|
| 419 |
+
|
| 420 |
+
Examples:
|
| 421 |
+
|
| 422 |
+
Returns:
|
| 423 |
+
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 427 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 428 |
+
num_videos_per_prompt = 1
|
| 429 |
+
|
| 430 |
+
# 1. Check inputs. Raise error if not correct
|
| 431 |
+
self.check_inputs(
|
| 432 |
+
prompt,
|
| 433 |
+
height,
|
| 434 |
+
width,
|
| 435 |
+
negative_prompt,
|
| 436 |
+
callback_on_step_end_tensor_inputs,
|
| 437 |
+
prompt_embeds,
|
| 438 |
+
negative_prompt_embeds,
|
| 439 |
+
)
|
| 440 |
+
self._guidance_scale = guidance_scale
|
| 441 |
+
self._attention_kwargs = attention_kwargs
|
| 442 |
+
self._interrupt = False
|
| 443 |
+
|
| 444 |
+
# 2. Default call parameters
|
| 445 |
+
if prompt is not None and isinstance(prompt, str):
|
| 446 |
+
batch_size = 1
|
| 447 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 448 |
+
batch_size = len(prompt)
|
| 449 |
+
else:
|
| 450 |
+
batch_size = prompt_embeds.shape[0]
|
| 451 |
+
|
| 452 |
+
device = self._execution_device
|
| 453 |
+
weight_dtype = self.text_encoder.dtype
|
| 454 |
+
|
| 455 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 456 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 457 |
+
# corresponds to doing no classifier free guidance.
|
| 458 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 459 |
+
|
| 460 |
+
# 3. Encode input prompt
|
| 461 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 462 |
+
prompt,
|
| 463 |
+
negative_prompt,
|
| 464 |
+
do_classifier_free_guidance,
|
| 465 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 466 |
+
prompt_embeds=prompt_embeds,
|
| 467 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 468 |
+
max_sequence_length=max_sequence_length,
|
| 469 |
+
device=device,
|
| 470 |
+
)
|
| 471 |
+
if do_classifier_free_guidance:
|
| 472 |
+
in_prompt_embeds = negative_prompt_embeds + prompt_embeds
|
| 473 |
+
else:
|
| 474 |
+
in_prompt_embeds = prompt_embeds
|
| 475 |
+
|
| 476 |
+
# 4. Prepare timesteps
|
| 477 |
+
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
| 478 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
|
| 479 |
+
elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
|
| 480 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
|
| 481 |
+
timesteps = self.scheduler.timesteps
|
| 482 |
+
elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
|
| 483 |
+
sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
|
| 484 |
+
timesteps, _ = retrieve_timesteps(
|
| 485 |
+
self.scheduler,
|
| 486 |
+
device=device,
|
| 487 |
+
sigmas=sampling_sigmas)
|
| 488 |
+
else:
|
| 489 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 490 |
+
self._num_timesteps = len(timesteps)
|
| 491 |
+
if comfyui_progressbar:
|
| 492 |
+
from comfy.utils import ProgressBar
|
| 493 |
+
pbar = ProgressBar(num_inference_steps + 1)
|
| 494 |
+
|
| 495 |
+
# 5. Prepare latents
|
| 496 |
+
latent_channels = self.transformer.config.in_channels
|
| 497 |
+
latents = self.prepare_latents(
|
| 498 |
+
batch_size * num_videos_per_prompt,
|
| 499 |
+
latent_channels,
|
| 500 |
+
num_frames,
|
| 501 |
+
height,
|
| 502 |
+
width,
|
| 503 |
+
weight_dtype,
|
| 504 |
+
device,
|
| 505 |
+
generator,
|
| 506 |
+
latents,
|
| 507 |
+
)
|
| 508 |
+
if comfyui_progressbar:
|
| 509 |
+
pbar.update(1)
|
| 510 |
+
|
| 511 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 512 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 513 |
+
|
| 514 |
+
target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
|
| 515 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
|
| 516 |
+
# 7. Denoising loop
|
| 517 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 518 |
+
self.transformer.num_inference_steps = num_inference_steps
|
| 519 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 520 |
+
for i, t in enumerate(timesteps):
|
| 521 |
+
self.transformer.current_steps = i
|
| 522 |
+
|
| 523 |
+
if self.interrupt:
|
| 524 |
+
continue
|
| 525 |
+
|
| 526 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 527 |
+
if hasattr(self.scheduler, "scale_model_input"):
|
| 528 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 529 |
+
|
| 530 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 531 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 532 |
+
|
| 533 |
+
if self.transformer_2 is not None:
|
| 534 |
+
if t >= boundary * self.scheduler.config.num_train_timesteps:
|
| 535 |
+
local_transformer = self.transformer_2
|
| 536 |
+
else:
|
| 537 |
+
local_transformer = self.transformer
|
| 538 |
+
else:
|
| 539 |
+
local_transformer = self.transformer
|
| 540 |
+
|
| 541 |
+
# predict noise model_output
|
| 542 |
+
with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
|
| 543 |
+
noise_pred = local_transformer(
|
| 544 |
+
x=latent_model_input,
|
| 545 |
+
context=in_prompt_embeds,
|
| 546 |
+
t=timestep,
|
| 547 |
+
seq_len=seq_len,
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
# perform guidance
|
| 551 |
+
if do_classifier_free_guidance:
|
| 552 |
+
if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))):
|
| 553 |
+
sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0]
|
| 554 |
+
else:
|
| 555 |
+
sample_guide_scale = self.guidance_scale
|
| 556 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 557 |
+
noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond)
|
| 558 |
+
|
| 559 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 560 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 561 |
+
|
| 562 |
+
if callback_on_step_end is not None:
|
| 563 |
+
callback_kwargs = {}
|
| 564 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 565 |
+
callback_kwargs[k] = locals()[k]
|
| 566 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 567 |
+
|
| 568 |
+
latents = callback_outputs.pop("latents", latents)
|
| 569 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 570 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 571 |
+
|
| 572 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 573 |
+
progress_bar.update()
|
| 574 |
+
if comfyui_progressbar:
|
| 575 |
+
pbar.update(1)
|
| 576 |
+
|
| 577 |
+
if output_type == "numpy":
|
| 578 |
+
video = self.decode_latents(latents)
|
| 579 |
+
elif not output_type == "latent":
|
| 580 |
+
video = self.decode_latents(latents)
|
| 581 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 582 |
+
else:
|
| 583 |
+
video = latents
|
| 584 |
+
|
| 585 |
+
# Offload all models
|
| 586 |
+
self.maybe_free_model_hooks()
|
| 587 |
+
|
| 588 |
+
if not return_dict:
|
| 589 |
+
video = torch.from_numpy(video)
|
| 590 |
+
|
| 591 |
+
return WanPipelineOutput(videos=video)
|
videox_fun/ui/cogvideox_fun_ui.py
ADDED
|
@@ -0,0 +1,722 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
|
| 2 |
+
"""
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from safetensors import safe_open
|
| 12 |
+
|
| 13 |
+
from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
|
| 14 |
+
from ..models import (AutoencoderKLCogVideoX, CogVideoXTransformer3DModel,
|
| 15 |
+
T5EncoderModel, T5Tokenizer)
|
| 16 |
+
from ..pipeline import (CogVideoXFunControlPipeline,
|
| 17 |
+
CogVideoXFunInpaintPipeline, CogVideoXFunPipeline)
|
| 18 |
+
from ..utils.fp8_optimization import (convert_model_weight_to_float8, replace_parameters_by_name,
|
| 19 |
+
convert_weight_dtype_wrapper)
|
| 20 |
+
from ..utils.lora_utils import merge_lora, unmerge_lora
|
| 21 |
+
from ..utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, timer,
|
| 22 |
+
get_video_to_video_latent, save_videos_grid)
|
| 23 |
+
from .controller import (Fun_Controller, Fun_Controller_Client,
|
| 24 |
+
all_cheduler_dict, css, ddpm_scheduler_dict,
|
| 25 |
+
flow_scheduler_dict, gradio_version,
|
| 26 |
+
gradio_version_is_above_4)
|
| 27 |
+
from .ui import (create_cfg_and_seedbox,
|
| 28 |
+
create_fake_finetune_models_checkpoints,
|
| 29 |
+
create_fake_height_width, create_fake_model_checkpoints,
|
| 30 |
+
create_fake_model_type, create_finetune_models_checkpoints,
|
| 31 |
+
create_generation_method,
|
| 32 |
+
create_generation_methods_and_video_length,
|
| 33 |
+
create_height_width, create_model_checkpoints,
|
| 34 |
+
create_model_type, create_prompts, create_samplers,
|
| 35 |
+
create_ui_outputs)
|
| 36 |
+
from ..dist import set_multi_gpus_devices, shard_model
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class CogVideoXFunController(Fun_Controller):
|
| 40 |
+
def update_diffusion_transformer(self, diffusion_transformer_dropdown):
|
| 41 |
+
print(f"Update diffusion transformer: {diffusion_transformer_dropdown}")
|
| 42 |
+
self.diffusion_transformer_dropdown = diffusion_transformer_dropdown
|
| 43 |
+
if diffusion_transformer_dropdown == "none":
|
| 44 |
+
return gr.update()
|
| 45 |
+
self.vae = AutoencoderKLCogVideoX.from_pretrained(
|
| 46 |
+
diffusion_transformer_dropdown,
|
| 47 |
+
subfolder="vae",
|
| 48 |
+
).to(self.weight_dtype)
|
| 49 |
+
|
| 50 |
+
# Get Transformer
|
| 51 |
+
self.transformer = CogVideoXTransformer3DModel.from_pretrained(
|
| 52 |
+
diffusion_transformer_dropdown,
|
| 53 |
+
subfolder="transformer",
|
| 54 |
+
low_cpu_mem_usage=True,
|
| 55 |
+
).to(self.weight_dtype)
|
| 56 |
+
|
| 57 |
+
# Get tokenizer and text_encoder
|
| 58 |
+
tokenizer = T5Tokenizer.from_pretrained(
|
| 59 |
+
diffusion_transformer_dropdown, subfolder="tokenizer"
|
| 60 |
+
)
|
| 61 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
| 62 |
+
diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Get pipeline
|
| 66 |
+
if self.model_type == "Inpaint":
|
| 67 |
+
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
| 68 |
+
self.pipeline = CogVideoXFunInpaintPipeline(
|
| 69 |
+
tokenizer=tokenizer,
|
| 70 |
+
text_encoder=text_encoder,
|
| 71 |
+
vae=self.vae,
|
| 72 |
+
transformer=self.transformer,
|
| 73 |
+
scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
self.pipeline = CogVideoXFunPipeline(
|
| 77 |
+
tokenizer=tokenizer,
|
| 78 |
+
text_encoder=text_encoder,
|
| 79 |
+
vae=self.vae,
|
| 80 |
+
transformer=self.transformer,
|
| 81 |
+
scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
self.pipeline = CogVideoXFunControlPipeline(
|
| 85 |
+
diffusion_transformer_dropdown,
|
| 86 |
+
vae=self.vae,
|
| 87 |
+
transformer=self.transformer,
|
| 88 |
+
scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
|
| 89 |
+
torch_dtype=self.weight_dtype
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if self.ulysses_degree > 1 or self.ring_degree > 1:
|
| 93 |
+
from functools import partial
|
| 94 |
+
self.transformer.enable_multi_gpus_inference()
|
| 95 |
+
if self.fsdp_dit:
|
| 96 |
+
shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype)
|
| 97 |
+
self.pipeline.transformer = shard_fn(self.pipeline.transformer)
|
| 98 |
+
print("Add FSDP DIT")
|
| 99 |
+
if self.fsdp_text_encoder:
|
| 100 |
+
shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype)
|
| 101 |
+
self.pipeline.text_encoder = shard_fn(self.pipeline.text_encoder)
|
| 102 |
+
print("Add FSDP TEXT ENCODER")
|
| 103 |
+
|
| 104 |
+
if self.compile_dit:
|
| 105 |
+
for i in range(len(self.pipeline.transformer.transformer_blocks)):
|
| 106 |
+
self.pipeline.transformer.transformer_blocks[i] = torch.compile(self.pipeline.transformer.transformer_blocks[i])
|
| 107 |
+
print("Add Compile")
|
| 108 |
+
|
| 109 |
+
if self.GPU_memory_mode == "sequential_cpu_offload":
|
| 110 |
+
self.pipeline.enable_sequential_cpu_offload(device=self.device)
|
| 111 |
+
elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
|
| 112 |
+
convert_model_weight_to_float8(self.pipeline.transformer, exclude_module_name=[], device=self.device)
|
| 113 |
+
convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype)
|
| 114 |
+
self.pipeline.enable_model_cpu_offload(device=self.device)
|
| 115 |
+
elif self.GPU_memory_mode == "model_cpu_offload":
|
| 116 |
+
self.pipeline.enable_model_cpu_offload(device=self.device)
|
| 117 |
+
elif self.GPU_memory_mode == "model_full_load_and_qfloat8":
|
| 118 |
+
convert_model_weight_to_float8(self.pipeline.transformer, exclude_module_name=[], device=self.device)
|
| 119 |
+
convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype)
|
| 120 |
+
self.pipeline.to(self.device)
|
| 121 |
+
else:
|
| 122 |
+
self.pipeline.to(self.device)
|
| 123 |
+
print("Update diffusion transformer done")
|
| 124 |
+
return gr.update()
|
| 125 |
+
|
| 126 |
+
@timer
|
| 127 |
+
def generate(
|
| 128 |
+
self,
|
| 129 |
+
diffusion_transformer_dropdown,
|
| 130 |
+
base_model_dropdown,
|
| 131 |
+
lora_model_dropdown,
|
| 132 |
+
lora_alpha_slider,
|
| 133 |
+
prompt_textbox,
|
| 134 |
+
negative_prompt_textbox,
|
| 135 |
+
sampler_dropdown,
|
| 136 |
+
sample_step_slider,
|
| 137 |
+
resize_method,
|
| 138 |
+
width_slider,
|
| 139 |
+
height_slider,
|
| 140 |
+
base_resolution,
|
| 141 |
+
generation_method,
|
| 142 |
+
length_slider,
|
| 143 |
+
overlap_video_length,
|
| 144 |
+
partial_video_length,
|
| 145 |
+
cfg_scale_slider,
|
| 146 |
+
start_image,
|
| 147 |
+
end_image,
|
| 148 |
+
validation_video,
|
| 149 |
+
validation_video_mask,
|
| 150 |
+
control_video,
|
| 151 |
+
denoise_strength,
|
| 152 |
+
seed_textbox,
|
| 153 |
+
ref_image = None,
|
| 154 |
+
enable_teacache = None,
|
| 155 |
+
teacache_threshold = None,
|
| 156 |
+
num_skip_start_steps = None,
|
| 157 |
+
teacache_offload = None,
|
| 158 |
+
cfg_skip_ratio = None,
|
| 159 |
+
enable_riflex = None,
|
| 160 |
+
riflex_k = None,
|
| 161 |
+
base_model_2_dropdown=None,
|
| 162 |
+
lora_model_2_dropdown=None,
|
| 163 |
+
fps = None,
|
| 164 |
+
is_api = False,
|
| 165 |
+
):
|
| 166 |
+
self.clear_cache()
|
| 167 |
+
|
| 168 |
+
print(f"Input checking.")
|
| 169 |
+
_, comment = self.input_check(
|
| 170 |
+
resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api
|
| 171 |
+
)
|
| 172 |
+
print(f"Input checking down")
|
| 173 |
+
if comment != "OK":
|
| 174 |
+
return "", comment
|
| 175 |
+
is_image = True if generation_method == "Image Generation" else False
|
| 176 |
+
|
| 177 |
+
if self.base_model_path != base_model_dropdown:
|
| 178 |
+
self.update_base_model(base_model_dropdown)
|
| 179 |
+
|
| 180 |
+
if self.lora_model_path != lora_model_dropdown:
|
| 181 |
+
self.update_lora_model(lora_model_dropdown)
|
| 182 |
+
|
| 183 |
+
print(f"Load scheduler.")
|
| 184 |
+
self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
|
| 185 |
+
print(f"Load scheduler down.")
|
| 186 |
+
|
| 187 |
+
if resize_method == "Resize according to Reference":
|
| 188 |
+
print(f"Calculate height and width according to Reference.")
|
| 189 |
+
height_slider, width_slider = self.get_height_width_from_reference(
|
| 190 |
+
base_resolution, start_image, validation_video, control_video,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
if self.lora_model_path != "none":
|
| 194 |
+
print(f"Merge Lora.")
|
| 195 |
+
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 196 |
+
print(f"Merge Lora done.")
|
| 197 |
+
|
| 198 |
+
if fps is None:
|
| 199 |
+
fps = 8
|
| 200 |
+
|
| 201 |
+
print(f"Generate seed.")
|
| 202 |
+
if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
|
| 203 |
+
else: seed_textbox = np.random.randint(0, 1e10)
|
| 204 |
+
generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox))
|
| 205 |
+
print(f"Generate seed done.")
|
| 206 |
+
|
| 207 |
+
try:
|
| 208 |
+
print(f"Generation.")
|
| 209 |
+
if self.model_type == "Inpaint":
|
| 210 |
+
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
| 211 |
+
if generation_method == "Long Video Generation":
|
| 212 |
+
if validation_video is not None:
|
| 213 |
+
raise gr.Error(f"Video to Video is not Support Long Video Generation now.")
|
| 214 |
+
init_frames = 0
|
| 215 |
+
last_frames = init_frames + partial_video_length
|
| 216 |
+
while init_frames < length_slider:
|
| 217 |
+
if last_frames >= length_slider:
|
| 218 |
+
_partial_video_length = length_slider - init_frames
|
| 219 |
+
_partial_video_length = int((_partial_video_length - 1) // self.vae.config.temporal_compression_ratio * self.vae.config.temporal_compression_ratio) + 1
|
| 220 |
+
|
| 221 |
+
if _partial_video_length <= 0:
|
| 222 |
+
break
|
| 223 |
+
else:
|
| 224 |
+
_partial_video_length = partial_video_length
|
| 225 |
+
|
| 226 |
+
if last_frames >= length_slider:
|
| 227 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
|
| 228 |
+
else:
|
| 229 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
|
| 230 |
+
|
| 231 |
+
with torch.no_grad():
|
| 232 |
+
sample = self.pipeline(
|
| 233 |
+
prompt_textbox,
|
| 234 |
+
negative_prompt = negative_prompt_textbox,
|
| 235 |
+
num_inference_steps = sample_step_slider,
|
| 236 |
+
guidance_scale = cfg_scale_slider,
|
| 237 |
+
width = width_slider,
|
| 238 |
+
height = height_slider,
|
| 239 |
+
num_frames = _partial_video_length,
|
| 240 |
+
generator = generator,
|
| 241 |
+
|
| 242 |
+
video = input_video,
|
| 243 |
+
mask_video = input_video_mask,
|
| 244 |
+
strength = 1,
|
| 245 |
+
).videos
|
| 246 |
+
|
| 247 |
+
if init_frames != 0:
|
| 248 |
+
mix_ratio = torch.from_numpy(
|
| 249 |
+
np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
|
| 250 |
+
).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
| 251 |
+
|
| 252 |
+
new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
|
| 253 |
+
sample[:, :, :overlap_video_length] * mix_ratio
|
| 254 |
+
new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
|
| 255 |
+
|
| 256 |
+
sample = new_sample
|
| 257 |
+
else:
|
| 258 |
+
new_sample = sample
|
| 259 |
+
|
| 260 |
+
if last_frames >= length_slider:
|
| 261 |
+
break
|
| 262 |
+
|
| 263 |
+
start_image = [
|
| 264 |
+
Image.fromarray(
|
| 265 |
+
(sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
|
| 266 |
+
) for _index in range(-overlap_video_length, 0)
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
init_frames = init_frames + _partial_video_length - overlap_video_length
|
| 270 |
+
last_frames = init_frames + _partial_video_length
|
| 271 |
+
else:
|
| 272 |
+
if validation_video is not None:
|
| 273 |
+
input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps)
|
| 274 |
+
strength = denoise_strength
|
| 275 |
+
else:
|
| 276 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
|
| 277 |
+
strength = 1
|
| 278 |
+
|
| 279 |
+
sample = self.pipeline(
|
| 280 |
+
prompt_textbox,
|
| 281 |
+
negative_prompt = negative_prompt_textbox,
|
| 282 |
+
num_inference_steps = sample_step_slider,
|
| 283 |
+
guidance_scale = cfg_scale_slider,
|
| 284 |
+
width = width_slider,
|
| 285 |
+
height = height_slider,
|
| 286 |
+
num_frames = length_slider if not is_image else 1,
|
| 287 |
+
generator = generator,
|
| 288 |
+
|
| 289 |
+
video = input_video,
|
| 290 |
+
mask_video = input_video_mask,
|
| 291 |
+
strength = strength,
|
| 292 |
+
).videos
|
| 293 |
+
else:
|
| 294 |
+
sample = self.pipeline(
|
| 295 |
+
prompt_textbox,
|
| 296 |
+
negative_prompt = negative_prompt_textbox,
|
| 297 |
+
num_inference_steps = sample_step_slider,
|
| 298 |
+
guidance_scale = cfg_scale_slider,
|
| 299 |
+
width = width_slider,
|
| 300 |
+
height = height_slider,
|
| 301 |
+
num_frames = length_slider if not is_image else 1,
|
| 302 |
+
generator = generator
|
| 303 |
+
).videos
|
| 304 |
+
else:
|
| 305 |
+
input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps)
|
| 306 |
+
|
| 307 |
+
sample = self.pipeline(
|
| 308 |
+
prompt_textbox,
|
| 309 |
+
negative_prompt = negative_prompt_textbox,
|
| 310 |
+
num_inference_steps = sample_step_slider,
|
| 311 |
+
guidance_scale = cfg_scale_slider,
|
| 312 |
+
width = width_slider,
|
| 313 |
+
height = height_slider,
|
| 314 |
+
num_frames = length_slider if not is_image else 1,
|
| 315 |
+
generator = generator,
|
| 316 |
+
|
| 317 |
+
control_video = input_video,
|
| 318 |
+
).videos
|
| 319 |
+
except Exception as e:
|
| 320 |
+
self.auto_model_clear_cache(self.pipeline.transformer)
|
| 321 |
+
self.auto_model_clear_cache(self.pipeline.text_encoder)
|
| 322 |
+
self.auto_model_clear_cache(self.pipeline.vae)
|
| 323 |
+
self.clear_cache()
|
| 324 |
+
|
| 325 |
+
print(f"Error. error information is {str(e)}")
|
| 326 |
+
if self.lora_model_path != "none":
|
| 327 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 328 |
+
if is_api:
|
| 329 |
+
return "", f"Error. error information is {str(e)}"
|
| 330 |
+
else:
|
| 331 |
+
return gr.update(), gr.update(), f"Error. error information is {str(e)}"
|
| 332 |
+
|
| 333 |
+
self.clear_cache()
|
| 334 |
+
# lora part
|
| 335 |
+
if self.lora_model_path != "none":
|
| 336 |
+
print(f"Unmerge Lora.")
|
| 337 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 338 |
+
print(f"Unmerge Lora done.")
|
| 339 |
+
|
| 340 |
+
print(f"Saving outputs.")
|
| 341 |
+
save_sample_path = self.save_outputs(
|
| 342 |
+
is_image, length_slider, sample, fps=fps
|
| 343 |
+
)
|
| 344 |
+
print(f"Saving outputs done.")
|
| 345 |
+
|
| 346 |
+
if is_image or length_slider == 1:
|
| 347 |
+
if is_api:
|
| 348 |
+
return save_sample_path, "Success"
|
| 349 |
+
else:
|
| 350 |
+
if gradio_version_is_above_4:
|
| 351 |
+
return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
|
| 352 |
+
else:
|
| 353 |
+
return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
|
| 354 |
+
else:
|
| 355 |
+
if is_api:
|
| 356 |
+
return save_sample_path, "Success"
|
| 357 |
+
else:
|
| 358 |
+
if gradio_version_is_above_4:
|
| 359 |
+
return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
|
| 360 |
+
else:
|
| 361 |
+
return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
|
| 362 |
+
|
| 363 |
+
CogVideoXFunController_Host = CogVideoXFunController
|
| 364 |
+
CogVideoXFunController_Client = Fun_Controller_Client
|
| 365 |
+
|
| 366 |
+
def ui(GPU_memory_mode, scheduler_dict, compile_dit, weight_dtype, savedir_sample=None):
|
| 367 |
+
controller = CogVideoXFunController(
|
| 368 |
+
GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
|
| 369 |
+
compile_dit=compile_dit,
|
| 370 |
+
weight_dtype=weight_dtype, savedir_sample=savedir_sample,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
with gr.Blocks(css=css) as demo:
|
| 374 |
+
gr.Markdown(
|
| 375 |
+
"""
|
| 376 |
+
# CogVideoX-Fun:
|
| 377 |
+
|
| 378 |
+
A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
|
| 379 |
+
|
| 380 |
+
[Github](https://github.com/aigc-apps/CogVideoX-Fun/)
|
| 381 |
+
"""
|
| 382 |
+
)
|
| 383 |
+
with gr.Column(variant="panel"):
|
| 384 |
+
model_type = create_model_type(visible=True)
|
| 385 |
+
diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \
|
| 386 |
+
create_model_checkpoints(controller, visible=True)
|
| 387 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \
|
| 388 |
+
create_finetune_models_checkpoints(controller, visible=True)
|
| 389 |
+
|
| 390 |
+
with gr.Column(variant="panel"):
|
| 391 |
+
prompt_textbox, negative_prompt_textbox = create_prompts()
|
| 392 |
+
|
| 393 |
+
with gr.Row():
|
| 394 |
+
with gr.Column():
|
| 395 |
+
sampler_dropdown, sample_step_slider = create_samplers(controller)
|
| 396 |
+
|
| 397 |
+
resize_method, width_slider, height_slider, base_resolution = create_height_width(
|
| 398 |
+
default_height = 384, default_width = 672, maximum_height = 1344,
|
| 399 |
+
maximum_width = 1344,
|
| 400 |
+
)
|
| 401 |
+
gr.Markdown(
|
| 402 |
+
"""
|
| 403 |
+
V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames.
|
| 404 |
+
(V1.0和V1.1支持最大49��视频生成,V1.5支持最大85帧视频生成。)
|
| 405 |
+
"""
|
| 406 |
+
)
|
| 407 |
+
generation_method, length_slider, overlap_video_length, partial_video_length = \
|
| 408 |
+
create_generation_methods_and_video_length(
|
| 409 |
+
["Video Generation", "Image Generation", "Long Video Generation"],
|
| 410 |
+
default_video_length=49,
|
| 411 |
+
maximum_video_length=85,
|
| 412 |
+
)
|
| 413 |
+
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
|
| 414 |
+
["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], prompt_textbox
|
| 415 |
+
)
|
| 416 |
+
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
|
| 417 |
+
|
| 418 |
+
generate_button = gr.Button(value="Generate (生成)", variant='primary')
|
| 419 |
+
|
| 420 |
+
result_image, result_video, infer_progress = create_ui_outputs()
|
| 421 |
+
|
| 422 |
+
model_type.change(
|
| 423 |
+
fn=controller.update_model_type,
|
| 424 |
+
inputs=[model_type],
|
| 425 |
+
outputs=[]
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
def upload_generation_method(generation_method):
|
| 429 |
+
if generation_method == "Video Generation":
|
| 430 |
+
return [gr.update(visible=True, maximum=85, value=49, interactive=True), gr.update(visible=False), gr.update(visible=False)]
|
| 431 |
+
elif generation_method == "Image Generation":
|
| 432 |
+
return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)]
|
| 433 |
+
else:
|
| 434 |
+
return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)]
|
| 435 |
+
generation_method.change(
|
| 436 |
+
upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
def upload_source_method(source_method):
|
| 440 |
+
if source_method == "Text to Video (文本到视频)":
|
| 441 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 442 |
+
elif source_method == "Image to Video (图片到视频)":
|
| 443 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 444 |
+
elif source_method == "Video to Video (视频到视频)":
|
| 445 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
|
| 446 |
+
else:
|
| 447 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
|
| 448 |
+
source_method.change(
|
| 449 |
+
upload_source_method, source_method, [
|
| 450 |
+
image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
|
| 451 |
+
validation_video, validation_video_mask, control_video
|
| 452 |
+
]
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
def upload_resize_method(resize_method):
|
| 456 |
+
if resize_method == "Generate by":
|
| 457 |
+
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
|
| 458 |
+
else:
|
| 459 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
| 460 |
+
resize_method.change(
|
| 461 |
+
upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
generate_button.click(
|
| 465 |
+
fn=controller.generate,
|
| 466 |
+
inputs=[
|
| 467 |
+
diffusion_transformer_dropdown,
|
| 468 |
+
base_model_dropdown,
|
| 469 |
+
lora_model_dropdown,
|
| 470 |
+
lora_alpha_slider,
|
| 471 |
+
prompt_textbox,
|
| 472 |
+
negative_prompt_textbox,
|
| 473 |
+
sampler_dropdown,
|
| 474 |
+
sample_step_slider,
|
| 475 |
+
resize_method,
|
| 476 |
+
width_slider,
|
| 477 |
+
height_slider,
|
| 478 |
+
base_resolution,
|
| 479 |
+
generation_method,
|
| 480 |
+
length_slider,
|
| 481 |
+
overlap_video_length,
|
| 482 |
+
partial_video_length,
|
| 483 |
+
cfg_scale_slider,
|
| 484 |
+
start_image,
|
| 485 |
+
end_image,
|
| 486 |
+
validation_video,
|
| 487 |
+
validation_video_mask,
|
| 488 |
+
control_video,
|
| 489 |
+
denoise_strength,
|
| 490 |
+
seed_textbox,
|
| 491 |
+
],
|
| 492 |
+
outputs=[result_image, result_video, infer_progress]
|
| 493 |
+
)
|
| 494 |
+
return demo, controller
|
| 495 |
+
|
| 496 |
+
def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, compile_dit, weight_dtype, savedir_sample=None):
|
| 497 |
+
controller = CogVideoXFunController_Host(
|
| 498 |
+
GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type,
|
| 499 |
+
compile_dit=compile_dit,
|
| 500 |
+
weight_dtype=weight_dtype, savedir_sample=savedir_sample,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
with gr.Blocks(css=css) as demo:
|
| 504 |
+
gr.Markdown(
|
| 505 |
+
"""
|
| 506 |
+
# CogVideoX-Fun
|
| 507 |
+
|
| 508 |
+
A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
|
| 509 |
+
|
| 510 |
+
[Github](https://github.com/aigc-apps/CogVideoX-Fun/)
|
| 511 |
+
"""
|
| 512 |
+
)
|
| 513 |
+
with gr.Column(variant="panel"):
|
| 514 |
+
model_type = create_fake_model_type(visible=False)
|
| 515 |
+
diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
|
| 516 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
|
| 517 |
+
|
| 518 |
+
with gr.Column(variant="panel"):
|
| 519 |
+
prompt_textbox, negative_prompt_textbox = create_prompts()
|
| 520 |
+
|
| 521 |
+
with gr.Row():
|
| 522 |
+
with gr.Column():
|
| 523 |
+
sampler_dropdown, sample_step_slider = create_samplers(controller)
|
| 524 |
+
|
| 525 |
+
resize_method, width_slider, height_slider, base_resolution = create_height_width(
|
| 526 |
+
default_height = 384, default_width = 672, maximum_height = 1344,
|
| 527 |
+
maximum_width = 1344,
|
| 528 |
+
)
|
| 529 |
+
gr.Markdown(
|
| 530 |
+
"""
|
| 531 |
+
V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames.
|
| 532 |
+
(V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。)
|
| 533 |
+
"""
|
| 534 |
+
)
|
| 535 |
+
generation_method, length_slider, overlap_video_length, partial_video_length = \
|
| 536 |
+
create_generation_methods_and_video_length(
|
| 537 |
+
["Video Generation", "Image Generation"],
|
| 538 |
+
default_video_length=49,
|
| 539 |
+
maximum_video_length=85,
|
| 540 |
+
)
|
| 541 |
+
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
|
| 542 |
+
["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], prompt_textbox
|
| 543 |
+
)
|
| 544 |
+
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
|
| 545 |
+
|
| 546 |
+
generate_button = gr.Button(value="Generate (生成)", variant='primary')
|
| 547 |
+
|
| 548 |
+
result_image, result_video, infer_progress = create_ui_outputs()
|
| 549 |
+
|
| 550 |
+
def upload_generation_method(generation_method):
|
| 551 |
+
if generation_method == "Video Generation":
|
| 552 |
+
return gr.update(visible=True, minimum=8, maximum=85, value=49, interactive=True)
|
| 553 |
+
elif generation_method == "Image Generation":
|
| 554 |
+
return gr.update(minimum=1, maximum=1, value=1, interactive=False)
|
| 555 |
+
generation_method.change(
|
| 556 |
+
upload_generation_method, generation_method, [length_slider]
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
def upload_source_method(source_method):
|
| 560 |
+
if source_method == "Text to Video (文本到视频)":
|
| 561 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 562 |
+
elif source_method == "Image to Video (图片到视频)":
|
| 563 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 564 |
+
elif source_method == "Video to Video (视频到视频)":
|
| 565 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
|
| 566 |
+
else:
|
| 567 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
|
| 568 |
+
source_method.change(
|
| 569 |
+
upload_source_method, source_method, [
|
| 570 |
+
image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
|
| 571 |
+
validation_video, validation_video_mask, control_video
|
| 572 |
+
]
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
def upload_resize_method(resize_method):
|
| 576 |
+
if resize_method == "Generate by":
|
| 577 |
+
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
|
| 578 |
+
else:
|
| 579 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
| 580 |
+
resize_method.change(
|
| 581 |
+
upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
generate_button.click(
|
| 585 |
+
fn=controller.generate,
|
| 586 |
+
inputs=[
|
| 587 |
+
diffusion_transformer_dropdown,
|
| 588 |
+
base_model_dropdown,
|
| 589 |
+
lora_model_dropdown,
|
| 590 |
+
lora_alpha_slider,
|
| 591 |
+
prompt_textbox,
|
| 592 |
+
negative_prompt_textbox,
|
| 593 |
+
sampler_dropdown,
|
| 594 |
+
sample_step_slider,
|
| 595 |
+
resize_method,
|
| 596 |
+
width_slider,
|
| 597 |
+
height_slider,
|
| 598 |
+
base_resolution,
|
| 599 |
+
generation_method,
|
| 600 |
+
length_slider,
|
| 601 |
+
overlap_video_length,
|
| 602 |
+
partial_video_length,
|
| 603 |
+
cfg_scale_slider,
|
| 604 |
+
start_image,
|
| 605 |
+
end_image,
|
| 606 |
+
validation_video,
|
| 607 |
+
validation_video_mask,
|
| 608 |
+
control_video,
|
| 609 |
+
denoise_strength,
|
| 610 |
+
seed_textbox,
|
| 611 |
+
],
|
| 612 |
+
outputs=[result_image, result_video, infer_progress]
|
| 613 |
+
)
|
| 614 |
+
return demo, controller
|
| 615 |
+
|
| 616 |
+
def ui_client(scheduler_dict, model_name, savedir_sample=None):
|
| 617 |
+
controller = CogVideoXFunController_Client(scheduler_dict, savedir_sample)
|
| 618 |
+
|
| 619 |
+
with gr.Blocks(css=css) as demo:
|
| 620 |
+
gr.Markdown(
|
| 621 |
+
"""
|
| 622 |
+
# CogVideoX-Fun
|
| 623 |
+
|
| 624 |
+
A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
|
| 625 |
+
|
| 626 |
+
[Github](https://github.com/aigc-apps/CogVideoX-Fun/)
|
| 627 |
+
"""
|
| 628 |
+
)
|
| 629 |
+
with gr.Column(variant="panel"):
|
| 630 |
+
diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
|
| 631 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
|
| 632 |
+
|
| 633 |
+
with gr.Column(variant="panel"):
|
| 634 |
+
prompt_textbox, negative_prompt_textbox = create_prompts()
|
| 635 |
+
|
| 636 |
+
with gr.Row():
|
| 637 |
+
with gr.Column():
|
| 638 |
+
sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50)
|
| 639 |
+
|
| 640 |
+
resize_method, width_slider, height_slider, base_resolution = create_fake_height_width(
|
| 641 |
+
default_height = 384, default_width = 672, maximum_height = 1344,
|
| 642 |
+
maximum_width = 1344,
|
| 643 |
+
)
|
| 644 |
+
gr.Markdown(
|
| 645 |
+
"""
|
| 646 |
+
V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames.
|
| 647 |
+
(V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。)
|
| 648 |
+
"""
|
| 649 |
+
)
|
| 650 |
+
generation_method, length_slider, overlap_video_length, partial_video_length = \
|
| 651 |
+
create_generation_methods_and_video_length(
|
| 652 |
+
["Video Generation", "Image Generation"],
|
| 653 |
+
default_video_length=49,
|
| 654 |
+
maximum_video_length=85,
|
| 655 |
+
)
|
| 656 |
+
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
|
| 657 |
+
["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"], prompt_textbox
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
|
| 661 |
+
|
| 662 |
+
generate_button = gr.Button(value="Generate (生成)", variant='primary')
|
| 663 |
+
|
| 664 |
+
result_image, result_video, infer_progress = create_ui_outputs()
|
| 665 |
+
|
| 666 |
+
def upload_generation_method(generation_method):
|
| 667 |
+
if generation_method == "Video Generation":
|
| 668 |
+
return gr.update(visible=True, minimum=5, maximum=85, value=49, interactive=True)
|
| 669 |
+
elif generation_method == "Image Generation":
|
| 670 |
+
return gr.update(minimum=1, maximum=1, value=1, interactive=False)
|
| 671 |
+
generation_method.change(
|
| 672 |
+
upload_generation_method, generation_method, [length_slider]
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
def upload_source_method(source_method):
|
| 676 |
+
if source_method == "Text to Video (文本到视频)":
|
| 677 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 678 |
+
elif source_method == "Image to Video (图片到视频)":
|
| 679 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)]
|
| 680 |
+
else:
|
| 681 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()]
|
| 682 |
+
source_method.change(
|
| 683 |
+
upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask]
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
def upload_resize_method(resize_method):
|
| 687 |
+
if resize_method == "Generate by":
|
| 688 |
+
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
|
| 689 |
+
else:
|
| 690 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
| 691 |
+
resize_method.change(
|
| 692 |
+
upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
generate_button.click(
|
| 696 |
+
fn=controller.generate,
|
| 697 |
+
inputs=[
|
| 698 |
+
diffusion_transformer_dropdown,
|
| 699 |
+
base_model_dropdown,
|
| 700 |
+
lora_model_dropdown,
|
| 701 |
+
lora_alpha_slider,
|
| 702 |
+
prompt_textbox,
|
| 703 |
+
negative_prompt_textbox,
|
| 704 |
+
sampler_dropdown,
|
| 705 |
+
sample_step_slider,
|
| 706 |
+
resize_method,
|
| 707 |
+
width_slider,
|
| 708 |
+
height_slider,
|
| 709 |
+
base_resolution,
|
| 710 |
+
generation_method,
|
| 711 |
+
length_slider,
|
| 712 |
+
cfg_scale_slider,
|
| 713 |
+
start_image,
|
| 714 |
+
end_image,
|
| 715 |
+
validation_video,
|
| 716 |
+
validation_video_mask,
|
| 717 |
+
denoise_strength,
|
| 718 |
+
seed_textbox,
|
| 719 |
+
],
|
| 720 |
+
outputs=[result_image, result_video, infer_progress]
|
| 721 |
+
)
|
| 722 |
+
return demo, controller
|
videox_fun/ui/controller.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
|
| 2 |
+
"""
|
| 3 |
+
import base64
|
| 4 |
+
import gc
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import hashlib
|
| 8 |
+
import random
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from glob import glob
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
import gradio as gr
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pkg_resources
|
| 16 |
+
import requests
|
| 17 |
+
import torch
|
| 18 |
+
from diffusers import (CogVideoXDDIMScheduler, DDIMScheduler,
|
| 19 |
+
DPMSolverMultistepScheduler,
|
| 20 |
+
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
|
| 21 |
+
FlowMatchEulerDiscreteScheduler, PNDMScheduler)
|
| 22 |
+
from omegaconf import OmegaConf
|
| 23 |
+
from PIL import Image
|
| 24 |
+
from safetensors import safe_open
|
| 25 |
+
|
| 26 |
+
from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
|
| 27 |
+
from ..utils.utils import save_videos_grid
|
| 28 |
+
from ..utils.fm_solvers import FlowDPMSolverMultistepScheduler
|
| 29 |
+
from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 30 |
+
from ..dist import set_multi_gpus_devices
|
| 31 |
+
|
| 32 |
+
gradio_version = pkg_resources.get_distribution("gradio").version
|
| 33 |
+
gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False
|
| 34 |
+
|
| 35 |
+
css = """
|
| 36 |
+
.toolbutton {
|
| 37 |
+
margin-buttom: 0em 0em 0em 0em;
|
| 38 |
+
max-width: 2.5em;
|
| 39 |
+
min-width: 2.5em !important;
|
| 40 |
+
height: 2.5em;
|
| 41 |
+
}
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
ddpm_scheduler_dict = {
|
| 45 |
+
"Euler": EulerDiscreteScheduler,
|
| 46 |
+
"Euler A": EulerAncestralDiscreteScheduler,
|
| 47 |
+
"DPM++": DPMSolverMultistepScheduler,
|
| 48 |
+
"PNDM": PNDMScheduler,
|
| 49 |
+
"DDIM": DDIMScheduler,
|
| 50 |
+
"DDIM_Origin": DDIMScheduler,
|
| 51 |
+
"DDIM_Cog": CogVideoXDDIMScheduler,
|
| 52 |
+
}
|
| 53 |
+
flow_scheduler_dict = {
|
| 54 |
+
"Flow": FlowMatchEulerDiscreteScheduler,
|
| 55 |
+
"Flow_Unipc": FlowUniPCMultistepScheduler,
|
| 56 |
+
"Flow_DPM++": FlowDPMSolverMultistepScheduler,
|
| 57 |
+
}
|
| 58 |
+
all_cheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict}
|
| 59 |
+
|
| 60 |
+
class Fun_Controller:
|
| 61 |
+
def __init__(
|
| 62 |
+
self, GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
|
| 63 |
+
config_path=None, ulysses_degree=1, ring_degree=1,
|
| 64 |
+
fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False,
|
| 65 |
+
weight_dtype=None, savedir_sample=None,
|
| 66 |
+
):
|
| 67 |
+
# config dirs
|
| 68 |
+
self.basedir = os.getcwd()
|
| 69 |
+
self.config_dir = os.path.join(self.basedir, "config")
|
| 70 |
+
self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer")
|
| 71 |
+
self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
|
| 72 |
+
self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
|
| 73 |
+
if savedir_sample is None:
|
| 74 |
+
self.savedir_sample = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
|
| 75 |
+
else:
|
| 76 |
+
self.savedir_sample = savedir_sample
|
| 77 |
+
os.makedirs(self.savedir_sample, exist_ok=True)
|
| 78 |
+
|
| 79 |
+
self.GPU_memory_mode = GPU_memory_mode
|
| 80 |
+
self.model_name = model_name
|
| 81 |
+
self.diffusion_transformer_dropdown = model_name
|
| 82 |
+
self.scheduler_dict = scheduler_dict
|
| 83 |
+
self.model_type = model_type
|
| 84 |
+
if config_path is not None:
|
| 85 |
+
self.config_path = os.path.realpath(config_path)
|
| 86 |
+
self.config = OmegaConf.load(config_path)
|
| 87 |
+
else:
|
| 88 |
+
self.config_path = None
|
| 89 |
+
self.ulysses_degree = ulysses_degree
|
| 90 |
+
self.ring_degree = ring_degree
|
| 91 |
+
self.fsdp_dit = fsdp_dit
|
| 92 |
+
self.fsdp_text_encoder = fsdp_text_encoder
|
| 93 |
+
self.compile_dit = compile_dit
|
| 94 |
+
self.weight_dtype = weight_dtype
|
| 95 |
+
self.device = set_multi_gpus_devices(self.ulysses_degree, self.ring_degree)
|
| 96 |
+
|
| 97 |
+
self.diffusion_transformer_list = []
|
| 98 |
+
self.motion_module_list = []
|
| 99 |
+
self.personalized_model_list = []
|
| 100 |
+
self.config_list = []
|
| 101 |
+
|
| 102 |
+
# config models
|
| 103 |
+
self.tokenizer = None
|
| 104 |
+
self.text_encoder = None
|
| 105 |
+
self.vae = None
|
| 106 |
+
self.transformer = None
|
| 107 |
+
self.transformer_2 = None
|
| 108 |
+
self.pipeline = None
|
| 109 |
+
self.base_model_path = "none"
|
| 110 |
+
self.base_model_2_path = "none"
|
| 111 |
+
self.lora_model_path = "none"
|
| 112 |
+
self.lora_model_2_path = "none"
|
| 113 |
+
|
| 114 |
+
self.refresh_config()
|
| 115 |
+
self.refresh_diffusion_transformer()
|
| 116 |
+
self.refresh_personalized_model()
|
| 117 |
+
if model_name != None:
|
| 118 |
+
self.update_diffusion_transformer(model_name)
|
| 119 |
+
|
| 120 |
+
def refresh_config(self):
|
| 121 |
+
config_list = []
|
| 122 |
+
for root, dirs, files in os.walk(self.config_dir):
|
| 123 |
+
for file in files:
|
| 124 |
+
if file.endswith(('.yaml', '.yml')):
|
| 125 |
+
full_path = os.path.join(root, file)
|
| 126 |
+
config_list.append(full_path)
|
| 127 |
+
self.config_list = config_list
|
| 128 |
+
|
| 129 |
+
def refresh_diffusion_transformer(self):
|
| 130 |
+
self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
|
| 131 |
+
|
| 132 |
+
def refresh_personalized_model(self):
|
| 133 |
+
personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
|
| 134 |
+
self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
|
| 135 |
+
|
| 136 |
+
def update_model_type(self, model_type):
|
| 137 |
+
self.model_type = model_type
|
| 138 |
+
|
| 139 |
+
def update_config(self, config_dropdown):
|
| 140 |
+
self.config_path = config_dropdown
|
| 141 |
+
self.config = OmegaConf.load(config_dropdown)
|
| 142 |
+
print(f"Update config: {config_dropdown}")
|
| 143 |
+
|
| 144 |
+
def update_diffusion_transformer(self, diffusion_transformer_dropdown):
|
| 145 |
+
pass
|
| 146 |
+
|
| 147 |
+
def update_base_model(self, base_model_dropdown, is_checkpoint_2=False):
|
| 148 |
+
if not is_checkpoint_2:
|
| 149 |
+
self.base_model_path = base_model_dropdown
|
| 150 |
+
else:
|
| 151 |
+
self.base_model_2_path = base_model_dropdown
|
| 152 |
+
print(f"Update base model: {base_model_dropdown}")
|
| 153 |
+
if base_model_dropdown == "none":
|
| 154 |
+
return gr.update()
|
| 155 |
+
if self.transformer is None and not is_checkpoint_2:
|
| 156 |
+
gr.Info(f"Please select a pretrained model path.")
|
| 157 |
+
print(f"Please select a pretrained model path.")
|
| 158 |
+
return gr.update(value=None)
|
| 159 |
+
elif self.transformer_2 is None and is_checkpoint_2:
|
| 160 |
+
gr.Info(f"Please select a pretrained model path.")
|
| 161 |
+
print(f"Please select a pretrained model path.")
|
| 162 |
+
return gr.update(value=None)
|
| 163 |
+
else:
|
| 164 |
+
base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
|
| 165 |
+
base_model_state_dict = {}
|
| 166 |
+
with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
|
| 167 |
+
for key in f.keys():
|
| 168 |
+
base_model_state_dict[key] = f.get_tensor(key)
|
| 169 |
+
if not is_checkpoint_2:
|
| 170 |
+
self.transformer.load_state_dict(base_model_state_dict, strict=False)
|
| 171 |
+
else:
|
| 172 |
+
self.transformer_2.load_state_dict(base_model_state_dict, strict=False)
|
| 173 |
+
print("Update base model done")
|
| 174 |
+
return gr.update()
|
| 175 |
+
|
| 176 |
+
def update_lora_model(self, lora_model_dropdown, is_checkpoint_2=False):
|
| 177 |
+
print(f"Update lora model: {lora_model_dropdown}")
|
| 178 |
+
if lora_model_dropdown == "none":
|
| 179 |
+
self.lora_model_path = "none"
|
| 180 |
+
return gr.update()
|
| 181 |
+
lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
|
| 182 |
+
if not is_checkpoint_2:
|
| 183 |
+
self.lora_model_path = lora_model_dropdown
|
| 184 |
+
else:
|
| 185 |
+
self.lora_model_2_path = lora_model_dropdown
|
| 186 |
+
return gr.update()
|
| 187 |
+
|
| 188 |
+
def clear_cache(self,):
|
| 189 |
+
gc.collect()
|
| 190 |
+
torch.cuda.empty_cache()
|
| 191 |
+
torch.cuda.ipc_collect()
|
| 192 |
+
|
| 193 |
+
def auto_model_clear_cache(self, model):
|
| 194 |
+
origin_device = model.device
|
| 195 |
+
model = model.to("cpu")
|
| 196 |
+
gc.collect()
|
| 197 |
+
torch.cuda.empty_cache()
|
| 198 |
+
torch.cuda.ipc_collect()
|
| 199 |
+
model = model.to(origin_device)
|
| 200 |
+
|
| 201 |
+
def input_check(self,
|
| 202 |
+
resize_method,
|
| 203 |
+
generation_method,
|
| 204 |
+
start_image,
|
| 205 |
+
end_image,
|
| 206 |
+
validation_video,
|
| 207 |
+
control_video,
|
| 208 |
+
is_api = False,
|
| 209 |
+
):
|
| 210 |
+
if self.transformer is None:
|
| 211 |
+
if is_api:
|
| 212 |
+
return "", f"Please select a pretrained model path."
|
| 213 |
+
else:
|
| 214 |
+
raise gr.Error(f"Please select a pretrained model path.")
|
| 215 |
+
|
| 216 |
+
if control_video is not None and self.model_type == "Inpaint":
|
| 217 |
+
if is_api:
|
| 218 |
+
return "", f"If specifying the control video, please set the model_type == \"Control\". "
|
| 219 |
+
else:
|
| 220 |
+
raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ")
|
| 221 |
+
|
| 222 |
+
if control_video is None and self.model_type == "Control":
|
| 223 |
+
if is_api:
|
| 224 |
+
return "", f"If set the model_type == \"Control\", please specifying the control video. "
|
| 225 |
+
else:
|
| 226 |
+
raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ")
|
| 227 |
+
|
| 228 |
+
if resize_method == "Resize according to Reference":
|
| 229 |
+
if start_image is None and validation_video is None and control_video is None:
|
| 230 |
+
if is_api:
|
| 231 |
+
return "", f"Please upload an image when using \"Resize according to Reference\"."
|
| 232 |
+
else:
|
| 233 |
+
raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
|
| 234 |
+
|
| 235 |
+
if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None:
|
| 236 |
+
if is_api:
|
| 237 |
+
return "", f"Please select an image to video pretrained model while using image to video."
|
| 238 |
+
else:
|
| 239 |
+
raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
|
| 240 |
+
|
| 241 |
+
if self.transformer.config.in_channels == self.vae.config.latent_channels and generation_method == "Long Video Generation":
|
| 242 |
+
if is_api:
|
| 243 |
+
return "", f"Please select an image to video pretrained model while using long video generation."
|
| 244 |
+
else:
|
| 245 |
+
raise gr.Error(f"Please select an image to video pretrained model while using long video generation.")
|
| 246 |
+
|
| 247 |
+
if start_image is None and end_image is not None:
|
| 248 |
+
if is_api:
|
| 249 |
+
return "", f"If specifying the ending image of the video, please specify a starting image of the video."
|
| 250 |
+
else:
|
| 251 |
+
raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
|
| 252 |
+
return "", "OK"
|
| 253 |
+
|
| 254 |
+
def get_height_width_from_reference(
|
| 255 |
+
self,
|
| 256 |
+
base_resolution,
|
| 257 |
+
start_image,
|
| 258 |
+
validation_video,
|
| 259 |
+
control_video,
|
| 260 |
+
):
|
| 261 |
+
spatial_compression_ratio = self.vae.config.spatial_compression_ratio if hasattr(self.vae.config, "spatial_compression_ratio") else 8
|
| 262 |
+
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
|
| 263 |
+
if self.model_type == "Inpaint":
|
| 264 |
+
if validation_video is not None:
|
| 265 |
+
original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
|
| 266 |
+
else:
|
| 267 |
+
original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
|
| 268 |
+
else:
|
| 269 |
+
original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size
|
| 270 |
+
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
|
| 271 |
+
height_slider, width_slider = [int(x / spatial_compression_ratio / 2) * spatial_compression_ratio * 2 for x in closest_size]
|
| 272 |
+
return height_slider, width_slider
|
| 273 |
+
|
| 274 |
+
def save_outputs(self, is_image, length_slider, sample, fps):
|
| 275 |
+
def save_results():
|
| 276 |
+
if not os.path.exists(self.savedir_sample):
|
| 277 |
+
os.makedirs(self.savedir_sample, exist_ok=True)
|
| 278 |
+
index = len([path for path in os.listdir(self.savedir_sample)]) + 1
|
| 279 |
+
prefix = str(index).zfill(8)
|
| 280 |
+
|
| 281 |
+
md5_hash = hashlib.md5(sample.cpu().numpy().tobytes()).hexdigest()
|
| 282 |
+
|
| 283 |
+
if is_image or length_slider == 1:
|
| 284 |
+
save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.png")
|
| 285 |
+
print(f"Saving to {save_sample_path}")
|
| 286 |
+
image = sample[0, :, 0]
|
| 287 |
+
image = image.transpose(0, 1).transpose(1, 2)
|
| 288 |
+
image = (image * 255).numpy().astype(np.uint8)
|
| 289 |
+
image = Image.fromarray(image)
|
| 290 |
+
image.save(save_sample_path)
|
| 291 |
+
|
| 292 |
+
else:
|
| 293 |
+
save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.mp4")
|
| 294 |
+
print(f"Saving to {save_sample_path}")
|
| 295 |
+
save_videos_grid(sample, save_sample_path, fps=fps)
|
| 296 |
+
return save_sample_path
|
| 297 |
+
|
| 298 |
+
if self.ulysses_degree * self.ring_degree > 1:
|
| 299 |
+
import torch.distributed as dist
|
| 300 |
+
if dist.get_rank() == 0:
|
| 301 |
+
save_sample_path = save_results()
|
| 302 |
+
else:
|
| 303 |
+
save_sample_path = None
|
| 304 |
+
else:
|
| 305 |
+
save_sample_path = save_results()
|
| 306 |
+
return save_sample_path
|
| 307 |
+
|
| 308 |
+
def generate(
|
| 309 |
+
self,
|
| 310 |
+
diffusion_transformer_dropdown,
|
| 311 |
+
base_model_dropdown,
|
| 312 |
+
lora_model_dropdown,
|
| 313 |
+
lora_alpha_slider,
|
| 314 |
+
prompt_textbox,
|
| 315 |
+
negative_prompt_textbox,
|
| 316 |
+
sampler_dropdown,
|
| 317 |
+
sample_step_slider,
|
| 318 |
+
resize_method,
|
| 319 |
+
width_slider,
|
| 320 |
+
height_slider,
|
| 321 |
+
base_resolution,
|
| 322 |
+
generation_method,
|
| 323 |
+
length_slider,
|
| 324 |
+
overlap_video_length,
|
| 325 |
+
partial_video_length,
|
| 326 |
+
cfg_scale_slider,
|
| 327 |
+
start_image,
|
| 328 |
+
end_image,
|
| 329 |
+
validation_video,
|
| 330 |
+
validation_video_mask,
|
| 331 |
+
control_video,
|
| 332 |
+
denoise_strength,
|
| 333 |
+
seed_textbox,
|
| 334 |
+
enable_teacache = None,
|
| 335 |
+
teacache_threshold = None,
|
| 336 |
+
num_skip_start_steps = None,
|
| 337 |
+
teacache_offload = None,
|
| 338 |
+
cfg_skip_ratio = None,
|
| 339 |
+
enable_riflex = None,
|
| 340 |
+
riflex_k = None,
|
| 341 |
+
is_api = False,
|
| 342 |
+
):
|
| 343 |
+
pass
|
| 344 |
+
|
| 345 |
+
def post_to_host(
|
| 346 |
+
diffusion_transformer_dropdown,
|
| 347 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
|
| 348 |
+
prompt_textbox, negative_prompt_textbox,
|
| 349 |
+
sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
|
| 350 |
+
base_resolution, generation_method, length_slider, cfg_scale_slider,
|
| 351 |
+
start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox,
|
| 352 |
+
ref_image = None, enable_teacache = None, teacache_threshold = None, num_skip_start_steps = None,
|
| 353 |
+
teacache_offload = None, cfg_skip_ratio = None,enable_riflex = None, riflex_k = None,
|
| 354 |
+
):
|
| 355 |
+
if start_image is not None:
|
| 356 |
+
with open(start_image, 'rb') as file:
|
| 357 |
+
file_content = file.read()
|
| 358 |
+
start_image_encoded_content = base64.b64encode(file_content)
|
| 359 |
+
start_image = start_image_encoded_content.decode('utf-8')
|
| 360 |
+
|
| 361 |
+
if end_image is not None:
|
| 362 |
+
with open(end_image, 'rb') as file:
|
| 363 |
+
file_content = file.read()
|
| 364 |
+
end_image_encoded_content = base64.b64encode(file_content)
|
| 365 |
+
end_image = end_image_encoded_content.decode('utf-8')
|
| 366 |
+
|
| 367 |
+
if validation_video is not None:
|
| 368 |
+
with open(validation_video, 'rb') as file:
|
| 369 |
+
file_content = file.read()
|
| 370 |
+
validation_video_encoded_content = base64.b64encode(file_content)
|
| 371 |
+
validation_video = validation_video_encoded_content.decode('utf-8')
|
| 372 |
+
|
| 373 |
+
if validation_video_mask is not None:
|
| 374 |
+
with open(validation_video_mask, 'rb') as file:
|
| 375 |
+
file_content = file.read()
|
| 376 |
+
validation_video_mask_encoded_content = base64.b64encode(file_content)
|
| 377 |
+
validation_video_mask = validation_video_mask_encoded_content.decode('utf-8')
|
| 378 |
+
|
| 379 |
+
if ref_image is not None:
|
| 380 |
+
with open(ref_image, 'rb') as file:
|
| 381 |
+
file_content = file.read()
|
| 382 |
+
ref_image_encoded_content = base64.b64encode(file_content)
|
| 383 |
+
ref_image = ref_image_encoded_content.decode('utf-8')
|
| 384 |
+
|
| 385 |
+
datas = {
|
| 386 |
+
"base_model_path": base_model_dropdown,
|
| 387 |
+
"lora_model_path": lora_model_dropdown,
|
| 388 |
+
"lora_alpha_slider": lora_alpha_slider,
|
| 389 |
+
"prompt_textbox": prompt_textbox,
|
| 390 |
+
"negative_prompt_textbox": negative_prompt_textbox,
|
| 391 |
+
"sampler_dropdown": sampler_dropdown,
|
| 392 |
+
"sample_step_slider": sample_step_slider,
|
| 393 |
+
"resize_method": resize_method,
|
| 394 |
+
"width_slider": width_slider,
|
| 395 |
+
"height_slider": height_slider,
|
| 396 |
+
"base_resolution": base_resolution,
|
| 397 |
+
"generation_method": generation_method,
|
| 398 |
+
"length_slider": length_slider,
|
| 399 |
+
"cfg_scale_slider": cfg_scale_slider,
|
| 400 |
+
"start_image": start_image,
|
| 401 |
+
"end_image": end_image,
|
| 402 |
+
"validation_video": validation_video,
|
| 403 |
+
"validation_video_mask": validation_video_mask,
|
| 404 |
+
"denoise_strength": denoise_strength,
|
| 405 |
+
"seed_textbox": seed_textbox,
|
| 406 |
+
|
| 407 |
+
"ref_image": ref_image,
|
| 408 |
+
"enable_teacache": enable_teacache,
|
| 409 |
+
"teacache_threshold": teacache_threshold,
|
| 410 |
+
"num_skip_start_steps": num_skip_start_steps,
|
| 411 |
+
"teacache_offload": teacache_offload,
|
| 412 |
+
"cfg_skip_ratio": cfg_skip_ratio,
|
| 413 |
+
"enable_riflex": enable_riflex,
|
| 414 |
+
"riflex_k": riflex_k,
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
session = requests.session()
|
| 418 |
+
session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")})
|
| 419 |
+
|
| 420 |
+
response = session.post(url=f'{os.environ.get("EAS_URL")}/videox_fun/infer_forward', json=datas, timeout=300)
|
| 421 |
+
|
| 422 |
+
outputs = response.json()
|
| 423 |
+
return outputs
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class Fun_Controller_Client:
|
| 427 |
+
def __init__(self, scheduler_dict, savedir_sample):
|
| 428 |
+
self.basedir = os.getcwd()
|
| 429 |
+
if savedir_sample is None:
|
| 430 |
+
self.savedir_sample = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
|
| 431 |
+
else:
|
| 432 |
+
self.savedir_sample = savedir_sample
|
| 433 |
+
os.makedirs(self.savedir_sample, exist_ok=True)
|
| 434 |
+
|
| 435 |
+
self.scheduler_dict = scheduler_dict
|
| 436 |
+
|
| 437 |
+
def generate(
|
| 438 |
+
self,
|
| 439 |
+
diffusion_transformer_dropdown,
|
| 440 |
+
base_model_dropdown,
|
| 441 |
+
lora_model_dropdown,
|
| 442 |
+
lora_alpha_slider,
|
| 443 |
+
prompt_textbox,
|
| 444 |
+
negative_prompt_textbox,
|
| 445 |
+
sampler_dropdown,
|
| 446 |
+
sample_step_slider,
|
| 447 |
+
resize_method,
|
| 448 |
+
width_slider,
|
| 449 |
+
height_slider,
|
| 450 |
+
base_resolution,
|
| 451 |
+
generation_method,
|
| 452 |
+
length_slider,
|
| 453 |
+
cfg_scale_slider,
|
| 454 |
+
start_image,
|
| 455 |
+
end_image,
|
| 456 |
+
validation_video,
|
| 457 |
+
validation_video_mask,
|
| 458 |
+
denoise_strength,
|
| 459 |
+
seed_textbox,
|
| 460 |
+
ref_image = None,
|
| 461 |
+
enable_teacache = None,
|
| 462 |
+
teacache_threshold = None,
|
| 463 |
+
num_skip_start_steps = None,
|
| 464 |
+
teacache_offload = None,
|
| 465 |
+
cfg_skip_ratio = None,
|
| 466 |
+
enable_riflex = None,
|
| 467 |
+
riflex_k = None,
|
| 468 |
+
):
|
| 469 |
+
is_image = True if generation_method == "Image Generation" else False
|
| 470 |
+
|
| 471 |
+
outputs = post_to_host(
|
| 472 |
+
diffusion_transformer_dropdown,
|
| 473 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
|
| 474 |
+
prompt_textbox, negative_prompt_textbox,
|
| 475 |
+
sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
|
| 476 |
+
base_resolution, generation_method, length_slider, cfg_scale_slider,
|
| 477 |
+
start_image, end_image, validation_video, validation_video_mask, denoise_strength,
|
| 478 |
+
seed_textbox, ref_image = ref_image, enable_teacache = enable_teacache, teacache_threshold = teacache_threshold,
|
| 479 |
+
num_skip_start_steps = num_skip_start_steps, teacache_offload = teacache_offload,
|
| 480 |
+
cfg_skip_ratio = cfg_skip_ratio, enable_riflex = enable_riflex, riflex_k = riflex_k,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
try:
|
| 484 |
+
base64_encoding = outputs["base64_encoding"]
|
| 485 |
+
except:
|
| 486 |
+
return gr.Image(visible=False, value=None), gr.Video(None, visible=True), outputs["message"]
|
| 487 |
+
|
| 488 |
+
decoded_data = base64.b64decode(base64_encoding)
|
| 489 |
+
|
| 490 |
+
if not os.path.exists(self.savedir_sample):
|
| 491 |
+
os.makedirs(self.savedir_sample, exist_ok=True)
|
| 492 |
+
md5_hash = hashlib.md5(decoded_data).hexdigest()
|
| 493 |
+
|
| 494 |
+
index = len([path for path in os.listdir(self.savedir_sample)]) + 1
|
| 495 |
+
prefix = str(index).zfill(8)
|
| 496 |
+
|
| 497 |
+
if is_image or length_slider == 1:
|
| 498 |
+
save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.png")
|
| 499 |
+
print(f"Saving to {save_sample_path}")
|
| 500 |
+
with open(save_sample_path, "wb") as file:
|
| 501 |
+
file.write(decoded_data)
|
| 502 |
+
if gradio_version_is_above_4:
|
| 503 |
+
return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
|
| 504 |
+
else:
|
| 505 |
+
return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
|
| 506 |
+
else:
|
| 507 |
+
save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.mp4")
|
| 508 |
+
print(f"Saving to {save_sample_path}")
|
| 509 |
+
with open(save_sample_path, "wb") as file:
|
| 510 |
+
file.write(decoded_data)
|
| 511 |
+
if gradio_version_is_above_4:
|
| 512 |
+
return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
|
| 513 |
+
else:
|
| 514 |
+
return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
|
videox_fun/ui/ui.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def create_model_type(visible):
|
| 7 |
+
gr.Markdown(
|
| 8 |
+
"""
|
| 9 |
+
### Model Type.
|
| 10 |
+
""",
|
| 11 |
+
visible=visible,
|
| 12 |
+
)
|
| 13 |
+
with gr.Row():
|
| 14 |
+
model_type = gr.Dropdown(
|
| 15 |
+
label="The model type of the model",
|
| 16 |
+
choices=["Inpaint", "Control"],
|
| 17 |
+
value="Inpaint",
|
| 18 |
+
visible=visible,
|
| 19 |
+
interactive=True,
|
| 20 |
+
)
|
| 21 |
+
return model_type
|
| 22 |
+
|
| 23 |
+
def create_fake_model_type(visible):
|
| 24 |
+
gr.Markdown(
|
| 25 |
+
"""
|
| 26 |
+
### Model Type.
|
| 27 |
+
""",
|
| 28 |
+
visible=visible,
|
| 29 |
+
)
|
| 30 |
+
with gr.Row():
|
| 31 |
+
model_type = gr.Dropdown(
|
| 32 |
+
label="The model type of the model",
|
| 33 |
+
choices=["Inpaint", "Control"],
|
| 34 |
+
value="Inpaint",
|
| 35 |
+
interactive=False,
|
| 36 |
+
visible=visible,
|
| 37 |
+
)
|
| 38 |
+
return model_type
|
| 39 |
+
|
| 40 |
+
def create_model_checkpoints(controller, visible, default_model="none"):
|
| 41 |
+
gr.Markdown(
|
| 42 |
+
"""
|
| 43 |
+
### Model checkpoints.
|
| 44 |
+
"""
|
| 45 |
+
)
|
| 46 |
+
with gr.Row(visible=visible):
|
| 47 |
+
diffusion_transformer_dropdown = gr.Dropdown(
|
| 48 |
+
label="Pretrained Model Path",
|
| 49 |
+
choices=list(set(controller.diffusion_transformer_list + [default_model])),
|
| 50 |
+
value=default_model,
|
| 51 |
+
interactive=True,
|
| 52 |
+
)
|
| 53 |
+
diffusion_transformer_dropdown.change(
|
| 54 |
+
fn=controller.update_diffusion_transformer,
|
| 55 |
+
inputs=[diffusion_transformer_dropdown],
|
| 56 |
+
outputs=[diffusion_transformer_dropdown]
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
| 60 |
+
def refresh_diffusion_transformer():
|
| 61 |
+
controller.refresh_diffusion_transformer()
|
| 62 |
+
return gr.update(choices=controller.diffusion_transformer_list)
|
| 63 |
+
diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
|
| 64 |
+
|
| 65 |
+
return diffusion_transformer_dropdown, diffusion_transformer_refresh_button
|
| 66 |
+
|
| 67 |
+
def create_fake_model_checkpoints(model_name, visible):
|
| 68 |
+
gr.Markdown(
|
| 69 |
+
"""
|
| 70 |
+
### Model checkpoints.
|
| 71 |
+
"""
|
| 72 |
+
)
|
| 73 |
+
with gr.Row(visible=visible):
|
| 74 |
+
diffusion_transformer_dropdown = gr.Dropdown(
|
| 75 |
+
label="Pretrained Model Path",
|
| 76 |
+
choices=[model_name],
|
| 77 |
+
value=model_name,
|
| 78 |
+
interactive=False,
|
| 79 |
+
)
|
| 80 |
+
return diffusion_transformer_dropdown
|
| 81 |
+
|
| 82 |
+
def create_finetune_models_checkpoints(controller, visible, add_checkpoint_2=False, default_lora="none"):
|
| 83 |
+
with gr.Row(visible=visible):
|
| 84 |
+
base_model_dropdown = gr.Dropdown(
|
| 85 |
+
label="Select base Dreambooth model",
|
| 86 |
+
choices=["none"] + controller.personalized_model_list,
|
| 87 |
+
value="none",
|
| 88 |
+
interactive=True,
|
| 89 |
+
)
|
| 90 |
+
if add_checkpoint_2:
|
| 91 |
+
base_model_2_dropdown = gr.Dropdown(
|
| 92 |
+
label="Select base Dreambooth model",
|
| 93 |
+
choices=["none"] + controller.personalized_model_list,
|
| 94 |
+
value="none",
|
| 95 |
+
interactive=True,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
lora_model_dropdown = gr.Dropdown(
|
| 99 |
+
label="Select LoRA model",
|
| 100 |
+
choices=list(set(["none"] + controller.personalized_model_list + [default_lora])),
|
| 101 |
+
value=default_lora,
|
| 102 |
+
interactive=True,
|
| 103 |
+
)
|
| 104 |
+
if add_checkpoint_2:
|
| 105 |
+
lora_model_2_dropdown = gr.Dropdown(
|
| 106 |
+
label="Select LoRA model",
|
| 107 |
+
choices=["none"] + controller.personalized_model_list,
|
| 108 |
+
value="none",
|
| 109 |
+
interactive=True,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.55, minimum=0, maximum=2, interactive=True)
|
| 113 |
+
|
| 114 |
+
personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
| 115 |
+
def update_personalized_model():
|
| 116 |
+
controller.refresh_personalized_model()
|
| 117 |
+
return [
|
| 118 |
+
gr.update(choices=controller.personalized_model_list),
|
| 119 |
+
gr.update(choices=["none"] + controller.personalized_model_list)
|
| 120 |
+
]
|
| 121 |
+
personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
|
| 122 |
+
|
| 123 |
+
if not add_checkpoint_2:
|
| 124 |
+
return base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button
|
| 125 |
+
else:
|
| 126 |
+
return [base_model_dropdown, base_model_2_dropdown], [lora_model_dropdown, lora_model_2_dropdown], \
|
| 127 |
+
lora_alpha_slider, personalized_refresh_button
|
| 128 |
+
|
| 129 |
+
def create_fake_finetune_models_checkpoints(visible):
|
| 130 |
+
with gr.Row():
|
| 131 |
+
base_model_dropdown = gr.Dropdown(
|
| 132 |
+
label="Select base Dreambooth model",
|
| 133 |
+
choices=["none"],
|
| 134 |
+
value="none",
|
| 135 |
+
interactive=False,
|
| 136 |
+
visible=False
|
| 137 |
+
)
|
| 138 |
+
with gr.Column(visible=False):
|
| 139 |
+
gr.Markdown(
|
| 140 |
+
"""
|
| 141 |
+
### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/CogVideoX-Fun/wiki/Training-Lora).
|
| 142 |
+
"""
|
| 143 |
+
)
|
| 144 |
+
with gr.Row():
|
| 145 |
+
lora_model_dropdown = gr.Dropdown(
|
| 146 |
+
label="Select LoRA model",
|
| 147 |
+
choices=["none"],
|
| 148 |
+
value="none",
|
| 149 |
+
interactive=True,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.55, minimum=0, maximum=2, interactive=True)
|
| 153 |
+
|
| 154 |
+
return base_model_dropdown, lora_model_dropdown, lora_alpha_slider
|
| 155 |
+
|
| 156 |
+
def create_teacache_params(
|
| 157 |
+
enable_teacache = True,
|
| 158 |
+
teacache_threshold = 0.10,
|
| 159 |
+
num_skip_start_steps = 1,
|
| 160 |
+
teacache_offload = False,
|
| 161 |
+
):
|
| 162 |
+
enable_teacache = gr.Checkbox(label="Enable TeaCache", value=enable_teacache)
|
| 163 |
+
teacache_threshold = gr.Slider(0.00, 0.25, value=teacache_threshold, step=0.01, label="TeaCache Threshold")
|
| 164 |
+
num_skip_start_steps = gr.Slider(0, 10, value=num_skip_start_steps, step=5, label="Number of Skip Start Steps")
|
| 165 |
+
teacache_offload = gr.Checkbox(label="Offload TeaCache to CPU", value=teacache_offload)
|
| 166 |
+
return enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload
|
| 167 |
+
|
| 168 |
+
def create_cfg_skip_params(
|
| 169 |
+
cfg_skip_ratio = 0
|
| 170 |
+
):
|
| 171 |
+
cfg_skip_ratio = gr.Slider(0.00, 0.50, value=cfg_skip_ratio, step=0.01, label="CFG Skip Ratio", visible=False)
|
| 172 |
+
return cfg_skip_ratio
|
| 173 |
+
|
| 174 |
+
def create_cfg_riflex_k(
|
| 175 |
+
enable_riflex = False,
|
| 176 |
+
riflex_k = 6
|
| 177 |
+
):
|
| 178 |
+
enable_riflex = gr.Checkbox(label="Enable Riflex", value=enable_riflex, visible=False)
|
| 179 |
+
riflex_k = gr.Slider(0, 10, value=riflex_k, step=1, label="Riflex Intrinsic Frequency Index", visible=False)
|
| 180 |
+
return enable_riflex, riflex_k
|
| 181 |
+
|
| 182 |
+
def create_prompts(
|
| 183 |
+
prompt="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
|
| 184 |
+
negative_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
| 185 |
+
):
|
| 186 |
+
gr.Markdown(
|
| 187 |
+
"""
|
| 188 |
+
### Configs for Generation.
|
| 189 |
+
"""
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
prompt_textbox = gr.Textbox(label="Prompt", lines=2, value=prompt)
|
| 193 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value=negative_prompt)
|
| 194 |
+
return prompt_textbox, negative_prompt_textbox
|
| 195 |
+
|
| 196 |
+
def create_samplers(controller, maximum_step=100):
|
| 197 |
+
with gr.Row():
|
| 198 |
+
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(controller.scheduler_dict.keys()), value=list(controller.scheduler_dict.keys())[0])
|
| 199 |
+
sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=maximum_step, step=1)
|
| 200 |
+
|
| 201 |
+
return sampler_dropdown, sample_step_slider
|
| 202 |
+
|
| 203 |
+
def create_height_width(default_height, default_width, maximum_height, maximum_width):
|
| 204 |
+
resize_method = gr.Radio(
|
| 205 |
+
["Generate by", "Resize according to Reference"],
|
| 206 |
+
value="Generate by",
|
| 207 |
+
show_label=False,
|
| 208 |
+
)
|
| 209 |
+
width_slider = gr.Slider(label="Width", value=default_width, minimum=128, maximum=maximum_width, step=16)
|
| 210 |
+
height_slider = gr.Slider(label="Height", value=default_height, minimum=128, maximum=maximum_height, step=16)
|
| 211 |
+
base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 640, 768, 896, 960, 1024], visible=False)
|
| 212 |
+
|
| 213 |
+
return resize_method, width_slider, height_slider, base_resolution
|
| 214 |
+
|
| 215 |
+
def create_fake_height_width(default_height, default_width, maximum_height, maximum_width):
|
| 216 |
+
resize_method = gr.Radio(
|
| 217 |
+
["Generate by", "Resize according to Reference"],
|
| 218 |
+
value="Generate by",
|
| 219 |
+
show_label=False,
|
| 220 |
+
)
|
| 221 |
+
width_slider = gr.Slider(label="Width", value=default_width, minimum=128, maximum=maximum_width, step=16, interactive=False)
|
| 222 |
+
height_slider = gr.Slider(label="Height", value=default_height, minimum=128, maximum=maximum_height, step=16, interactive=False)
|
| 223 |
+
base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 640, 768, 896, 960, 1024], interactive=False, visible=False)
|
| 224 |
+
|
| 225 |
+
return resize_method, width_slider, height_slider, base_resolution
|
| 226 |
+
|
| 227 |
+
def create_generation_methods_and_video_length(
|
| 228 |
+
generation_method_options,
|
| 229 |
+
default_video_length,
|
| 230 |
+
maximum_video_length
|
| 231 |
+
):
|
| 232 |
+
with gr.Group():
|
| 233 |
+
generation_method = gr.Radio(
|
| 234 |
+
generation_method_options,
|
| 235 |
+
value="Video Generation",
|
| 236 |
+
show_label=False,
|
| 237 |
+
visible=False
|
| 238 |
+
)
|
| 239 |
+
with gr.Row():
|
| 240 |
+
length_slider = gr.Slider(label="Animation length", value=default_video_length, minimum=1, maximum=maximum_video_length, step=4, visible=False)
|
| 241 |
+
overlap_video_length = gr.Slider(label="Overlap length", value=4, minimum=1, maximum=4, step=1, visible=False)
|
| 242 |
+
partial_video_length = gr.Slider(label="Partial video generation length", value=25, minimum=5, maximum=maximum_video_length, step=4, visible=False)
|
| 243 |
+
|
| 244 |
+
return generation_method, length_slider, overlap_video_length, partial_video_length
|
| 245 |
+
|
| 246 |
+
def create_generation_method(source_method_options, prompt_textbox, support_end_image=True, support_ref_image=False, default_video=None, video_examples=None):
|
| 247 |
+
default_method = source_method_options[0] if source_method_options else "Text to Video"
|
| 248 |
+
source_method = gr.Radio(
|
| 249 |
+
source_method_options,
|
| 250 |
+
value=default_method,
|
| 251 |
+
show_label=False,
|
| 252 |
+
)
|
| 253 |
+
with gr.Column(visible = (default_method == "Image to Video")) as image_to_video_col:
|
| 254 |
+
start_image = gr.Image(
|
| 255 |
+
label="The image at the beginning of the video", show_label=True,
|
| 256 |
+
elem_id="i2v_start", sources="upload", type="filepath",
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
|
| 260 |
+
def select_template(evt: gr.SelectData):
|
| 261 |
+
text = {
|
| 262 |
+
"asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.",
|
| 263 |
+
"asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.",
|
| 264 |
+
"asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.",
|
| 265 |
+
"asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.",
|
| 266 |
+
"asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.",
|
| 267 |
+
}[template_gallery_path[evt.index]]
|
| 268 |
+
return template_gallery_path[evt.index], text
|
| 269 |
+
|
| 270 |
+
template_gallery = gr.Gallery(
|
| 271 |
+
template_gallery_path,
|
| 272 |
+
columns=5, rows=1,
|
| 273 |
+
height=140,
|
| 274 |
+
allow_preview=False,
|
| 275 |
+
container=False,
|
| 276 |
+
label="Template Examples",
|
| 277 |
+
)
|
| 278 |
+
template_gallery.select(select_template, None, [start_image, prompt_textbox])
|
| 279 |
+
|
| 280 |
+
with gr.Accordion("The image at the ending of the video", open=False, visible=support_end_image):
|
| 281 |
+
end_image = gr.Image(label="The image at the ending of the video", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
|
| 282 |
+
|
| 283 |
+
with gr.Column(visible = (default_method == "Video to Video")) as video_to_video_col:
|
| 284 |
+
with gr.Row():
|
| 285 |
+
validation_video = gr.Video(
|
| 286 |
+
label="The video to convert", show_label=True,
|
| 287 |
+
elem_id="v2v", sources=["upload"], value=default_video,
|
| 288 |
+
)
|
| 289 |
+
if video_examples:
|
| 290 |
+
gr.Examples(
|
| 291 |
+
examples=video_examples,
|
| 292 |
+
inputs=[validation_video, prompt_textbox] if len(video_examples[0]) > 1 else validation_video,
|
| 293 |
+
label="Video Examples"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Removed Mask Accordion entirely per request or hidden. User said "mask这个不需要"
|
| 297 |
+
# validation_video_mask = gr.Image(
|
| 298 |
+
# label="The mask of the video to inpaint",
|
| 299 |
+
# show_label=False, elem_id="v2v_mask", sources="upload", type="filepath",
|
| 300 |
+
# visible=False
|
| 301 |
+
# )
|
| 302 |
+
validation_video_mask = gr.Image(visible=False, value=None)
|
| 303 |
+
|
| 304 |
+
# Denoise strength default 1.0, hidden
|
| 305 |
+
denoise_strength = gr.Slider(label="Denoise strength", value=1.00, minimum=0.10, maximum=1.00, step=0.01, visible=False)
|
| 306 |
+
|
| 307 |
+
with gr.Column(visible = False) as control_video_col:
|
| 308 |
+
gr.Markdown(
|
| 309 |
+
"""
|
| 310 |
+
Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
|
| 311 |
+
"""
|
| 312 |
+
)
|
| 313 |
+
control_video = gr.Video(
|
| 314 |
+
label="The control video", show_label=True,
|
| 315 |
+
elem_id="v2v_control", sources="upload",
|
| 316 |
+
)
|
| 317 |
+
ref_image = gr.Image(
|
| 318 |
+
label="The reference image for control video", show_label=True,
|
| 319 |
+
elem_id="ref_image", sources="upload", type="filepath", visible=support_ref_image
|
| 320 |
+
)
|
| 321 |
+
return image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image
|
| 322 |
+
|
| 323 |
+
def create_cfg_and_seedbox(gradio_version_is_above_4):
|
| 324 |
+
# cfg default 6, hidden
|
| 325 |
+
cfg_scale_slider = gr.Slider(label="CFG Scale", value=6.0, minimum=0, maximum=20, visible=False)
|
| 326 |
+
|
| 327 |
+
with gr.Row():
|
| 328 |
+
seed_textbox = gr.Textbox(label="Seed", value=43)
|
| 329 |
+
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
| 330 |
+
seed_button.click(
|
| 331 |
+
fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)),
|
| 332 |
+
inputs=[],
|
| 333 |
+
outputs=[seed_textbox]
|
| 334 |
+
)
|
| 335 |
+
return cfg_scale_slider, seed_textbox, seed_button
|
| 336 |
+
|
| 337 |
+
def create_ui_outputs():
|
| 338 |
+
with gr.Column():
|
| 339 |
+
result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
|
| 340 |
+
result_video = gr.Video(label="Generated Animation", interactive=False)
|
| 341 |
+
infer_progress = gr.Textbox(
|
| 342 |
+
label="Generation Info",
|
| 343 |
+
value="No task currently",
|
| 344 |
+
interactive=False
|
| 345 |
+
)
|
| 346 |
+
return result_image, result_video, infer_progress
|
| 347 |
+
|
| 348 |
+
def create_config(controller):
|
| 349 |
+
gr.Markdown(
|
| 350 |
+
"""
|
| 351 |
+
### Config Path (配置文件路径)
|
| 352 |
+
"""
|
| 353 |
+
)
|
| 354 |
+
with gr.Row():
|
| 355 |
+
config_dropdown = gr.Dropdown(
|
| 356 |
+
label="Config Path",
|
| 357 |
+
choices=controller.config_list,
|
| 358 |
+
value=controller.config_path,
|
| 359 |
+
interactive=True,
|
| 360 |
+
)
|
| 361 |
+
config_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
| 362 |
+
def refresh_config():
|
| 363 |
+
controller.refresh_config()
|
| 364 |
+
return gr.update(choices=controller.config_list)
|
| 365 |
+
config_refresh_button.click(fn=refresh_config, inputs=[], outputs=[config_dropdown])
|
| 366 |
+
return config_dropdown, config_refresh_button
|
videox_fun/ui/wan2_2_fun_ui.py
ADDED
|
@@ -0,0 +1,803 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
|
| 2 |
+
"""
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from safetensors import safe_open
|
| 13 |
+
|
| 14 |
+
from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
|
| 15 |
+
from ..dist import set_multi_gpus_devices, shard_model
|
| 16 |
+
from ..models import (AutoencoderKLWan, AutoencoderKLWan3_8, AutoTokenizer,
|
| 17 |
+
CLIPModel, Wan2_2Transformer3DModel, WanT5EncoderModel)
|
| 18 |
+
from ..models.cache_utils import get_teacache_coefficients
|
| 19 |
+
from ..pipeline import Wan2_2FunControlPipeline, Wan2_2FunPipeline, Wan2_2FunInpaintPipeline
|
| 20 |
+
from ..utils.fp8_optimization import (convert_model_weight_to_float8,
|
| 21 |
+
convert_weight_dtype_wrapper,
|
| 22 |
+
replace_parameters_by_name)
|
| 23 |
+
from ..utils.lora_utils import merge_lora, unmerge_lora
|
| 24 |
+
from ..utils.utils import (filter_kwargs, get_image_latent,
|
| 25 |
+
get_image_to_video_latent,
|
| 26 |
+
get_video_to_video_latent, save_videos_grid, timer)
|
| 27 |
+
from .controller import (Fun_Controller, Fun_Controller_Client,
|
| 28 |
+
all_cheduler_dict, css, ddpm_scheduler_dict,
|
| 29 |
+
flow_scheduler_dict, gradio_version,
|
| 30 |
+
gradio_version_is_above_4)
|
| 31 |
+
from .ui import (create_cfg_and_seedbox, create_cfg_riflex_k,
|
| 32 |
+
create_cfg_skip_params, create_config,
|
| 33 |
+
create_fake_finetune_models_checkpoints,
|
| 34 |
+
create_fake_height_width, create_fake_model_checkpoints,
|
| 35 |
+
create_fake_model_type, create_finetune_models_checkpoints,
|
| 36 |
+
create_generation_method,
|
| 37 |
+
create_generation_methods_and_video_length,
|
| 38 |
+
create_height_width, create_model_checkpoints,
|
| 39 |
+
create_model_type, create_prompts, create_samplers,
|
| 40 |
+
create_teacache_params, create_ui_outputs)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Wan2_2_Fun_Controller(Fun_Controller):
|
| 44 |
+
def update_diffusion_transformer(self, diffusion_transformer_dropdown):
|
| 45 |
+
print(f"Update diffusion transformer: {diffusion_transformer_dropdown}")
|
| 46 |
+
self.model_name = diffusion_transformer_dropdown
|
| 47 |
+
self.diffusion_transformer_dropdown = diffusion_transformer_dropdown
|
| 48 |
+
if diffusion_transformer_dropdown == "none":
|
| 49 |
+
return gr.update()
|
| 50 |
+
Chosen_AutoencoderKL = {
|
| 51 |
+
"AutoencoderKLWan": AutoencoderKLWan,
|
| 52 |
+
"AutoencoderKLWan3_8": AutoencoderKLWan3_8
|
| 53 |
+
}[self.config['vae_kwargs'].get('vae_type', 'AutoencoderKLWan')]
|
| 54 |
+
self.vae = Chosen_AutoencoderKL.from_pretrained(
|
| 55 |
+
os.path.join(diffusion_transformer_dropdown, self.config['vae_kwargs'].get('vae_subpath', 'vae')),
|
| 56 |
+
additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']),
|
| 57 |
+
).to(self.weight_dtype)
|
| 58 |
+
|
| 59 |
+
# Get Transformer
|
| 60 |
+
self.transformer = Wan2_2Transformer3DModel.from_pretrained(
|
| 61 |
+
os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer')),
|
| 62 |
+
transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']),
|
| 63 |
+
low_cpu_mem_usage=True,
|
| 64 |
+
torch_dtype=self.weight_dtype,
|
| 65 |
+
)
|
| 66 |
+
if self.config['transformer_additional_kwargs'].get('transformer_combination_type', 'single') == "moe":
|
| 67 |
+
self.transformer_2 = Wan2_2Transformer3DModel.from_pretrained(
|
| 68 |
+
os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer')),
|
| 69 |
+
transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']),
|
| 70 |
+
low_cpu_mem_usage=True,
|
| 71 |
+
torch_dtype=self.weight_dtype,
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
self.transformer_2 = None
|
| 75 |
+
|
| 76 |
+
# Get Tokenizer
|
| 77 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 78 |
+
os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Get Text encoder
|
| 82 |
+
self.text_encoder = WanT5EncoderModel.from_pretrained(
|
| 83 |
+
os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
|
| 84 |
+
additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']),
|
| 85 |
+
low_cpu_mem_usage=True,
|
| 86 |
+
torch_dtype=self.weight_dtype,
|
| 87 |
+
)
|
| 88 |
+
self.text_encoder = self.text_encoder.eval()
|
| 89 |
+
|
| 90 |
+
Chosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]]
|
| 91 |
+
self.scheduler = Chosen_Scheduler(
|
| 92 |
+
**filter_kwargs(Chosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs']))
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Get pipeline
|
| 96 |
+
if self.model_type == "Inpaint":
|
| 97 |
+
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
| 98 |
+
self.pipeline = Wan2_2FunInpaintPipeline(
|
| 99 |
+
vae=self.vae,
|
| 100 |
+
tokenizer=self.tokenizer,
|
| 101 |
+
text_encoder=self.text_encoder,
|
| 102 |
+
transformer=self.transformer,
|
| 103 |
+
transformer_2=self.transformer_2,
|
| 104 |
+
scheduler=self.scheduler,
|
| 105 |
+
)
|
| 106 |
+
else:
|
| 107 |
+
self.pipeline = Wan2_2FunPipeline(
|
| 108 |
+
vae=self.vae,
|
| 109 |
+
tokenizer=self.tokenizer,
|
| 110 |
+
text_encoder=self.text_encoder,
|
| 111 |
+
transformer=self.transformer,
|
| 112 |
+
transformer_2=self.transformer_2,
|
| 113 |
+
scheduler=self.scheduler,
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
self.pipeline = Wan2_2FunControlPipeline(
|
| 117 |
+
vae=self.vae,
|
| 118 |
+
tokenizer=self.tokenizer,
|
| 119 |
+
text_encoder=self.text_encoder,
|
| 120 |
+
transformer=self.transformer,
|
| 121 |
+
transformer_2=self.transformer_2,
|
| 122 |
+
scheduler=self.scheduler,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if self.ulysses_degree > 1 or self.ring_degree > 1:
|
| 126 |
+
from functools import partial
|
| 127 |
+
self.transformer.enable_multi_gpus_inference()
|
| 128 |
+
if self.transformer_2 is not None:
|
| 129 |
+
self.transformer_2.enable_multi_gpus_inference()
|
| 130 |
+
if self.fsdp_dit:
|
| 131 |
+
shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype)
|
| 132 |
+
self.pipeline.transformer = shard_fn(self.pipeline.transformer)
|
| 133 |
+
if self.transformer_2 is not None:
|
| 134 |
+
self.pipeline.transformer_2 = shard_fn(self.pipeline.transformer_2)
|
| 135 |
+
print("Add FSDP DIT")
|
| 136 |
+
if self.fsdp_text_encoder:
|
| 137 |
+
shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype)
|
| 138 |
+
self.pipeline.text_encoder = shard_fn(self.pipeline.text_encoder)
|
| 139 |
+
print("Add FSDP TEXT ENCODER")
|
| 140 |
+
|
| 141 |
+
if self.compile_dit:
|
| 142 |
+
for i in range(len(self.pipeline.transformer.blocks)):
|
| 143 |
+
self.pipeline.transformer.blocks[i] = torch.compile(self.pipeline.transformer.blocks[i])
|
| 144 |
+
if self.transformer_2 is not None:
|
| 145 |
+
for i in range(len(self.pipeline.transformer_2.blocks)):
|
| 146 |
+
self.pipeline.transformer_2.blocks[i] = torch.compile(self.pipeline.transformer_2.blocks[i])
|
| 147 |
+
print("Add Compile")
|
| 148 |
+
|
| 149 |
+
if self.GPU_memory_mode == "sequential_cpu_offload":
|
| 150 |
+
replace_parameters_by_name(self.transformer, ["modulation",], device=self.device)
|
| 151 |
+
self.transformer.freqs = self.transformer.freqs.to(device=self.device)
|
| 152 |
+
if self.transformer_2 is not None:
|
| 153 |
+
replace_parameters_by_name(self.transformer_2, ["modulation",], device=self.device)
|
| 154 |
+
self.transformer_2.freqs = self.transformer_2.freqs.to(device=self.device)
|
| 155 |
+
self.pipeline.enable_sequential_cpu_offload(device=self.device)
|
| 156 |
+
elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
|
| 157 |
+
convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device)
|
| 158 |
+
convert_weight_dtype_wrapper(self.transformer, self.weight_dtype)
|
| 159 |
+
if self.transformer_2 is not None:
|
| 160 |
+
convert_model_weight_to_float8(self.transformer_2, exclude_module_name=["modulation",], device=self.device)
|
| 161 |
+
convert_weight_dtype_wrapper(self.transformer_2, self.weight_dtype)
|
| 162 |
+
self.pipeline.enable_model_cpu_offload(device=self.device)
|
| 163 |
+
elif self.GPU_memory_mode == "model_cpu_offload":
|
| 164 |
+
self.pipeline.enable_model_cpu_offload(device=self.device)
|
| 165 |
+
elif self.GPU_memory_mode == "model_full_load_and_qfloat8":
|
| 166 |
+
convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device)
|
| 167 |
+
convert_weight_dtype_wrapper(self.transformer, self.weight_dtype)
|
| 168 |
+
if self.transformer_2 is not None:
|
| 169 |
+
convert_model_weight_to_float8(self.transformer_2, exclude_module_name=["modulation",], device=self.device)
|
| 170 |
+
convert_weight_dtype_wrapper(self.transformer_2, self.weight_dtype)
|
| 171 |
+
self.pipeline.to(self.device)
|
| 172 |
+
else:
|
| 173 |
+
self.pipeline.to(self.device)
|
| 174 |
+
print("Update diffusion transformer done")
|
| 175 |
+
return gr.update()
|
| 176 |
+
|
| 177 |
+
@timer
|
| 178 |
+
def generate(
|
| 179 |
+
self,
|
| 180 |
+
diffusion_transformer_dropdown,
|
| 181 |
+
base_model_dropdown,
|
| 182 |
+
lora_model_dropdown,
|
| 183 |
+
lora_alpha_slider,
|
| 184 |
+
prompt_textbox,
|
| 185 |
+
negative_prompt_textbox,
|
| 186 |
+
sampler_dropdown,
|
| 187 |
+
sample_step_slider,
|
| 188 |
+
resize_method,
|
| 189 |
+
width_slider,
|
| 190 |
+
height_slider,
|
| 191 |
+
base_resolution,
|
| 192 |
+
generation_method,
|
| 193 |
+
length_slider,
|
| 194 |
+
overlap_video_length,
|
| 195 |
+
partial_video_length,
|
| 196 |
+
cfg_scale_slider,
|
| 197 |
+
start_image,
|
| 198 |
+
end_image,
|
| 199 |
+
validation_video,
|
| 200 |
+
validation_video_mask,
|
| 201 |
+
control_video,
|
| 202 |
+
denoise_strength,
|
| 203 |
+
seed_textbox,
|
| 204 |
+
ref_image = None,
|
| 205 |
+
enable_teacache = None,
|
| 206 |
+
teacache_threshold = None,
|
| 207 |
+
num_skip_start_steps = None,
|
| 208 |
+
teacache_offload = None,
|
| 209 |
+
cfg_skip_ratio = None,
|
| 210 |
+
enable_riflex = None,
|
| 211 |
+
riflex_k = None,
|
| 212 |
+
base_model_2_dropdown=None,
|
| 213 |
+
lora_model_2_dropdown=None,
|
| 214 |
+
fps = None,
|
| 215 |
+
is_api = False,
|
| 216 |
+
):
|
| 217 |
+
self.clear_cache()
|
| 218 |
+
|
| 219 |
+
print(f"Input checking.")
|
| 220 |
+
_, comment = self.input_check(
|
| 221 |
+
resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api
|
| 222 |
+
)
|
| 223 |
+
print(f"Input checking down")
|
| 224 |
+
if comment != "OK":
|
| 225 |
+
return "", comment
|
| 226 |
+
is_image = True if generation_method == "Image Generation" else False
|
| 227 |
+
|
| 228 |
+
if self.base_model_path != base_model_dropdown:
|
| 229 |
+
self.update_base_model(base_model_dropdown)
|
| 230 |
+
if self.base_model_2_path != base_model_2_dropdown:
|
| 231 |
+
self.update_lora_model(base_model_2_dropdown, is_checkpoint_2=True)
|
| 232 |
+
|
| 233 |
+
if self.lora_model_path != lora_model_dropdown:
|
| 234 |
+
self.update_lora_model(lora_model_dropdown)
|
| 235 |
+
if self.lora_model_2_path != lora_model_2_dropdown:
|
| 236 |
+
self.update_lora_model(lora_model_2_dropdown, is_checkpoint_2=True)
|
| 237 |
+
|
| 238 |
+
print(f"Load scheduler.")
|
| 239 |
+
scheduler_config = self.pipeline.scheduler.config
|
| 240 |
+
if sampler_dropdown == "Flow_Unipc" or sampler_dropdown == "Flow_DPM++":
|
| 241 |
+
scheduler_config['shift'] = 1
|
| 242 |
+
self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config)
|
| 243 |
+
print(f"Load scheduler down.")
|
| 244 |
+
|
| 245 |
+
if resize_method == "Resize according to Reference":
|
| 246 |
+
print(f"Calculate height and width according to Reference.")
|
| 247 |
+
height_slider, width_slider = self.get_height_width_from_reference(
|
| 248 |
+
base_resolution, start_image, validation_video, control_video,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
if self.lora_model_path != "none":
|
| 252 |
+
print(f"Merge Lora.")
|
| 253 |
+
self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 254 |
+
if self.transformer_2 is not None:
|
| 255 |
+
self.pipeline = merge_lora(self.pipeline, self.lora_model_2_path, multiplier=lora_alpha_slider, sub_transformer_name="transformer_2")
|
| 256 |
+
print(f"Merge Lora done.")
|
| 257 |
+
|
| 258 |
+
coefficients = get_teacache_coefficients(self.diffusion_transformer_dropdown) if enable_teacache else None
|
| 259 |
+
if coefficients is not None:
|
| 260 |
+
print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.")
|
| 261 |
+
self.pipeline.transformer.enable_teacache(
|
| 262 |
+
coefficients, sample_step_slider, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload
|
| 263 |
+
)
|
| 264 |
+
if self.transformer_2 is not None:
|
| 265 |
+
self.pipeline.transformer_2.share_teacache(self.pipeline.transformer)
|
| 266 |
+
else:
|
| 267 |
+
print(f"Disable TeaCache.")
|
| 268 |
+
self.pipeline.transformer.disable_teacache()
|
| 269 |
+
if self.transformer_2 is not None:
|
| 270 |
+
self.pipeline.transformer_2.disable_teacache()
|
| 271 |
+
|
| 272 |
+
if cfg_skip_ratio is not None and cfg_skip_ratio >= 0:
|
| 273 |
+
print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.")
|
| 274 |
+
self.pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, sample_step_slider)
|
| 275 |
+
if self.transformer_2 is not None:
|
| 276 |
+
self.pipeline.transformer_2.share_cfg_skip(self.pipeline.transformer)
|
| 277 |
+
|
| 278 |
+
print(f"Generate seed.")
|
| 279 |
+
if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
|
| 280 |
+
else: seed_textbox = np.random.randint(0, 1e10)
|
| 281 |
+
generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox))
|
| 282 |
+
print(f"Generate seed done.")
|
| 283 |
+
|
| 284 |
+
if fps is None:
|
| 285 |
+
fps = 16
|
| 286 |
+
boundary = self.config['transformer_additional_kwargs'].get('boundary', 0.875)
|
| 287 |
+
|
| 288 |
+
if enable_riflex:
|
| 289 |
+
print(f"Enable riflex")
|
| 290 |
+
latent_frames = (int(length_slider) - 1) // self.vae.config.temporal_compression_ratio + 1
|
| 291 |
+
self.pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1)
|
| 292 |
+
if self.transformer_2 is not None:
|
| 293 |
+
self.pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1)
|
| 294 |
+
|
| 295 |
+
try:
|
| 296 |
+
print(f"Generation.")
|
| 297 |
+
if self.model_type == "Inpaint":
|
| 298 |
+
if self.transformer.config.in_channels != self.vae.config.latent_channels:
|
| 299 |
+
if validation_video is not None:
|
| 300 |
+
input_video, input_video_mask, _, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps)
|
| 301 |
+
else:
|
| 302 |
+
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
|
| 303 |
+
|
| 304 |
+
sample = self.pipeline(
|
| 305 |
+
prompt_textbox,
|
| 306 |
+
negative_prompt = negative_prompt_textbox,
|
| 307 |
+
num_inference_steps = sample_step_slider,
|
| 308 |
+
guidance_scale = cfg_scale_slider,
|
| 309 |
+
width = width_slider,
|
| 310 |
+
height = height_slider,
|
| 311 |
+
num_frames = length_slider if not is_image else 1,
|
| 312 |
+
generator = generator,
|
| 313 |
+
|
| 314 |
+
video = input_video,
|
| 315 |
+
mask_video = input_video_mask,
|
| 316 |
+
boundary = boundary
|
| 317 |
+
).videos
|
| 318 |
+
else:
|
| 319 |
+
sample = self.pipeline(
|
| 320 |
+
prompt_textbox,
|
| 321 |
+
negative_prompt = negative_prompt_textbox,
|
| 322 |
+
num_inference_steps = sample_step_slider,
|
| 323 |
+
guidance_scale = cfg_scale_slider,
|
| 324 |
+
width = width_slider,
|
| 325 |
+
height = height_slider,
|
| 326 |
+
num_frames = length_slider if not is_image else 1,
|
| 327 |
+
generator = generator,
|
| 328 |
+
boundary = boundary
|
| 329 |
+
).videos
|
| 330 |
+
else:
|
| 331 |
+
inpaint_video, inpaint_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
|
| 332 |
+
|
| 333 |
+
if ref_image is not None:
|
| 334 |
+
ref_image = get_image_latent(ref_image, sample_size=(height_slider, width_slider))
|
| 335 |
+
|
| 336 |
+
input_video, input_video_mask, _, _ = get_video_to_video_latent(control_video, video_length=length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps, ref_image=None)
|
| 337 |
+
|
| 338 |
+
sample = self.pipeline(
|
| 339 |
+
prompt_textbox,
|
| 340 |
+
negative_prompt = negative_prompt_textbox,
|
| 341 |
+
num_inference_steps = sample_step_slider,
|
| 342 |
+
guidance_scale = cfg_scale_slider,
|
| 343 |
+
width = width_slider,
|
| 344 |
+
height = height_slider,
|
| 345 |
+
num_frames = length_slider if not is_image else 1,
|
| 346 |
+
generator = generator,
|
| 347 |
+
|
| 348 |
+
video = inpaint_video,
|
| 349 |
+
mask_video = inpaint_video_mask,
|
| 350 |
+
control_video = input_video,
|
| 351 |
+
ref_image = ref_image,
|
| 352 |
+
boundary = boundary,
|
| 353 |
+
).videos
|
| 354 |
+
print(f"Generation done.")
|
| 355 |
+
except Exception as e:
|
| 356 |
+
self.auto_model_clear_cache(self.pipeline.transformer)
|
| 357 |
+
self.auto_model_clear_cache(self.pipeline.text_encoder)
|
| 358 |
+
self.auto_model_clear_cache(self.pipeline.vae)
|
| 359 |
+
self.clear_cache()
|
| 360 |
+
|
| 361 |
+
print(f"Error. error information is {str(e)}")
|
| 362 |
+
if self.lora_model_path != "none":
|
| 363 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 364 |
+
if is_api:
|
| 365 |
+
return "", f"Error. error information is {str(e)}"
|
| 366 |
+
else:
|
| 367 |
+
return gr.update(), gr.update(), f"Error. error information is {str(e)}"
|
| 368 |
+
|
| 369 |
+
self.clear_cache()
|
| 370 |
+
# lora part
|
| 371 |
+
if self.lora_model_path != "none":
|
| 372 |
+
print(f"Unmerge Lora.")
|
| 373 |
+
self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
|
| 374 |
+
print(f"Unmerge Lora done.")
|
| 375 |
+
|
| 376 |
+
print(f"Saving outputs.")
|
| 377 |
+
save_sample_path = self.save_outputs(
|
| 378 |
+
is_image, length_slider, sample, fps=fps
|
| 379 |
+
)
|
| 380 |
+
print(f"Saving outputs done.")
|
| 381 |
+
|
| 382 |
+
if is_image or length_slider == 1:
|
| 383 |
+
if is_api:
|
| 384 |
+
return save_sample_path, "Success"
|
| 385 |
+
else:
|
| 386 |
+
if gradio_version_is_above_4:
|
| 387 |
+
return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
|
| 388 |
+
else:
|
| 389 |
+
return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
|
| 390 |
+
else:
|
| 391 |
+
if is_api:
|
| 392 |
+
return save_sample_path, "Success"
|
| 393 |
+
else:
|
| 394 |
+
if gradio_version_is_above_4:
|
| 395 |
+
return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
|
| 396 |
+
else:
|
| 397 |
+
return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
|
| 398 |
+
|
| 399 |
+
Wan2_2_Fun_Controller_Host = Wan2_2_Fun_Controller
|
| 400 |
+
Wan2_2_Fun_Controller_Client = Fun_Controller_Client
|
| 401 |
+
|
| 402 |
+
def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype, savedir_sample=None):
|
| 403 |
+
controller = Wan2_2_Fun_Controller(
|
| 404 |
+
GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
|
| 405 |
+
config_path=config_path, compile_dit=compile_dit,
|
| 406 |
+
weight_dtype=weight_dtype, savedir_sample=savedir_sample,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
with gr.Blocks(css=css) as demo:
|
| 410 |
+
gr.Markdown(
|
| 411 |
+
"""
|
| 412 |
+
# Wan2.2-Fun:
|
| 413 |
+
|
| 414 |
+
A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos.
|
| 415 |
+
|
| 416 |
+
[Github](https://github.com/aigc-apps/VideoX-Fun/)
|
| 417 |
+
"""
|
| 418 |
+
)
|
| 419 |
+
with gr.Column(variant="panel"):
|
| 420 |
+
config_dropdown, config_refresh_button = create_config(controller)
|
| 421 |
+
model_type = create_model_type(visible=True)
|
| 422 |
+
diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \
|
| 423 |
+
create_model_checkpoints(controller, visible=True)
|
| 424 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \
|
| 425 |
+
create_finetune_models_checkpoints(controller, visible=True, add_checkpoint_2=True)
|
| 426 |
+
base_model_dropdown, base_model_2_dropdown = base_model_dropdown
|
| 427 |
+
lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown
|
| 428 |
+
|
| 429 |
+
with gr.Row():
|
| 430 |
+
enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \
|
| 431 |
+
create_teacache_params(True, 0.10, 1, False)
|
| 432 |
+
cfg_skip_ratio = create_cfg_skip_params(0)
|
| 433 |
+
enable_riflex, riflex_k = create_cfg_riflex_k(False, 6)
|
| 434 |
+
|
| 435 |
+
with gr.Column(variant="panel"):
|
| 436 |
+
prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
|
| 437 |
+
|
| 438 |
+
with gr.Row():
|
| 439 |
+
with gr.Column():
|
| 440 |
+
sampler_dropdown, sample_step_slider = create_samplers(controller)
|
| 441 |
+
|
| 442 |
+
resize_method, width_slider, height_slider, base_resolution = create_height_width(
|
| 443 |
+
default_height = 480, default_width = 832, maximum_height = 1344,
|
| 444 |
+
maximum_width = 1344,
|
| 445 |
+
)
|
| 446 |
+
generation_method, length_slider, overlap_video_length, partial_video_length = \
|
| 447 |
+
create_generation_methods_and_video_length(
|
| 448 |
+
["Video Generation", "Image Generation"],
|
| 449 |
+
default_video_length=81,
|
| 450 |
+
maximum_video_length=161,
|
| 451 |
+
)
|
| 452 |
+
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
|
| 453 |
+
["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video Control (视频控制)"], prompt_textbox, support_ref_image=True
|
| 454 |
+
)
|
| 455 |
+
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
|
| 456 |
+
|
| 457 |
+
generate_button = gr.Button(value="Generate (生成)", variant='primary')
|
| 458 |
+
|
| 459 |
+
result_image, result_video, infer_progress = create_ui_outputs()
|
| 460 |
+
|
| 461 |
+
config_dropdown.change(
|
| 462 |
+
fn=controller.update_config,
|
| 463 |
+
inputs=[config_dropdown],
|
| 464 |
+
outputs=[]
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
model_type.change(
|
| 468 |
+
fn=controller.update_model_type,
|
| 469 |
+
inputs=[model_type],
|
| 470 |
+
outputs=[]
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
def upload_generation_method(generation_method):
|
| 474 |
+
if generation_method == "Video Generation":
|
| 475 |
+
return [gr.update(visible=True, maximum=161, value=81, interactive=True), gr.update(visible=False), gr.update(visible=False)]
|
| 476 |
+
elif generation_method == "Image Generation":
|
| 477 |
+
return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)]
|
| 478 |
+
else:
|
| 479 |
+
return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)]
|
| 480 |
+
generation_method.change(
|
| 481 |
+
upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
def upload_source_method(source_method):
|
| 485 |
+
if source_method == "Text to Video (文本到视频)":
|
| 486 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 487 |
+
elif source_method == "Image to Video (图片到视频)":
|
| 488 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 489 |
+
elif source_method == "Video to Video (视频到视频)":
|
| 490 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
|
| 491 |
+
else:
|
| 492 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
|
| 493 |
+
source_method.change(
|
| 494 |
+
upload_source_method, source_method, [
|
| 495 |
+
image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
|
| 496 |
+
validation_video, validation_video_mask, control_video
|
| 497 |
+
]
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
def upload_resize_method(resize_method):
|
| 501 |
+
if resize_method == "Generate by":
|
| 502 |
+
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
|
| 503 |
+
else:
|
| 504 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
| 505 |
+
resize_method.change(
|
| 506 |
+
upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
generate_button.click(
|
| 510 |
+
fn=controller.generate,
|
| 511 |
+
inputs=[
|
| 512 |
+
diffusion_transformer_dropdown,
|
| 513 |
+
base_model_dropdown,
|
| 514 |
+
lora_model_dropdown,
|
| 515 |
+
lora_alpha_slider,
|
| 516 |
+
prompt_textbox,
|
| 517 |
+
negative_prompt_textbox,
|
| 518 |
+
sampler_dropdown,
|
| 519 |
+
sample_step_slider,
|
| 520 |
+
resize_method,
|
| 521 |
+
width_slider,
|
| 522 |
+
height_slider,
|
| 523 |
+
base_resolution,
|
| 524 |
+
generation_method,
|
| 525 |
+
length_slider,
|
| 526 |
+
overlap_video_length,
|
| 527 |
+
partial_video_length,
|
| 528 |
+
cfg_scale_slider,
|
| 529 |
+
start_image,
|
| 530 |
+
end_image,
|
| 531 |
+
validation_video,
|
| 532 |
+
validation_video_mask,
|
| 533 |
+
control_video,
|
| 534 |
+
denoise_strength,
|
| 535 |
+
seed_textbox,
|
| 536 |
+
ref_image,
|
| 537 |
+
enable_teacache,
|
| 538 |
+
teacache_threshold,
|
| 539 |
+
num_skip_start_steps,
|
| 540 |
+
teacache_offload,
|
| 541 |
+
cfg_skip_ratio,
|
| 542 |
+
enable_riflex,
|
| 543 |
+
riflex_k,
|
| 544 |
+
base_model_2_dropdown,
|
| 545 |
+
lora_model_2_dropdown
|
| 546 |
+
],
|
| 547 |
+
outputs=[result_image, result_video, infer_progress]
|
| 548 |
+
)
|
| 549 |
+
return demo, controller
|
| 550 |
+
|
| 551 |
+
def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, config_path, compile_dit, weight_dtype, savedir_sample=None):
|
| 552 |
+
controller = Wan2_2_Fun_Controller_Host(
|
| 553 |
+
GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type,
|
| 554 |
+
config_path=config_path, compile_dit=compile_dit,
|
| 555 |
+
weight_dtype=weight_dtype, savedir_sample=savedir_sample,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
with gr.Blocks(css=css) as demo:
|
| 559 |
+
gr.Markdown(
|
| 560 |
+
"""
|
| 561 |
+
# Wan2.2-Fun:
|
| 562 |
+
|
| 563 |
+
A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos.
|
| 564 |
+
|
| 565 |
+
[Github](https://github.com/aigc-apps/VideoX-Fun/)
|
| 566 |
+
"""
|
| 567 |
+
)
|
| 568 |
+
with gr.Column(variant="panel"):
|
| 569 |
+
model_type = create_fake_model_type(visible=False)
|
| 570 |
+
diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
|
| 571 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider = \
|
| 572 |
+
create_fake_finetune_models_checkpoints(visible=True, add_checkpoint_2=True)
|
| 573 |
+
base_model_dropdown, base_model_2_dropdown = base_model_dropdown
|
| 574 |
+
lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown
|
| 575 |
+
|
| 576 |
+
with gr.Row():
|
| 577 |
+
enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \
|
| 578 |
+
create_teacache_params(True, 0.10, 1, False)
|
| 579 |
+
cfg_skip_ratio = create_cfg_skip_params(0)
|
| 580 |
+
enable_riflex, riflex_k = create_cfg_riflex_k(False, 6)
|
| 581 |
+
|
| 582 |
+
with gr.Column(variant="panel"):
|
| 583 |
+
prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
|
| 584 |
+
|
| 585 |
+
with gr.Row():
|
| 586 |
+
with gr.Column():
|
| 587 |
+
sampler_dropdown, sample_step_slider = create_samplers(controller)
|
| 588 |
+
|
| 589 |
+
resize_method, width_slider, height_slider, base_resolution = create_height_width(
|
| 590 |
+
default_height = 480, default_width = 832, maximum_height = 1344,
|
| 591 |
+
maximum_width = 1344,
|
| 592 |
+
)
|
| 593 |
+
generation_method, length_slider, overlap_video_length, partial_video_length = \
|
| 594 |
+
create_generation_methods_and_video_length(
|
| 595 |
+
["Video Generation", "Image Generation"],
|
| 596 |
+
default_video_length=81,
|
| 597 |
+
maximum_video_length=161,
|
| 598 |
+
)
|
| 599 |
+
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
|
| 600 |
+
["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video Control (视频控制)"], prompt_textbox, support_ref_image=True
|
| 601 |
+
)
|
| 602 |
+
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
|
| 603 |
+
|
| 604 |
+
generate_button = gr.Button(value="Generate (生成)", variant='primary')
|
| 605 |
+
|
| 606 |
+
result_image, result_video, infer_progress = create_ui_outputs()
|
| 607 |
+
|
| 608 |
+
def upload_generation_method(generation_method):
|
| 609 |
+
if generation_method == "Video Generation":
|
| 610 |
+
return gr.update(visible=True, minimum=1, maximum=161, value=81, interactive=True)
|
| 611 |
+
elif generation_method == "Image Generation":
|
| 612 |
+
return gr.update(minimum=1, maximum=1, value=1, interactive=False)
|
| 613 |
+
generation_method.change(
|
| 614 |
+
upload_generation_method, generation_method, [length_slider]
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
def upload_source_method(source_method):
|
| 618 |
+
if source_method == "Text to Video (文本到视频)":
|
| 619 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 620 |
+
elif source_method == "Image to Video (图片到视频)":
|
| 621 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 622 |
+
elif source_method == "Video to Video (视频到视频)":
|
| 623 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
|
| 624 |
+
else:
|
| 625 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
|
| 626 |
+
source_method.change(
|
| 627 |
+
upload_source_method, source_method, [
|
| 628 |
+
image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
|
| 629 |
+
validation_video, validation_video_mask, control_video
|
| 630 |
+
]
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
def upload_resize_method(resize_method):
|
| 634 |
+
if resize_method == "Generate by":
|
| 635 |
+
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
|
| 636 |
+
else:
|
| 637 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
| 638 |
+
resize_method.change(
|
| 639 |
+
upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
generate_button.click(
|
| 643 |
+
fn=controller.generate,
|
| 644 |
+
inputs=[
|
| 645 |
+
diffusion_transformer_dropdown,
|
| 646 |
+
base_model_dropdown,
|
| 647 |
+
lora_model_dropdown,
|
| 648 |
+
lora_alpha_slider,
|
| 649 |
+
prompt_textbox,
|
| 650 |
+
negative_prompt_textbox,
|
| 651 |
+
sampler_dropdown,
|
| 652 |
+
sample_step_slider,
|
| 653 |
+
resize_method,
|
| 654 |
+
width_slider,
|
| 655 |
+
height_slider,
|
| 656 |
+
base_resolution,
|
| 657 |
+
generation_method,
|
| 658 |
+
length_slider,
|
| 659 |
+
overlap_video_length,
|
| 660 |
+
partial_video_length,
|
| 661 |
+
cfg_scale_slider,
|
| 662 |
+
start_image,
|
| 663 |
+
end_image,
|
| 664 |
+
validation_video,
|
| 665 |
+
validation_video_mask,
|
| 666 |
+
control_video,
|
| 667 |
+
denoise_strength,
|
| 668 |
+
seed_textbox,
|
| 669 |
+
ref_image,
|
| 670 |
+
enable_teacache,
|
| 671 |
+
teacache_threshold,
|
| 672 |
+
num_skip_start_steps,
|
| 673 |
+
teacache_offload,
|
| 674 |
+
cfg_skip_ratio,
|
| 675 |
+
enable_riflex,
|
| 676 |
+
riflex_k,
|
| 677 |
+
base_model_2_dropdown,
|
| 678 |
+
lora_model_2_dropdown
|
| 679 |
+
],
|
| 680 |
+
outputs=[result_image, result_video, infer_progress]
|
| 681 |
+
)
|
| 682 |
+
return demo, controller
|
| 683 |
+
|
| 684 |
+
def ui_client(scheduler_dict, model_name, savedir_sample=None):
|
| 685 |
+
controller = Wan2_2_Fun_Controller_Client(scheduler_dict, savedir_sample)
|
| 686 |
+
|
| 687 |
+
with gr.Blocks(css=css) as demo:
|
| 688 |
+
gr.Markdown(
|
| 689 |
+
"""
|
| 690 |
+
# Wan2.2-Fun:
|
| 691 |
+
|
| 692 |
+
A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos.
|
| 693 |
+
|
| 694 |
+
[Github](https://github.com/aigc-apps/VideoX-Fun/)
|
| 695 |
+
"""
|
| 696 |
+
)
|
| 697 |
+
with gr.Column(variant="panel"):
|
| 698 |
+
diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
|
| 699 |
+
base_model_dropdown, lora_model_dropdown, lora_alpha_slider = \
|
| 700 |
+
create_fake_finetune_models_checkpoints(visible=True, add_checkpoint_2=True)
|
| 701 |
+
base_model_dropdown, base_model_2_dropdown = base_model_dropdown
|
| 702 |
+
lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown
|
| 703 |
+
|
| 704 |
+
with gr.Row():
|
| 705 |
+
enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \
|
| 706 |
+
create_teacache_params(True, 0.10, 1, False)
|
| 707 |
+
cfg_skip_ratio = create_cfg_skip_params(0)
|
| 708 |
+
enable_riflex, riflex_k = create_cfg_riflex_k(False, 6)
|
| 709 |
+
|
| 710 |
+
with gr.Column(variant="panel"):
|
| 711 |
+
prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
|
| 712 |
+
|
| 713 |
+
with gr.Row():
|
| 714 |
+
with gr.Column():
|
| 715 |
+
sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50)
|
| 716 |
+
|
| 717 |
+
resize_method, width_slider, height_slider, base_resolution = create_fake_height_width(
|
| 718 |
+
default_height = 480, default_width = 832, maximum_height = 1344,
|
| 719 |
+
maximum_width = 1344,
|
| 720 |
+
)
|
| 721 |
+
generation_method, length_slider, overlap_video_length, partial_video_length = \
|
| 722 |
+
create_generation_methods_and_video_length(
|
| 723 |
+
["Video Generation", "Image Generation"],
|
| 724 |
+
default_video_length=81,
|
| 725 |
+
maximum_video_length=161,
|
| 726 |
+
)
|
| 727 |
+
image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
|
| 728 |
+
["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
|
| 732 |
+
|
| 733 |
+
generate_button = gr.Button(value="Generate (生成)", variant='primary')
|
| 734 |
+
|
| 735 |
+
result_image, result_video, infer_progress = create_ui_outputs()
|
| 736 |
+
|
| 737 |
+
def upload_generation_method(generation_method):
|
| 738 |
+
if generation_method == "Video Generation":
|
| 739 |
+
return gr.update(visible=True, minimum=5, maximum=161, value=49, interactive=True)
|
| 740 |
+
elif generation_method == "Image Generation":
|
| 741 |
+
return gr.update(minimum=1, maximum=1, value=1, interactive=False)
|
| 742 |
+
generation_method.change(
|
| 743 |
+
upload_generation_method, generation_method, [length_slider]
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
def upload_source_method(source_method):
|
| 747 |
+
if source_method == "Text to Video (文本到视频)":
|
| 748 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
|
| 749 |
+
elif source_method == "Image to Video (图片到视频)":
|
| 750 |
+
return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)]
|
| 751 |
+
else:
|
| 752 |
+
return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()]
|
| 753 |
+
source_method.change(
|
| 754 |
+
upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask]
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
def upload_resize_method(resize_method):
|
| 758 |
+
if resize_method == "Generate by":
|
| 759 |
+
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
|
| 760 |
+
else:
|
| 761 |
+
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
|
| 762 |
+
resize_method.change(
|
| 763 |
+
upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
generate_button.click(
|
| 767 |
+
fn=controller.generate,
|
| 768 |
+
inputs=[
|
| 769 |
+
diffusion_transformer_dropdown,
|
| 770 |
+
base_model_dropdown,
|
| 771 |
+
lora_model_dropdown,
|
| 772 |
+
lora_alpha_slider,
|
| 773 |
+
prompt_textbox,
|
| 774 |
+
negative_prompt_textbox,
|
| 775 |
+
sampler_dropdown,
|
| 776 |
+
sample_step_slider,
|
| 777 |
+
resize_method,
|
| 778 |
+
width_slider,
|
| 779 |
+
height_slider,
|
| 780 |
+
base_resolution,
|
| 781 |
+
generation_method,
|
| 782 |
+
length_slider,
|
| 783 |
+
cfg_scale_slider,
|
| 784 |
+
start_image,
|
| 785 |
+
end_image,
|
| 786 |
+
validation_video,
|
| 787 |
+
validation_video_mask,
|
| 788 |
+
denoise_strength,
|
| 789 |
+
seed_textbox,
|
| 790 |
+
ref_image,
|
| 791 |
+
enable_teacache,
|
| 792 |
+
teacache_threshold,
|
| 793 |
+
num_skip_start_steps,
|
| 794 |
+
teacache_offload,
|
| 795 |
+
cfg_skip_ratio,
|
| 796 |
+
enable_riflex,
|
| 797 |
+
riflex_k,
|
| 798 |
+
base_model_2_dropdown,
|
| 799 |
+
lora_model_2_dropdown
|
| 800 |
+
],
|
| 801 |
+
outputs=[result_image, result_video, infer_progress]
|
| 802 |
+
)
|
| 803 |
+
return demo, controller
|