XiangpengYang commited on
Commit
42a2bfa
·
1 Parent(s): 64a0f40

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +4 -0
  3. LICENSE +201 -0
  4. README.md +209 -14
  5. __init__.py +3 -0
  6. app.py +391 -0
  7. assets/dough.mp4 +3 -0
  8. assets/sign.mp4 +3 -0
  9. assets/teaser_test.json +20 -0
  10. assets/two_man.mp4 +3 -0
  11. assets/woman_ballon.mp4 +3 -0
  12. config/1.3b_lora_zero_stage2_config.json +24 -0
  13. config/14b_lora_zero2_bf16_config.json +24 -0
  14. config/wan2.1/wan_civitai.yaml +39 -0
  15. config/wan2.2/wan_civitai_5b.yaml +41 -0
  16. config/wan2.2/wan_civitai_i2v.yaml +43 -0
  17. config/wan2.2/wan_civitai_s2v.yaml +44 -0
  18. config/wan2.2/wan_civitai_t2v.yaml +43 -0
  19. config/zero_stage2_config.json +16 -0
  20. config/zero_stage3_config.json +27 -0
  21. config/zero_stage3_config_cpu_offload.json +28 -0
  22. inference.py +400 -0
  23. install.py +45 -0
  24. pyproject.toml +15 -0
  25. requirements.txt +26 -0
  26. scripts/local_style.sh +13 -0
  27. scripts/obj_add.sh +13 -0
  28. scripts/obj_rem.sh +13 -0
  29. scripts/parallel_infer.sh +12 -0
  30. videox_fun/__init__.py +0 -0
  31. videox_fun/api/api.py +226 -0
  32. videox_fun/api/api_multi_nodes.py +320 -0
  33. videox_fun/data/bucket_sampler.py +392 -0
  34. videox_fun/data/dataset_image.py +76 -0
  35. videox_fun/data/dataset_image_video.py +1939 -0
  36. videox_fun/data/dataset_video.py +262 -0
  37. videox_fun/dist/__init__.py +66 -0
  38. videox_fun/dist/cogvideox_xfuser.py +105 -0
  39. videox_fun/dist/flux_xfuser.py +168 -0
  40. videox_fun/dist/fsdp.py +44 -0
  41. videox_fun/dist/fuser.py +55 -0
  42. videox_fun/dist/qwen_xfuser.py +176 -0
  43. videox_fun/dist/wan_xfuser.py +180 -0
  44. videox_fun/pipeline/__init__.py +21 -0
  45. videox_fun/pipeline/pipeline_wan.py +799 -0
  46. videox_fun/pipeline/pipeline_wan2_2.py +591 -0
  47. videox_fun/ui/cogvideox_fun_ui.py +722 -0
  48. videox_fun/ui/controller.py +514 -0
  49. videox_fun/ui/ui.py +366 -0
  50. videox_fun/ui/wan2_2_fun_ui.py +803 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ *.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ samples/
2
+ models/
3
+ __pycache__/
4
+ *.pyc
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,14 +1,209 @@
1
- ---
2
- title: VideoCoF
3
- emoji: 📉
4
- colorFrom: gray
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 6.1.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: Unified Video Editing with Temporal Reasoner
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <h1 style="margin: 0; font-size: 2.4em;">
4
+ Unified Video Editing with Temporal Reasoner
5
+ </h1>
6
+
7
+ <h4 style="margin: 15px 0; color: #2c3e50;">
8
+ 👁️ See &rarr; 🧠 Reason &rarr; ✏️ Edit
9
+ </h4>
10
+
11
+ <h4 style="margin: 15px 0; color: #2c3e50;">
12
+ 🚀 A Chain of Frames video editing method enbale temporal reasoning and 4x video length extrapolation with just 50k training pairs!
13
+ </h4>
14
+
15
+ [![Hugging Face Daily Paper](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Daily%20Paper-yellow)](https://huggingface.co/papers/2512.07469)
16
+ [![arXiv](https://img.shields.io/badge/arXiv-2512.07469-b31b1b.svg)](https://arxiv.org/abs/2512.07469)
17
+ [![Project Page](https://img.shields.io/badge/Project-Page-green)](https://videocof.github.io)
18
+ [![Hugging Face Model](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-yellow)](https://huggingface.co/XiangpengYang/VideoCoF)
19
+ ![visitors](https://visitor-badge.laobi.icu/badge?page_id=videocof.VideoCoF&left_color=green&right_color=red)
20
+
21
+ </div>
22
+
23
+ <div align="center">
24
+ <b>
25
+ <a href="https://scholar.google.com/citations?user=reiIeYMAAAAJ">Xiangpeng Yang</a><sup>1</sup>,
26
+ <a href="https://horizonwind2004.github.io/">Ji Xie</a><sup>2</sup>,
27
+ <a href="https://scholar.google.com/citations?user=OvfI_HMAAAAJ">Yiyuan Yang</a><sup>1</sup>,
28
+ <a href="https://scholar.google.com/citations?user=zfeWd6gAAAAJ">Yan Huang</a><sup>1</sup>,
29
+ <a href="https://scholar.google.com/citations?user=sCuACdkAAAAJ">Min Xu</a><sup>1</sup>,
30
+ <a href="https://scholar.google.com/citations?user=sCuACdkAAAAJ">Qiang Wu</a><sup>1</sup>
31
+ </b>
32
+ <br>
33
+ <span style="font-size: 1em; color: #555;"><sup>1</sup>University of Technology Sydney, <sup>2</sup>Zhejiang University</span>
34
+ </div>
35
+
36
+ <br>
37
+
38
+ ## 💿 Introduction
39
+
40
+ https://github.com/user-attachments/assets/26f7d347-3d6c-43cf-9645-6eb5906f6ad6
41
+
42
+ ## 🔥 News
43
+
44
+ - **2025.12.09**: Paper available on arXiv.
45
+ - **2025.12.08**: Release the inference code and videocof-50k weight.
46
+ - **2025.12.06**: 🔥 Project Page and README updated!
47
+
48
+
49
+ ## 📑 Table of Contents
50
+
51
+ - [🔧 Quick Start](#-quick-start)
52
+ - [🏆 Model Zoo](#-model-zoo)
53
+ - [🍭 Results](#-results)
54
+ - [🎨 Edit Comparison](#-edit-comparison)
55
+ - [🚧 TODO](#-todo)
56
+ - [🙏 Acknowledgments](#-acknowledgments)
57
+ - [📜 License](#-license)
58
+ - [📮 Contact](#-contact)
59
+ - [📄 Citation](#-citation)
60
+
61
+ ## 🔧 Quick Start
62
+
63
+ 1. **Clone the repository:**
64
+
65
+ ```bash
66
+ git clone https://github.com/videocof/VideoCoF.git
67
+ cd VideoCoF
68
+ ```
69
+
70
+ 2. **Install dependencies:**
71
+
72
+ ```bash
73
+ # 1. Create and activate a conda environment
74
+ conda create -n videocof python=3.10
75
+ conda activate videocof
76
+
77
+ # 2. Install PyTorch (Choose version compatible with your CUDA)
78
+ # For standard GPUs (CUDA 12.1):
79
+ pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
80
+
81
+ # For Hopper GPUs (e.g., H100/H800) requiring fast inference:
82
+ # pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu128
83
+
84
+ # 3. Install other dependencies
85
+ pip install -r requirements.txt
86
+ ```
87
+
88
+ **Note on Flash Attention:**
89
+ We recommend using **FlashAttention-3** (currently beta) for optimal performance, especially on NVIDIA H100/H800 GPUs.
90
+ If you are using these GPUs, please follow the [official FlashAttention-3 installation guide](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release) after installing the compatible PyTorch version (e.g., PyTorch 2.8 + CUDA 12.8).
91
+
92
+
93
+ 3. **Download Models:**
94
+
95
+ **Wan-2.1-T2V-14B Pretrained Weights:**
96
+
97
+ ```bash
98
+ git lfs install
99
+ git clone https://huggingface.co/Wan-AI/Wan2.1-T2V-14B
100
+
101
+ # Or using huggingface-cli:
102
+ # hf download Wan-AI/Wan2.1-T2V-14B --local-dir Wan2.1-T2V-14B
103
+ ```
104
+
105
+ **VideoCoF Checkpoint:**
106
+
107
+ ```bash
108
+ git lfs install
109
+ git clone https://huggingface.co/XiangpengYang/VideoCoF videocof_weight
110
+
111
+ # Or using huggingface-cli:
112
+ # hf download XiangpengYang/VideoCoF --local-dir videocof_weight
113
+ ```
114
+
115
+ 4. **Inference:**
116
+
117
+ For single inference tasks:
118
+
119
+ ```bash
120
+ # Object Removal
121
+ sh scripts/obj_rem.sh
122
+
123
+ # Object Addition
124
+ sh scripts/obj_add.sh
125
+
126
+ # Local Style Transfer
127
+ sh scripts/local_style.sh
128
+ ```
129
+
130
+ For parallel inference:
131
+
132
+ ```bash
133
+ sh scripts/parallel_infer.sh
134
+ ```
135
+
136
+ ## 🏆 Model Zoo
137
+
138
+ Our models are available on Hugging Face:
139
+
140
+ | Model Name | Description | Link |
141
+ |------------|-------------|------|
142
+ | VideoCoF-Base | Base model trained on 50k video pairs | [Hugging Face](https://huggingface.co/XiangpengYang/VideoCoF) |
143
+
144
+ ## 🍭 Results
145
+
146
+ ### Why We Need Reasoning Before Editing?
147
+ ![](assets/motivation_v2.gif)
148
+
149
+ Current video editing methods typically follow two paths:
150
+ 1. **Expert models**: Rely on external masks for precision but sacrifice unification.
151
+ 2. **Unified in-context learning models**: Mask-free but often struggle with spatial accuracy due to the lack of explicit cues.
152
+
153
+ **VideoCoF** bridges this gap by predicting reasoning tokens before generating the target video tokens.
154
+
155
+ ### Key Capabilities
156
+
157
+ 1. **Seeing, Reasoning, Editing**: VideoCoF adopts a "seeing, reasoning, editing" approach, ensuring edits are applied accurately to the intended targets.
158
+ 2. **Length Extrapolation**: Trained on only **50k** data (33 frames), VideoCoF demonstrates robust multi-shot editing and length generalization (e.g., 4&times; length extrapolation).
159
+ 3. **Diverse Editing Tasks**: Supports fine-grained (instance and part level, spatial aware) Object Removal, Object Addition, Object Swap, and Local Style Transfer.
160
+
161
+ ### Gallery Highlights
162
+
163
+ > Please refer to our [Project Page](https://videocof.github.io) for the full gallery.
164
+
165
+ * **Object Removal**: Remove people or objects based on text prompts.
166
+ * **Object Addition**: Add elements like animals, objects, or people.
167
+ * **Object Swap**: Change specific attributes or objects.
168
+ * **Local Style Transfer**: Modify texture, materials or colors.
169
+
170
+ ## 🚧 TODO
171
+
172
+ - [x] Release paper.
173
+ - [x] Release inference code and weights.
174
+ - [ ] Release training code.
175
+ - [ ] Release training data.
176
+ - [ ] Add Hugging Face demo.
177
+
178
+ ## 🙏 Acknowledgments
179
+
180
+ We thank the authors of related works and the open-source community [VideoX-Fun](https://github.com/aigc-apps/VideoX-Fun) and [Wan](https://github.com/Wan-Video/Wan2.1) for their contributions.
181
+
182
+ ## 📜 License
183
+
184
+ This project is licensed under the [Apache License 2.0](LICENSE).
185
+
186
+ ## 📮 Contact
187
+
188
+ For any questions, please feel free to reach out to the author Xiangpeng Yang [@knightyxp](https://github.com/knightyxp), email: knightyxp@gmail.com/Xiangpeng.Yang@student.uts.edu.au
189
+
190
+ ## 📄 Citation
191
+
192
+ If you find this work useful for your research, please consider citing:
193
+
194
+ ```bibtex
195
+ @article{yang2025videocof,
196
+ title={Unified Video Editing with Temporal Reasoner},
197
+ author={Yang, Xiangpeng and Xie, Ji and Yang, Yiyuan and Huang, Yan and Xu, Min and Wu, Qiang},
198
+ journal={arXiv preprint arXiv:2512.07469},
199
+ year={2025}
200
+ }
201
+ ```
202
+
203
+ <div align="center">
204
+ ⭐ **If you find this project helpful, please consider giving it a star!** ⭐
205
+ </div>
206
+
207
+ ## ⭐️ Star History
208
+
209
+ [![Star History Chart](https://api.star-history.com/svg?repos=knightyxp/VideoCoF&type=Date&legend=top-left)](https://star-history.com/#knightyxp/VideoCoF&Date)
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .comfyui.comfyui_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
2
+
3
+ __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
app.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import torch
5
+ import gradio as gr
6
+ import numpy as np
7
+ import imageio
8
+ from PIL import Image
9
+
10
+ # Add project root to path
11
+ # current_file_path = os.path.abspath(__file__)
12
+ # project_root = os.path.dirname(os.path.dirname(current_file_path))
13
+ # if project_root not in sys.path:
14
+ # sys.path.insert(0, project_root)
15
+
16
+ from videox_fun.ui.wan_ui import Wan_Controller, css
17
+ from videox_fun.ui.ui import (
18
+ create_model_type, create_model_checkpoints, create_finetune_models_checkpoints,
19
+ create_teacache_params, create_cfg_skip_params, create_cfg_riflex_k,
20
+ create_prompts, create_samplers, create_height_width,
21
+ create_generation_methods_and_video_length, create_generation_method,
22
+ create_cfg_and_seedbox, create_ui_outputs
23
+ )
24
+ from videox_fun.data.dataset_image_video import derive_ground_object_from_instruction
25
+ from videox_fun.utils.lora_utils import merge_lora, unmerge_lora
26
+ from videox_fun.utils.utils import save_videos_grid, timer
27
+
28
+ # Redefine create_height_width to remove Chinese and specific defaults if needed,
29
+ # although we will mostly ignore sliders if we use input resolution.
30
+ # We will create a custom version here to avoid modifying the library file if possible,
31
+ # or we just rely on `create_height_width` and update labels.
32
+ # But `create_height_width` is imported. Let's override it or create a new one.
33
+
34
+ def create_height_width_english(default_height, default_width, maximum_height, maximum_width):
35
+ resize_method = gr.Radio(
36
+ ["Generate by", "Resize according to Reference"],
37
+ value="Generate by",
38
+ show_label=False,
39
+ visible=False # Hide since we force input resolution
40
+ )
41
+ # We keep sliders visible but maybe we can update them dynamically or just ignore them?
42
+ # User requested "input is whatever resolution, inference is whatever resolution".
43
+ # So we can hide these or just label them as "Default / Override if no video".
44
+ # But better to hide them if we always use video resolution.
45
+ # However, if no video is provided (which shouldn't happen for VideoCoF), we might need them.
46
+ # Let's keep them but make them less prominent or explain.
47
+ # Actually user said "no default 480x832", implying don't force it.
48
+
49
+ width_slider = gr.Slider(label="Width", value=default_width, minimum=128, maximum=maximum_width, step=16, visible=False)
50
+ height_slider = gr.Slider(label="Height", value=default_height, minimum=128, maximum=maximum_height, step=16, visible=False)
51
+ base_resolution = gr.Radio(label="Base Resolution", value=512, choices=[512, 640, 768, 896, 960, 1024], visible=False)
52
+
53
+ return resize_method, width_slider, height_slider, base_resolution
54
+
55
+ def load_video_frames(video_path: str, source_frames: int):
56
+ assert source_frames is not None, "source_frames is required"
57
+
58
+ reader = imageio.get_reader(video_path)
59
+ try:
60
+ total_frames = reader.count_frames()
61
+ except Exception:
62
+ total_frames = sum(1 for _ in reader)
63
+ reader = imageio.get_reader(video_path)
64
+
65
+ stride = max(1, total_frames // source_frames)
66
+ # Using random start frame as in inference.py
67
+ start_frame = torch.randint(0, max(1, total_frames - stride * source_frames), (1,))[0].item()
68
+
69
+ frames = []
70
+ original_height, original_width = None, None
71
+
72
+ for i in range(source_frames):
73
+ idx = start_frame + i * stride
74
+ if idx >= total_frames:
75
+ break
76
+ try:
77
+ frame = reader.get_data(idx)
78
+ pil_frame = Image.fromarray(frame)
79
+ if original_height is None:
80
+ original_width, original_height = pil_frame.size
81
+ frames.append(pil_frame)
82
+ except IndexError:
83
+ break
84
+
85
+ reader.close()
86
+
87
+ while len(frames) < source_frames:
88
+ if frames:
89
+ frames.append(frames[-1].copy())
90
+ else:
91
+ w, h = (original_width, original_height) if original_width else (832, 480)
92
+ frames.append(Image.new('RGB', (w, h), (0, 0, 0)))
93
+
94
+ input_video = torch.from_numpy(np.array(frames))
95
+ input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0).float()
96
+ input_video = input_video * (2.0 / 255.0) - 1.0
97
+
98
+ return input_video, original_height, original_width
99
+
100
+ class VideoCoF_Controller(Wan_Controller):
101
+ @timer
102
+ def generate(
103
+ self,
104
+ diffusion_transformer_dropdown,
105
+ base_model_dropdown,
106
+ lora_model_dropdown,
107
+ lora_alpha_slider,
108
+ prompt_textbox,
109
+ negative_prompt_textbox,
110
+ sampler_dropdown,
111
+ sample_step_slider,
112
+ resize_method,
113
+ width_slider,
114
+ height_slider,
115
+ base_resolution,
116
+ generation_method,
117
+ length_slider,
118
+ overlap_video_length,
119
+ partial_video_length,
120
+ cfg_scale_slider,
121
+ start_image,
122
+ end_image,
123
+ validation_video,
124
+ validation_video_mask,
125
+ control_video,
126
+ denoise_strength,
127
+ seed_textbox,
128
+ ref_image=None,
129
+ enable_teacache=None,
130
+ teacache_threshold=None,
131
+ num_skip_start_steps=None,
132
+ teacache_offload=None,
133
+ cfg_skip_ratio=None,
134
+ enable_riflex=None,
135
+ riflex_k=None,
136
+ # Custom args
137
+ source_frames_slider=33,
138
+ reasoning_frames_slider=4,
139
+ repeat_rope_checkbox=True,
140
+ fps=10,
141
+ is_api=False,
142
+ ):
143
+ self.clear_cache()
144
+ print(f"VideoCoF Generation started.")
145
+
146
+ if self.diffusion_transformer_dropdown != diffusion_transformer_dropdown:
147
+ self.update_diffusion_transformer(diffusion_transformer_dropdown)
148
+
149
+ if self.base_model_path != base_model_dropdown:
150
+ self.update_base_model(base_model_dropdown)
151
+
152
+ if self.lora_model_path != lora_model_dropdown:
153
+ self.update_lora_model(lora_model_dropdown)
154
+
155
+ # Scheduler setup
156
+ scheduler_config = self.pipeline.scheduler.config
157
+ if sampler_dropdown in ["Flow_Unipc", "Flow_DPM++"]:
158
+ scheduler_config['shift'] = 1
159
+ self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config)
160
+
161
+ # LoRA merging
162
+ if self.lora_model_path != "none":
163
+ print(f"Merge Lora.")
164
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
165
+
166
+ # Seed
167
+ if int(seed_textbox) != -1 and seed_textbox != "":
168
+ torch.manual_seed(int(seed_textbox))
169
+ else:
170
+ seed_textbox = np.random.randint(0, 1e10)
171
+ generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox))
172
+
173
+ try:
174
+ # VideoCoF logic
175
+ # Use validation_video as source if provided (UI standard for Video-to-Video)
176
+ input_video_path = validation_video
177
+
178
+ if input_video_path is None:
179
+ # Fallback to control_video if set, but standard UI uses validation_video
180
+ input_video_path = control_video
181
+
182
+ if input_video_path is None:
183
+ raise ValueError("Please upload a video for VideoCoF generation.")
184
+
185
+ # CoT Prompt Construction
186
+ edit_text = prompt_textbox
187
+ ground_instr = derive_ground_object_from_instruction(edit_text)
188
+ prompt = (
189
+ "A video sequence showing three parts: first the original scene, "
190
+ f"then grounded {ground_instr}, and finally the same scene but {edit_text}"
191
+ )
192
+ print(f"Constructed prompt: {prompt}")
193
+
194
+ # Load video frames
195
+ input_video_tensor, video_height, video_width = load_video_frames(
196
+ input_video_path,
197
+ source_frames=source_frames_slider
198
+ )
199
+
200
+ # Using loaded video dimensions
201
+ h, w = video_height, video_width
202
+ print(f"Input video dimensions: {w}x{h}")
203
+
204
+ print(f"Running pipeline with frames={length_slider}, source={source_frames_slider}, reasoning={reasoning_frames_slider}")
205
+
206
+ sample = self.pipeline(
207
+ video=input_video_tensor,
208
+ prompt=prompt,
209
+ num_frames=length_slider,
210
+ source_frames=source_frames_slider,
211
+ reasoning_frames=reasoning_frames_slider,
212
+ negative_prompt=negative_prompt_textbox,
213
+ height=h,
214
+ width=w,
215
+ generator=generator,
216
+ guidance_scale=cfg_scale_slider,
217
+ num_inference_steps=sample_step_slider,
218
+ repeat_rope=repeat_rope_checkbox,
219
+ cot=True,
220
+ ).videos
221
+
222
+ final_video = sample
223
+
224
+ except Exception as e:
225
+ print(f"Error: {e}")
226
+ if self.lora_model_path != "none":
227
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
228
+ return gr.update(), gr.update(), f"Error: {str(e)}"
229
+
230
+ # Unmerge LoRA
231
+ if self.lora_model_path != "none":
232
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
233
+
234
+ # Save output
235
+ save_sample_path = self.save_outputs(
236
+ False, length_slider, final_video, fps=fps
237
+ )
238
+
239
+ # Return input video to display it alongside output if needed?
240
+ # But generate returns [result_image, result_video, infer_progress].
241
+ # The user said "load original video didn't display".
242
+ # That usually refers to the input component not showing the video after upload or example selection.
243
+ # Grado handles that automatically if `value` is set or user uploads.
244
+ # Maybe they mean the `validation_video` component didn't show the example?
245
+ # Or do they mean they want to see the processed input frames?
246
+ # "load 原视频没有display 出来" -> "Loaded original video didn't display".
247
+ # Likely referring to the input UI component.
248
+ # If they mean they want to see it in the output area, we can't easily change the return signature without changing UI structure.
249
+ # But let's ensure the input component works.
250
+
251
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
252
+
253
+ def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype):
254
+ controller = VideoCoF_Controller(
255
+ GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
256
+ config_path=config_path, compile_dit=compile_dit,
257
+ weight_dtype=weight_dtype
258
+ )
259
+
260
+ with gr.Blocks() as demo:
261
+ gr.Markdown("# VideoCoF Demo")
262
+
263
+ with gr.Column(variant="panel"):
264
+ # Hide model selection
265
+ diffusion_transformer_dropdown, _ = create_model_checkpoints(controller, visible=False, default_model="Wan-AI/Wan2.1-T2V-14B")
266
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider, _ = create_finetune_models_checkpoints(controller, visible=False, default_lora="XiangpengYang/VideoCoF")
267
+
268
+ # Set default LoRA alpha to 1.0 (matching inference.py)
269
+ lora_alpha_slider.value = 1.0
270
+
271
+ with gr.Row():
272
+ # Disable teacache by default
273
+ enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = create_teacache_params(False, 0.10, 5, False)
274
+ cfg_skip_ratio = create_cfg_skip_params(0)
275
+ enable_riflex, riflex_k = create_cfg_riflex_k(False, 6)
276
+
277
+ with gr.Column(variant="panel"):
278
+ prompt_textbox, negative_prompt_textbox = create_prompts(prompt="Remove the young man with short black hair wearing black shirt on the left.")
279
+
280
+ with gr.Row():
281
+ with gr.Column():
282
+ sampler_dropdown, sample_step_slider = create_samplers(controller)
283
+
284
+ # Custom VideoCoF Params
285
+ with gr.Group():
286
+ gr.Markdown("### VideoCoF Parameters")
287
+ source_frames_slider = gr.Slider(label="Source Frames", minimum=1, maximum=100, value=33, step=1)
288
+ reasoning_frames_slider = gr.Slider(label="Reasoning Frames", minimum=1, maximum=20, value=4, step=1)
289
+ repeat_rope_checkbox = gr.Checkbox(label="Repeat RoPE", value=True)
290
+
291
+ # Use custom height/width creation to hide/customize
292
+ resize_method, width_slider, height_slider, base_resolution = create_height_width_english(
293
+ default_height=480, default_width=832, maximum_height=1344, maximum_width=1344
294
+ )
295
+
296
+ # Default video length 65
297
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
298
+ create_generation_methods_and_video_length(
299
+ ["Video Generation"],
300
+ default_video_length=65,
301
+ maximum_video_length=161
302
+ )
303
+
304
+ # Simplified input for VideoCoF - mainly Video to Video.
305
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
306
+ ["Video to Video"], prompt_textbox, support_end_image=False, default_video="assets/two_man.mp4",
307
+ video_examples=[
308
+ ["assets/two_man.mp4", "Remove the young man with short black hair wearing black shirt on the left."],
309
+ ["assets/sign.mp4", "Replace the yellow \"SCHOOL\" sign with a red hospital sign, featuring a white hospital emblem on the top and the word \"HOSPITAL\" below."]
310
+ ]
311
+ )
312
+
313
+ # Ensure validation_video is visible and interactive
314
+ validation_video.visible = True
315
+ validation_video.interactive = True
316
+
317
+ # Set default seed to 0
318
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(True)
319
+ seed_textbox.value = "0"
320
+
321
+ generate_button = gr.Button(value="Generate", variant='primary')
322
+
323
+ result_image, result_video, infer_progress = create_ui_outputs()
324
+
325
+ # Event handlers
326
+ generate_button.click(
327
+ fn=controller.generate,
328
+ inputs=[
329
+ diffusion_transformer_dropdown,
330
+ base_model_dropdown,
331
+ lora_model_dropdown,
332
+ lora_alpha_slider,
333
+ prompt_textbox,
334
+ negative_prompt_textbox,
335
+ sampler_dropdown,
336
+ sample_step_slider,
337
+ resize_method,
338
+ width_slider,
339
+ height_slider,
340
+ base_resolution,
341
+ generation_method,
342
+ length_slider,
343
+ overlap_video_length,
344
+ partial_video_length,
345
+ cfg_scale_slider,
346
+ start_image,
347
+ end_image,
348
+ validation_video,
349
+ validation_video_mask,
350
+ control_video,
351
+ denoise_strength,
352
+ seed_textbox,
353
+ ref_image,
354
+ enable_teacache,
355
+ teacache_threshold,
356
+ num_skip_start_steps,
357
+ teacache_offload,
358
+ cfg_skip_ratio,
359
+ enable_riflex,
360
+ riflex_k,
361
+ # New inputs
362
+ source_frames_slider,
363
+ reasoning_frames_slider,
364
+ repeat_rope_checkbox
365
+ ],
366
+ outputs=[result_image, result_video, infer_progress]
367
+ )
368
+
369
+ return demo, controller
370
+
371
+ if __name__ == "__main__":
372
+ from videox_fun.ui.controller import flow_scheduler_dict
373
+
374
+ GPU_memory_mode = "sequential_cpu_offload"
375
+ compile_dit = False
376
+ weight_dtype = torch.bfloat16
377
+ server_name = "0.0.0.0"
378
+ server_port = 7860
379
+ config_path = "config/wan2.1/wan_civitai.yaml"
380
+
381
+ demo, controller = ui(GPU_memory_mode, flow_scheduler_dict, config_path, compile_dit, weight_dtype)
382
+
383
+ demo.queue(status_update_rate=1).launch(
384
+ server_name=server_name,
385
+ server_port=server_port,
386
+ prevent_thread_lock=True,
387
+ share=False
388
+ )
389
+
390
+ while True:
391
+ time.sleep(5)
assets/dough.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5262cf58ffa08dcd79d6346abec46bc0234aebfc65905b5ea2ca4ab905ca9dac
3
+ size 185700
assets/sign.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e94f03a7d5738a001ce2e1302a8ae65596431a647dbfed83cdb6876322175a7
3
+ size 100798
assets/teaser_test.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "task_type": "obj_add",
4
+ "sample_id": "001",
5
+ "source_video_path": "assets/woman_ballon.mp4",
6
+ "qwen_vl_72b_refined_instruction": "Add the woman in a floral dress pointing at the balloon on the left."
7
+ },
8
+ {
9
+ "task_type": "obj_rem",
10
+ "sample_id": "001",
11
+ "source_video_path": "assets/two_man.mp4",
12
+ "qwen_vl_72b_refined_instruction": "Remove the young man with short black hair wearing black shirt on the left."
13
+ },
14
+ {
15
+ "task_type": "local_style",
16
+ "sample_id": "001",
17
+ "source_video_path": "assets/sign.mp4",
18
+ "qwen_vl_72b_refined_instruction": "Replace the yellow \"SCHOOL\" sign with a red hospital sign, featuring a white hospital emblem on the top and the word \"HOSPITAL\" below."
19
+ }
20
+ ]
assets/two_man.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd9c0f6523207bbcf0d5159beb7f7eaf37811e6e5b7a53585dda50491a573cd9
3
+ size 303233
assets/woman_ballon.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:575b37abda414161179bc00e0e7b6893f28feb967e875c8f9676275d2cc32572
3
+ size 89085
config/1.3b_lora_zero_stage2_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "train_micro_batch_size_per_gpu": 4,
6
+ "train_batch_size": 64,
7
+ "gradient_accumulation_steps": 1,
8
+ "gradient_clipping": 0.05,
9
+ "zero_optimization": {
10
+ "stage": 2,
11
+ "offload_optimizer": {
12
+ "device": "none"
13
+ },
14
+ "overlap_comm": true,
15
+ "contiguous_gradients": true,
16
+ "sub_group_size": 1e9,
17
+ "reduce_bucket_size": 5e8,
18
+ "allgather_partitions": true,
19
+ "allgather_bucket_size": 2e8,
20
+ "reduce_scatter": true
21
+ },
22
+ "steps_per_print": 100,
23
+ "wall_clock_breakdown": false
24
+ }
config/14b_lora_zero2_bf16_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "train_micro_batch_size_per_gpu": 1,
6
+ "train_batch_size": "auto",
7
+ "gradient_accumulation_steps": 1,
8
+ "gradient_clipping": 0.05,
9
+ "zero_optimization": {
10
+ "stage": 2,
11
+ "offload_optimizer": {
12
+ "device": "none"
13
+ },
14
+ "overlap_comm": true,
15
+ "contiguous_gradients": true,
16
+ "sub_group_size": 1e9,
17
+ "reduce_bucket_size": 5e8,
18
+ "allgather_partitions": true,
19
+ "allgather_bucket_size": 2e8,
20
+ "reduce_scatter": true
21
+ },
22
+ "steps_per_print": 100,
23
+ "wall_clock_breakdown": false
24
+ }
config/wan2.1/wan_civitai.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ format: civitai
2
+ pipeline: Wan
3
+ transformer_additional_kwargs:
4
+ transformer_subpath: ./
5
+ dict_mapping:
6
+ in_dim: in_channels
7
+ dim: hidden_size
8
+
9
+ vae_kwargs:
10
+ vae_subpath: Wan2.1_VAE.pth
11
+ temporal_compression_ratio: 4
12
+ spatial_compression_ratio: 8
13
+
14
+ text_encoder_kwargs:
15
+ text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
16
+ tokenizer_subpath: google/umt5-xxl
17
+ text_length: 512
18
+ vocab: 256384
19
+ dim: 4096
20
+ dim_attn: 4096
21
+ dim_ffn: 10240
22
+ num_heads: 64
23
+ num_layers: 24
24
+ num_buckets: 32
25
+ shared_pos: False
26
+ dropout: 0.0
27
+
28
+ scheduler_kwargs:
29
+ scheduler_subpath: null
30
+ num_train_timesteps: 1000
31
+ shift: 5.0
32
+ use_dynamic_shifting: false
33
+ base_shift: 0.5
34
+ max_shift: 1.15
35
+ base_image_seq_len: 256
36
+ max_image_seq_len: 4096
37
+
38
+ image_encoder_kwargs:
39
+ image_encoder_subpath: models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
config/wan2.2/wan_civitai_5b.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ format: civitai
2
+ pipeline: Wan
3
+ transformer_additional_kwargs:
4
+ transformer_low_noise_model_subpath: ./
5
+ transformer_combination_type: "single"
6
+ dict_mapping:
7
+ in_dim: in_channels
8
+ dim: hidden_size
9
+
10
+ vae_kwargs:
11
+ vae_type: "AutoencoderKLWan3_8"
12
+ vae_subpath: Wan2.2_VAE.pth
13
+ temporal_compression_ratio: 4
14
+ spatial_compression_ratio: 16
15
+
16
+ text_encoder_kwargs:
17
+ text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
18
+ tokenizer_subpath: google/umt5-xxl
19
+ text_length: 512
20
+ vocab: 256384
21
+ dim: 4096
22
+ dim_attn: 4096
23
+ dim_ffn: 10240
24
+ num_heads: 64
25
+ num_layers: 24
26
+ num_buckets: 32
27
+ shared_pos: False
28
+ dropout: 0.0
29
+
30
+ scheduler_kwargs:
31
+ scheduler_subpath: null
32
+ num_train_timesteps: 1000
33
+ shift: 5.0
34
+ use_dynamic_shifting: false
35
+ base_shift: 0.5
36
+ max_shift: 1.15
37
+ base_image_seq_len: 256
38
+ max_image_seq_len: 4096
39
+
40
+ image_encoder_kwargs:
41
+ image_encoder_subpath: models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
config/wan2.2/wan_civitai_i2v.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ format: civitai
2
+ pipeline: Wan
3
+ transformer_additional_kwargs:
4
+ transformer_low_noise_model_subpath: ./low_noise_model
5
+ transformer_high_noise_model_subpath: ./high_noise_model
6
+ transformer_combination_type: "moe"
7
+ boundary: 0.900
8
+ dict_mapping:
9
+ in_dim: in_channels
10
+ dim: hidden_size
11
+
12
+ vae_kwargs:
13
+ vae_type: "AutoencoderKLWan"
14
+ vae_subpath: Wan2.1_VAE.pth
15
+ temporal_compression_ratio: 4
16
+ spatial_compression_ratio: 8
17
+
18
+ text_encoder_kwargs:
19
+ text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
20
+ tokenizer_subpath: google/umt5-xxl
21
+ text_length: 512
22
+ vocab: 256384
23
+ dim: 4096
24
+ dim_attn: 4096
25
+ dim_ffn: 10240
26
+ num_heads: 64
27
+ num_layers: 24
28
+ num_buckets: 32
29
+ shared_pos: False
30
+ dropout: 0.0
31
+
32
+ scheduler_kwargs:
33
+ scheduler_subpath: null
34
+ num_train_timesteps: 1000
35
+ shift: 5.0
36
+ use_dynamic_shifting: false
37
+ base_shift: 0.5
38
+ max_shift: 1.15
39
+ base_image_seq_len: 256
40
+ max_image_seq_len: 4096
41
+
42
+ image_encoder_kwargs:
43
+ image_encoder_subpath: models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
config/wan2.2/wan_civitai_s2v.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ format: civitai
2
+ pipeline: Wan
3
+ transformer_additional_kwargs:
4
+ transformer_low_noise_model_subpath: ./
5
+ transformer_combination_type: "single"
6
+ dict_mapping:
7
+ in_dim: in_channels
8
+ dim: hidden_size
9
+
10
+ vae_kwargs:
11
+ vae_type: "AutoencoderKLWan"
12
+ vae_subpath: Wan2.1_VAE.pth
13
+ temporal_compression_ratio: 4
14
+ spatial_compression_ratio: 8
15
+
16
+ text_encoder_kwargs:
17
+ text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
18
+ tokenizer_subpath: google/umt5-xxl
19
+ text_length: 512
20
+ vocab: 256384
21
+ dim: 4096
22
+ dim_attn: 4096
23
+ dim_ffn: 10240
24
+ num_heads: 64
25
+ num_layers: 24
26
+ num_buckets: 32
27
+ shared_pos: False
28
+ dropout: 0.0
29
+
30
+ audio_encoder_kwargs:
31
+ audio_encoder_subpath: wav2vec2-large-xlsr-53-english
32
+
33
+ scheduler_kwargs:
34
+ scheduler_subpath: null
35
+ num_train_timesteps: 1000
36
+ shift: 3.0
37
+ use_dynamic_shifting: false
38
+ base_shift: 0.5
39
+ max_shift: 1.15
40
+ base_image_seq_len: 256
41
+ max_image_seq_len: 4096
42
+
43
+ image_encoder_kwargs:
44
+ image_encoder_subpath: models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
config/wan2.2/wan_civitai_t2v.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ format: civitai
2
+ pipeline: Wan
3
+ transformer_additional_kwargs:
4
+ transformer_low_noise_model_subpath: ./low_noise_model
5
+ transformer_high_noise_model_subpath: ./high_noise_model
6
+ transformer_combination_type: "moe"
7
+ boundary: 0.875
8
+ dict_mapping:
9
+ in_dim: in_channels
10
+ dim: hidden_size
11
+
12
+ vae_kwargs:
13
+ vae_type: "AutoencoderKLWan"
14
+ vae_subpath: Wan2.1_VAE.pth
15
+ temporal_compression_ratio: 4
16
+ spatial_compression_ratio: 8
17
+
18
+ text_encoder_kwargs:
19
+ text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
20
+ tokenizer_subpath: google/umt5-xxl
21
+ text_length: 512
22
+ vocab: 256384
23
+ dim: 4096
24
+ dim_attn: 4096
25
+ dim_ffn: 10240
26
+ num_heads: 64
27
+ num_layers: 24
28
+ num_buckets: 32
29
+ shared_pos: False
30
+ dropout: 0.0
31
+
32
+ scheduler_kwargs:
33
+ scheduler_subpath: null
34
+ num_train_timesteps: 1000
35
+ shift: 12.0
36
+ use_dynamic_shifting: false
37
+ base_shift: 0.5
38
+ max_shift: 1.15
39
+ base_image_seq_len: 256
40
+ max_image_seq_len: 4096
41
+
42
+ image_encoder_kwargs:
43
+ image_encoder_subpath: models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
config/zero_stage2_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "train_micro_batch_size_per_gpu": 1,
6
+ "train_batch_size": "auto",
7
+ "gradient_accumulation_steps": "auto",
8
+ "dump_state": true,
9
+ "zero_optimization": {
10
+ "stage": 2,
11
+ "overlap_comm": true,
12
+ "contiguous_gradients": true,
13
+ "sub_group_size": 1e9,
14
+ "reduce_bucket_size": 5e8
15
+ }
16
+ }
config/zero_stage3_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "train_micro_batch_size_per_gpu": 1,
6
+ "train_batch_size": "auto",
7
+ "gradient_accumulation_steps": "auto",
8
+ "gradient_clipping": "auto",
9
+ "steps_per_print": 2000,
10
+ "wall_clock_breakdown": false,
11
+ "zero_optimization": {
12
+ "stage": 3,
13
+ "overlap_comm": true,
14
+ "contiguous_gradients": true,
15
+ "reduce_bucket_size": 5e8,
16
+ "sub_group_size": 1e9,
17
+ "stage3_max_live_parameters": 1e9,
18
+ "stage3_max_reuse_distance": 1e9,
19
+ "stage3_gather_16bit_weights_on_model_save": "auto",
20
+ "offload_optimizer": {
21
+ "device": "none"
22
+ },
23
+ "offload_param": {
24
+ "device": "none"
25
+ }
26
+ }
27
+ }
config/zero_stage3_config_cpu_offload.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "train_micro_batch_size_per_gpu": 1,
6
+ "train_batch_size": "auto",
7
+ "gradient_accumulation_steps": "auto",
8
+ "gradient_clipping": "auto",
9
+ "steps_per_print": 2000,
10
+ "wall_clock_breakdown": false,
11
+ "zero_optimization": {
12
+ "stage": 3,
13
+ "overlap_comm": true,
14
+ "contiguous_gradients": true,
15
+ "reduce_bucket_size": 5e8,
16
+ "sub_group_size": 1e9,
17
+ "stage3_max_live_parameters": 1e9,
18
+ "stage3_max_reuse_distance": 1e9,
19
+ "stage3_gather_16bit_weights_on_model_save": "auto",
20
+ "offload_optimizer": {
21
+ "device": "cpu"
22
+ },
23
+ "offload_param": {
24
+ "device": "cpu"
25
+ }
26
+ }
27
+ }
28
+
inference.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import argparse
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.distributed as dist
9
+ from diffusers import FlowMatchEulerDiscreteScheduler
10
+ from omegaconf import OmegaConf
11
+ from PIL import Image
12
+ import imageio
13
+
14
+ current_file_path = os.path.abspath(__file__)
15
+ project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
16
+ for project_root in project_roots:
17
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
18
+
19
+ from videox_fun.models import (AutoencoderKLWan, WanT5EncoderModel, AutoTokenizer,
20
+ WanTransformer3DModel)
21
+ from videox_fun.pipeline import WanPipeline
22
+ from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, replace_parameters_by_name,
23
+ convert_weight_dtype_wrapper)
24
+ from videox_fun.utils.lora_utils import merge_lora, unmerge_lora
25
+ from videox_fun.utils.utils import (filter_kwargs, save_videos_grid)
26
+ from videox_fun.data.dataset_image_video import derive_ground_object_from_instruction
27
+ from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler
28
+ from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
29
+
30
+
31
+ def load_video_frames(
32
+ video_path: str,
33
+ source_frames: int = None,
34
+ ):
35
+ assert source_frames is not None, "请传入 source_frames"
36
+
37
+ reader = imageio.get_reader(video_path)
38
+ try:
39
+ total_frames = reader.count_frames()
40
+ except Exception:
41
+ total_frames = sum(1 for _ in reader)
42
+ reader = imageio.get_reader(video_path)
43
+
44
+ stride = max(1, total_frames // source_frames)
45
+ start_frame = torch.randint(0, max(1, total_frames - stride * source_frames), (1,))[0].item()
46
+
47
+ frames = []
48
+ original_height, original_width = None, None
49
+
50
+ for i in range(source_frames):
51
+ idx = start_frame + i * stride
52
+ if idx >= total_frames:
53
+ break
54
+ try:
55
+ frame = reader.get_data(idx)
56
+ pil_frame = Image.fromarray(frame)
57
+ if original_height is None:
58
+ original_width, original_height = pil_frame.size
59
+ print(f"Original video dimensions: {original_width}x{original_height}")
60
+ frames.append(pil_frame)
61
+ except IndexError:
62
+ break
63
+
64
+ reader.close()
65
+
66
+ while len(frames) < source_frames:
67
+ if frames:
68
+ frames.append(frames[-1].copy())
69
+ else:
70
+ w, h = (original_width, original_height) if original_width else (832, 480)
71
+ frames.append(Image.new('RGB', (w, h), (0, 0, 0)))
72
+
73
+ assert len(frames) == source_frames
74
+ print(f"Loaded {source_frames} source frames")
75
+
76
+ input_video = torch.from_numpy(np.array(frames))
77
+ input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0).float()
78
+ input_video = input_video * (2.0 / 255.0) - 1.0
79
+
80
+ return input_video, original_height, original_width
81
+
82
+
83
+ def parse_args():
84
+ parser = argparse.ArgumentParser(description="Video-to-video CoT reasoning generation from JSON task list with parallel inference")
85
+ parser.add_argument("--test_json", type=str, default=None, help="Path to test JSON file for batch inference")
86
+ parser.add_argument("--prompt", type=str, default=None, help="Text prompt for editing (single mode)")
87
+ parser.add_argument("--video_path", type=str, default=None, help="Path to input video (single mode)")
88
+ parser.add_argument("--model_name", type=str, default="/scratch3/yan204/models/Wan2.1-T2V-14B", help="Model checkpoint path")
89
+ parser.add_argument("--output_dir", type=str, required=True, help="Output directory for generated videos")
90
+ parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducible generation")
91
+ parser.add_argument("--videocof_path", type=str, default=None, help="Path to videocof weight checkpoint")
92
+ parser.add_argument("--num_frames", type=int, default=65, help="Total number of frames (input + generated)")
93
+ parser.add_argument("--source_frames", type=int, default=33, help="Number of source frames; default 33")
94
+ parser.add_argument("--reasoning_frames", type=int, default=4, help="Grounding frames in the middle segment (pixel-space)")
95
+ parser.add_argument("--repeat_rope", action="store_true", help="Enable repeat temporal RoPE for src and tgt segments")
96
+ return parser.parse_args()
97
+
98
+
99
+ # Defaults aligned with predict_v2v_json_new.py
100
+ GPU_memory_mode = "sequential_cpu_offload"
101
+ ulysses_degree = 1
102
+ ring_degree = 1
103
+ fsdp_dit = False
104
+ fsdp_text_encoder = True
105
+ compile_dit = False
106
+ enable_teacache = True
107
+ teacache_threshold = 0.10
108
+ num_skip_start_steps = 5
109
+ teacache_offload = False
110
+ cfg_skip_ratio = 0
111
+ enable_riflex = False
112
+ riflex_k = 6
113
+
114
+ config_path = "config/wan2.1/wan_civitai.yaml"
115
+ model_name = "/scratch3/yan204/models/Wan2.1-T2V-14B"
116
+ sampler_name = "Flow_Unipc"
117
+ shift = 3
118
+ transformer_path = None
119
+ vae_path = None
120
+
121
+ fps = 10
122
+ weight_dtype = torch.bfloat16
123
+ negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
124
+ guidance_scale = 5.0
125
+ num_inference_steps = 50
126
+ lora_weight = 1.0
127
+
128
+
129
+ def save_results(tensor: torch.Tensor, file_path: str, fps_out: int = 16):
130
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
131
+ B, C, T, H, W = tensor.shape
132
+ arr = tensor[0].cpu().numpy()
133
+ if T == 1:
134
+ img = arr[:, 0].transpose(1, 2, 0)
135
+ img = (img * 255).astype(np.uint8)
136
+ Image.fromarray(img).save(file_path)
137
+ else:
138
+ save_videos_grid(tensor, file_path, fps=fps_out)
139
+ print(f"Saved video → {file_path}")
140
+
141
+
142
+ def _normalize_to_01(video: torch.Tensor) -> torch.Tensor:
143
+ with torch.no_grad():
144
+ vmin = float(video.min())
145
+ vmax = float(video.max())
146
+ if vmin < 0.0 or vmax > 1.0:
147
+ video = (video + 1.0) / 2.0
148
+ return video.clamp(0.0, 1.0)
149
+
150
+
151
+ def save_side_by_side(input_tensor: torch.Tensor, sample_tensor: torch.Tensor, file_path: str, fps_out: int = 16):
152
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
153
+ a = _normalize_to_01(input_tensor.detach().cpu())
154
+ b = _normalize_to_01(sample_tensor.detach().cpu())
155
+
156
+ # Align dimensions by cropping to the minimum across T/H/W
157
+ T = min(a.shape[2], b.shape[2])
158
+ H = min(a.shape[3], b.shape[3])
159
+ W = min(a.shape[4], b.shape[4])
160
+ a = a[:, :, :T, :H, :W]
161
+ b = b[:, :, :T, :H, :W]
162
+
163
+ combined = torch.cat([a, b], dim=4)
164
+ save_videos_grid(combined, file_path, fps=fps_out)
165
+ print(f"Saved side-by-side video → {file_path}")
166
+
167
+
168
+ def derive_ground_instruction(edit_instruction_text: str) -> str:
169
+ # Keep wrapper for backward compatibility; reuse the same rule as training dataset
170
+ return derive_ground_object_from_instruction(edit_instruction_text)
171
+
172
+
173
+ def main():
174
+ args = parse_args()
175
+
176
+ # Initialize DDP
177
+ dist.init_process_group(backend="nccl")
178
+ rank = dist.get_rank()
179
+ world_size = dist.get_world_size()
180
+ local_rank = int(os.environ.get("LOCAL_RANK", rank % max(1, torch.cuda.device_count())))
181
+ torch.cuda.set_device(local_rank)
182
+
183
+ if rank == 0:
184
+ print(f"Running parallel CoT inference with {world_size} GPUs")
185
+ print(f"Using seed: {args.seed}")
186
+
187
+ model_name = args.model_name
188
+
189
+ # Load tasks
190
+ if args.test_json:
191
+ if rank == 0:
192
+ print(f"Loading tasks from JSON: {args.test_json}")
193
+ with open(args.test_json, 'r', encoding='utf-8') as f:
194
+ eval_prompts_list = json.load(f)
195
+
196
+ eval_prompts = {}
197
+ for item in eval_prompts_list:
198
+ # Assume item has structure compatible or use fallback logic
199
+ # Here we expect task_type/sample_id to uniquely identify, or we create a name
200
+ if 'task_type' in item and 'sample_id' in item:
201
+ fname = f"{item['task_type']}_{item['sample_id']}.mp4"
202
+ else:
203
+ # Fallback naming if JSON structure is different
204
+ fname = f"sample_{len(eval_prompts)}.mp4"
205
+ eval_prompts[fname] = item
206
+ items = list(eval_prompts.items())
207
+
208
+ elif args.video_path and args.prompt:
209
+ if rank == 0:
210
+ print(f"Running in single video mode: {args.video_path}")
211
+ fname = os.path.basename(args.video_path)
212
+ item = {
213
+ "source_video_path": args.video_path,
214
+ "edit_instruction": args.prompt
215
+ }
216
+ items = [(fname, item)]
217
+ else:
218
+ raise ValueError("Must provide either --test_json or both --video_path and --prompt")
219
+
220
+ # Filter done
221
+ pending_items = []
222
+ for fname, item in items:
223
+ base = os.path.splitext(fname)[0]
224
+ output_video_path = os.path.join(args.output_dir, f"gen_{base}.mp4")
225
+ if not os.path.exists(output_video_path):
226
+ pending_items.append((fname, item))
227
+
228
+ if rank == 0:
229
+ print(f"Total items: {len(items)}, already generated: {len(items) - len(pending_items)}, pending: {len(pending_items)}")
230
+
231
+ # Shard across GPUs
232
+ subset_items = pending_items[rank::world_size] if world_size > 0 else pending_items
233
+
234
+ print(f"[GPU {rank} | local {local_rank}] Processing {len(subset_items)} items")
235
+
236
+ device = torch.device(f"cuda:{local_rank}")
237
+
238
+ # Load config and models
239
+ config = OmegaConf.load(config_path)
240
+
241
+ transformer = WanTransformer3DModel.from_pretrained(
242
+ os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
243
+ transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
244
+ low_cpu_mem_usage=True,
245
+ torch_dtype=weight_dtype,
246
+ )
247
+
248
+ if transformer_path is not None:
249
+ print(f"[GPU {rank}] Loading transformer from checkpoint: {transformer_path}")
250
+ if transformer_path.endswith("safetensors"):
251
+ from safetensors.torch import load_file
252
+ state_dict = load_file(transformer_path)
253
+ else:
254
+ state_dict = torch.load(transformer_path, map_location="cpu")
255
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
256
+ m, u = transformer.load_state_dict(state_dict, strict=False)
257
+ print(f"[GPU {rank}] Missing keys: {len(m)}, unexpected keys: {len(u)}")
258
+
259
+ vae = AutoencoderKLWan.from_pretrained(
260
+ os.path.join(model_name, config['vae_kwargs'].get('vae_subpath', 'vae')),
261
+ additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
262
+ ).to(weight_dtype)
263
+
264
+ if vae_path is not None:
265
+ print(f"[GPU {rank}] Loading VAE from checkpoint: {vae_path}")
266
+ if vae_path.endswith("safetensors"):
267
+ from safetensors.torch import load_file
268
+ state_dict = load_file(vae_path)
269
+ else:
270
+ state_dict = torch.load(vae_path, map_location="cpu")
271
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
272
+ m, u = vae.load_state_dict(state_dict, strict=False)
273
+ print(f"[GPU {rank}] Missing keys: {len(m)}, unexpected keys: {len(u)}")
274
+
275
+ tokenizer = AutoTokenizer.from_pretrained(
276
+ os.path.join(model_name, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
277
+ )
278
+
279
+ text_encoder = WanT5EncoderModel.from_pretrained(
280
+ os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
281
+ additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
282
+ low_cpu_mem_usage=True,
283
+ torch_dtype=weight_dtype,
284
+ )
285
+
286
+ Choosen_Scheduler = {
287
+ "Flow": FlowMatchEulerDiscreteScheduler,
288
+ "Flow_Unipc": FlowUniPCMultistepScheduler,
289
+ "Flow_DPM++": FlowDPMSolverMultistepScheduler,
290
+ }[sampler_name]
291
+ if sampler_name in ["Flow_Unipc", "Flow_DPM++"]:
292
+ config['scheduler_kwargs']['shift'] = 1
293
+ scheduler = Choosen_Scheduler(
294
+ **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs']))
295
+ )
296
+
297
+ pipeline = WanPipeline(
298
+ transformer=transformer,
299
+ vae=vae,
300
+ tokenizer=tokenizer,
301
+ text_encoder=text_encoder,
302
+ scheduler=scheduler,
303
+ )
304
+
305
+ # Memory mode
306
+ if GPU_memory_mode == "sequential_cpu_offload":
307
+ replace_parameters_by_name(transformer, ["modulation",], device=device)
308
+ transformer.freqs = transformer.freqs.to(device=device)
309
+ pipeline.enable_sequential_cpu_offload(device=device)
310
+ elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
311
+ convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device)
312
+ convert_weight_dtype_wrapper(transformer, weight_dtype)
313
+ pipeline.enable_model_cpu_offload(device=device)
314
+ elif GPU_memory_mode == "model_cpu_offload":
315
+ pipeline.enable_model_cpu_offload(device=device)
316
+ elif GPU_memory_mode == "model_full_load_and_qfloat8":
317
+ convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device)
318
+ convert_weight_dtype_wrapper(transformer, weight_dtype)
319
+ pipeline.to(device=device)
320
+ else:
321
+ pipeline.to(device=device)
322
+
323
+ # LoRA
324
+ if args.videocof_path is not None:
325
+ pipeline = merge_lora(pipeline, args.videocof_path, lora_weight, device=device)
326
+ print(f"[GPU {rank}] Loaded LoRA from {args.videocof_path}")
327
+
328
+ os.makedirs(args.output_dir, exist_ok=True)
329
+
330
+ generator = torch.Generator(device=device).manual_seed(args.seed + rank)
331
+
332
+ # Grounding indices are now handled inside the pipeline; no forward override needed.
333
+
334
+ for fname, item in subset_items:
335
+ base = os.path.splitext(fname)[0]
336
+ output_video_path = os.path.join(args.output_dir, f"gen_{base}.mp4")
337
+ info_path = os.path.join(args.output_dir, f"gen_{base}_info.txt")
338
+
339
+ print(f"[GPU {rank}] Processing {fname}...")
340
+
341
+ video_path = item["source_video_path"]
342
+
343
+ # Match training dataset (ImageVideoCoTDataset) prompt formatting
344
+ edit_text = item.get('text', item.get('qwen_vl_72b_refined_instruction', item.get('edit_instruction', '')))
345
+ ground_instr = derive_ground_instruction(edit_text)
346
+ prompt = (
347
+ "A video sequence showing three parts: first the original scene, "
348
+ f"then grounded {ground_instr}, and finally the same scene but {edit_text}"
349
+ )
350
+
351
+
352
+ input_video, video_height, video_width = load_video_frames(
353
+ video_path,
354
+ source_frames=args.source_frames,
355
+ )
356
+
357
+ with torch.no_grad():
358
+ sample = pipeline(
359
+ video=input_video,
360
+ prompt=prompt,
361
+ num_frames=args.num_frames,
362
+ source_frames=args.source_frames,
363
+ reasoning_frames=args.reasoning_frames,
364
+ negative_prompt=negative_prompt,
365
+ height=video_height,
366
+ width=video_width,
367
+ generator=generator,
368
+ guidance_scale=guidance_scale,
369
+ num_inference_steps=num_inference_steps,
370
+ shift=shift,
371
+ repeat_rope=args.repeat_rope,
372
+ cot=True,
373
+ ).videos
374
+
375
+ reason_edit_path = os.path.join(args.output_dir, f"gen_{base}_reason_edit.mp4")
376
+ save_results(sample, reason_edit_path, fps)
377
+ print(f"[GPU {rank}] Saved reason+edit video shape: {sample.shape}")
378
+
379
+ edit_video = sample[:, :, -args.source_frames:, :, :]
380
+ save_results(edit_video, output_video_path, fps)
381
+ print(f"[GPU {rank}] Edit video shape: {edit_video.shape}")
382
+
383
+ compare_path = os.path.join(args.output_dir, f"gen_{base}_compare.mp4")
384
+ save_side_by_side(input_video, edit_video, compare_path, fps)
385
+
386
+ with open(info_path, "w", encoding="utf-8") as info_f:
387
+ info_f.write(prompt)
388
+
389
+ print(f"[GPU {rank}] Completed {fname}")
390
+
391
+ if args.videocof_path is not None:
392
+ pipeline = unmerge_lora(pipeline, args.videocof_path, lora_weight, device=device)
393
+
394
+ print(f"[GPU {rank}] Finished processing all assigned items")
395
+
396
+
397
+ if __name__ == "__main__":
398
+ main()
399
+
400
+
install.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import subprocess
3
+ import locale
4
+ import threading
5
+ import os
6
+
7
+ def handle_stream(stream, prefix):
8
+ stream.reconfigure(encoding=locale.getpreferredencoding(), errors='replace')
9
+ for msg in stream:
10
+ if prefix == '[!]' and ('it/s]' in msg or 's/it]' in msg) and ('%|' in msg or 'it [' in msg):
11
+ if msg.startswith('100%'):
12
+ print('\r' + msg, end="", file=sys.stderr),
13
+ else:
14
+ print('\r' + msg[:-1], end="", file=sys.stderr),
15
+ else:
16
+ if prefix == '[!]':
17
+ print(prefix, msg, end="", file=sys.stderr)
18
+ else:
19
+ print(prefix, msg, end="")
20
+
21
+ def process_wrap(cmd_str, cwd_path, handler=None):
22
+ process = subprocess.Popen(cmd_str, cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1)
23
+
24
+ if handler is None:
25
+ handler = handle_stream
26
+
27
+ stdout_thread = threading.Thread(target=handler, args=(process.stdout, ""))
28
+ stderr_thread = threading.Thread(target=handler, args=(process.stderr, "[!]"))
29
+
30
+ stdout_thread.start()
31
+ stderr_thread.start()
32
+
33
+ stdout_thread.join()
34
+ stderr_thread.join()
35
+
36
+ return process.wait()
37
+
38
+ assert process_wrap([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd_path=os.path.dirname(os.path.realpath(__file__))) == 0, "ERROR: Failed to install requirements.txt. Please install them manually, and restart ComfyUI."
39
+
40
+ nodep_packages = [
41
+ "kornia>=0.6.9",
42
+ "xformers>=0.0.20",
43
+ ]
44
+
45
+ assert process_wrap([sys.executable, "-m", "pip", "install", "--no-deps", *nodep_packages], cwd_path=os.path.dirname(os.path.realpath(__file__))) == 0, "ERROR: Failed to install last set of packages. Please install them manually, and restart ComfyUI."
pyproject.toml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "videox-fun"
3
+ description = "VideoX-Fun is a video generation pipeline that can be used to generate AI images and videos, as well as to train baseline and Lora models for Diffusion Transformer. We support direct prediction from pre-trained baseline models to generate videos with different resolutions, durations, and FPS. Additionally, we also support users in training their own baseline and Lora models to perform specific style transformations."
4
+ version = "1.0.0"
5
+ license = {file = "LICENSE"}
6
+ dependencies = ["Pillow", "einops", "safetensors", "timm", "tomesd", "torch>=2.1.2", "torchdiffeq", "torchsde", "decord", "datasets", "numpy", "scikit-image", "opencv-python", "omegaconf", "SentencePiece", "albumentations", "imageio[ffmpeg]", "imageio[pyav]", "tensorboard", "beautifulsoup4", "ftfy", "func_timeout", "accelerate>=0.25.0", "gradio>=3.41.2,<=3.48.0", "diffusers>=0.30.1,<=0.31.0", "transformers>=4.46.2"]
7
+
8
+ [project.urls]
9
+ Repository = "https://github.com/aigc-apps/VideoX-Fun"
10
+ # Used by Comfy Registry https://comfyregistry.org
11
+
12
+ [tool.comfy]
13
+ PublisherId = "bubbliiiing"
14
+ DisplayName = "VideoX-Fun"
15
+ Icon = ""
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pillow
2
+ einops
3
+ safetensors
4
+ timm
5
+ tomesd
6
+ torchdiffeq
7
+ torchsde
8
+ decord
9
+ datasets
10
+ numpy
11
+ scikit-image
12
+ opencv-python
13
+ omegaconf
14
+ SentencePiece
15
+ albumentations
16
+ imageio[ffmpeg]
17
+ imageio[pyav]
18
+ tensorboard
19
+ beautifulsoup4
20
+ ftfy
21
+ func_timeout
22
+ onnxruntime
23
+ accelerate>=0.25.0
24
+ gradio>=3.41.2
25
+ diffusers>=0.30.1
26
+ transformers>=4.46.2
scripts/local_style.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export CUDA_VISIBLE_DEVICES=2
2
+
3
+ torchrun --nproc_per_node=1 inference.py \
4
+ --video_path assets/sign.mp4 \
5
+ --prompt "Replace the yellow \"SCHOOL\" sign with a red hospital sign, featuring a white hospital emblem on the top and the word \"HOSPITAL\" below." \
6
+ --output_dir results/local_style \
7
+ --model_name /scratch3/yan204/models/Wan2.1-T2V-14B \
8
+ --seed 0 \
9
+ --num_frames 33 \
10
+ --source_frames 33 \
11
+ --reasoning_frames 4 \
12
+ --repeat_rope \
13
+ --videocof_path videocof_weight/videocof.safetensors
scripts/obj_add.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export CUDA_VISIBLE_DEVICES=0
2
+
3
+ torchrun --nproc_per_node=1 inference.py \
4
+ --video_path assets/woman_ballon.mp4 \
5
+ --prompt "Add the woman in a floral dress pointing at the balloon on the left." \
6
+ --output_dir results/obj_add \
7
+ --model_name /scratch3/yan204/models/Wan2.1-T2V-14B \
8
+ --seed 0 \
9
+ --num_frames 33 \
10
+ --source_frames 33 \
11
+ --reasoning_frames 4 \
12
+ --repeat_rope \
13
+ --videocof_path videocof_weight/videocof.safetensors
scripts/obj_rem.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export CUDA_VISIBLE_DEVICES=1
2
+
3
+ torchrun --nproc_per_node=1 inference.py \
4
+ --video_path assets/two_man.mp4 \
5
+ --prompt "Remove the young man with short black hair wearing black shirt on the left." \
6
+ --output_dir results/obj_rem \
7
+ --model_name /scratch3/yan204/models/Wan2.1-T2V-14B \
8
+ --seed 0 \
9
+ --num_frames 33 \
10
+ --source_frames 33 \
11
+ --reasoning_frames 4 \
12
+ --repeat_rope \
13
+ --videocof_path videocof_weight/videocof.safetensors
scripts/parallel_infer.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
2
+
3
+ torchrun --nproc_per_node=4 inference.py \
4
+ --test_json assets/teaser_test.json \
5
+ --output_dir results/torch_2.5.1 \
6
+ --model_name /scratch3/yan204/models/Wan2.1-T2V-14B \
7
+ --seed 0 \
8
+ --num_frames 33 \
9
+ --source_frames 33 \
10
+ --reasoning_frames 4 \
11
+ --repeat_rope \
12
+ --videocof_path videocof_weight/videocof.safetensors
videox_fun/__init__.py ADDED
File without changes
videox_fun/api/api.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import gc
3
+ import hashlib
4
+ import io
5
+ import os
6
+ import tempfile
7
+ from io import BytesIO
8
+
9
+ import gradio as gr
10
+ import requests
11
+ import torch
12
+ from fastapi import FastAPI
13
+ from PIL import Image
14
+
15
+
16
+ # Function to encode a file to Base64
17
+ def encode_file_to_base64(file_path):
18
+ with open(file_path, "rb") as file:
19
+ # Encode the data to Base64
20
+ file_base64 = base64.b64encode(file.read())
21
+ return file_base64
22
+
23
+ def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
24
+ @app.post("/videox_fun/update_diffusion_transformer")
25
+ def _update_diffusion_transformer_api(
26
+ datas: dict,
27
+ ):
28
+ diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
29
+
30
+ try:
31
+ controller.update_diffusion_transformer(
32
+ diffusion_transformer_path
33
+ )
34
+ comment = "Success"
35
+ except Exception as e:
36
+ torch.cuda.empty_cache()
37
+ comment = f"Error. error information is {str(e)}"
38
+
39
+ return {"message": comment}
40
+
41
+ def download_from_url(url, timeout=10):
42
+ try:
43
+ response = requests.get(url, timeout=timeout)
44
+ response.raise_for_status() # 检查请求是否成功
45
+ return response.content
46
+ except requests.exceptions.RequestException as e:
47
+ print(f"Error downloading from {url}: {e}")
48
+ return None
49
+
50
+ def save_base64_video(base64_string):
51
+ video_data = base64.b64decode(base64_string)
52
+
53
+ md5_hash = hashlib.md5(video_data).hexdigest()
54
+ filename = f"{md5_hash}.mp4"
55
+
56
+ temp_dir = tempfile.gettempdir()
57
+ file_path = os.path.join(temp_dir, filename)
58
+
59
+ with open(file_path, 'wb') as video_file:
60
+ video_file.write(video_data)
61
+
62
+ return file_path
63
+
64
+ def save_base64_image(base64_string):
65
+ video_data = base64.b64decode(base64_string)
66
+
67
+ md5_hash = hashlib.md5(video_data).hexdigest()
68
+ filename = f"{md5_hash}.jpg"
69
+
70
+ temp_dir = tempfile.gettempdir()
71
+ file_path = os.path.join(temp_dir, filename)
72
+
73
+ with open(file_path, 'wb') as video_file:
74
+ video_file.write(video_data)
75
+
76
+ return file_path
77
+
78
+ def save_url_video(url):
79
+ video_data = download_from_url(url)
80
+ if video_data:
81
+ return save_base64_video(base64.b64encode(video_data))
82
+ return None
83
+
84
+ def save_url_image(url):
85
+ image_data = download_from_url(url)
86
+ if image_data:
87
+ return save_base64_image(base64.b64encode(image_data))
88
+ return None
89
+
90
+ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
91
+ @app.post("/videox_fun/infer_forward")
92
+ def _infer_forward_api(
93
+ datas: dict,
94
+ ):
95
+ base_model_path = datas.get('base_model_path', 'none')
96
+ base_model_2_path = datas.get('base_model_2_path', 'none')
97
+ lora_model_path = datas.get('lora_model_path', 'none')
98
+ lora_model_2_path = datas.get('lora_model_2_path', 'none')
99
+ lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
100
+ prompt_textbox = datas.get('prompt_textbox', None)
101
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
102
+ sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
103
+ sample_step_slider = datas.get('sample_step_slider', 30)
104
+ resize_method = datas.get('resize_method', "Generate by")
105
+ width_slider = datas.get('width_slider', 672)
106
+ height_slider = datas.get('height_slider', 384)
107
+ base_resolution = datas.get('base_resolution', 512)
108
+ is_image = datas.get('is_image', False)
109
+ generation_method = datas.get('generation_method', False)
110
+ length_slider = datas.get('length_slider', 49)
111
+ overlap_video_length = datas.get('overlap_video_length', 4)
112
+ partial_video_length = datas.get('partial_video_length', 72)
113
+ cfg_scale_slider = datas.get('cfg_scale_slider', 6)
114
+ start_image = datas.get('start_image', None)
115
+ end_image = datas.get('end_image', None)
116
+ validation_video = datas.get('validation_video', None)
117
+ validation_video_mask = datas.get('validation_video_mask', None)
118
+ control_video = datas.get('control_video', None)
119
+ denoise_strength = datas.get('denoise_strength', 0.70)
120
+ seed_textbox = datas.get("seed_textbox", 43)
121
+
122
+ ref_image = datas.get('ref_image', None)
123
+ enable_teacache = datas.get('enable_teacache', True)
124
+ teacache_threshold = datas.get('teacache_threshold', 0.10)
125
+ num_skip_start_steps = datas.get('num_skip_start_steps', 1)
126
+ teacache_offload = datas.get('teacache_offload', False)
127
+ cfg_skip_ratio = datas.get('cfg_skip_ratio', 0)
128
+ enable_riflex = datas.get('enable_riflex', False)
129
+ riflex_k = datas.get('riflex_k', 6)
130
+ fps = datas.get('fps', None)
131
+
132
+ generation_method = "Image Generation" if is_image else generation_method
133
+
134
+ if start_image is not None:
135
+ if start_image.startswith('http'):
136
+ start_image = save_url_image(start_image)
137
+ start_image = [Image.open(start_image).convert("RGB")]
138
+ else:
139
+ start_image = base64.b64decode(start_image)
140
+ start_image = [Image.open(BytesIO(start_image)).convert("RGB")]
141
+
142
+ if end_image is not None:
143
+ if end_image.startswith('http'):
144
+ end_image = save_url_image(end_image)
145
+ end_image = [Image.open(end_image).convert("RGB")]
146
+ else:
147
+ end_image = base64.b64decode(end_image)
148
+ end_image = [Image.open(BytesIO(end_image)).convert("RGB")]
149
+
150
+ if validation_video is not None:
151
+ if validation_video.startswith('http'):
152
+ validation_video = save_url_video(validation_video)
153
+ else:
154
+ validation_video = save_base64_video(validation_video)
155
+
156
+ if validation_video_mask is not None:
157
+ if validation_video_mask.startswith('http'):
158
+ validation_video_mask = save_url_image(validation_video_mask)
159
+ else:
160
+ validation_video_mask = save_base64_image(validation_video_mask)
161
+
162
+ if control_video is not None:
163
+ if control_video.startswith('http'):
164
+ control_video = save_url_video(control_video)
165
+ else:
166
+ control_video = save_base64_video(control_video)
167
+
168
+ if ref_image is not None:
169
+ if ref_image.startswith('http'):
170
+ ref_image = save_url_image(ref_image)
171
+ ref_image = [Image.open(ref_image).convert("RGB")]
172
+ else:
173
+ ref_image = base64.b64decode(ref_image)
174
+ ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")]
175
+
176
+ try:
177
+ save_sample_path, comment = controller.generate(
178
+ "",
179
+ base_model_path,
180
+ lora_model_path,
181
+ lora_alpha_slider,
182
+ prompt_textbox,
183
+ negative_prompt_textbox,
184
+ sampler_dropdown,
185
+ sample_step_slider,
186
+ resize_method,
187
+ width_slider,
188
+ height_slider,
189
+ base_resolution,
190
+ generation_method,
191
+ length_slider,
192
+ overlap_video_length,
193
+ partial_video_length,
194
+ cfg_scale_slider,
195
+ start_image,
196
+ end_image,
197
+ validation_video,
198
+ validation_video_mask,
199
+ control_video,
200
+ denoise_strength,
201
+ seed_textbox,
202
+ ref_image = ref_image,
203
+ enable_teacache = enable_teacache,
204
+ teacache_threshold = teacache_threshold,
205
+ num_skip_start_steps = num_skip_start_steps,
206
+ teacache_offload = teacache_offload,
207
+ cfg_skip_ratio = cfg_skip_ratio,
208
+ enable_riflex = enable_riflex,
209
+ riflex_k = riflex_k,
210
+ base_model_2_dropdown = base_model_2_path,
211
+ lora_model_2_dropdown = lora_model_2_path,
212
+ fps = fps,
213
+ is_api = True,
214
+ )
215
+ except Exception as e:
216
+ gc.collect()
217
+ torch.cuda.empty_cache()
218
+ torch.cuda.ipc_collect()
219
+ save_sample_path = ""
220
+ comment = f"Error. error information is {str(e)}"
221
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
222
+
223
+ if save_sample_path != "":
224
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
225
+ else:
226
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": None}
videox_fun/api/api_multi_nodes.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from https://github.com/xdit-project/xDiT/blob/main/entrypoints/launch.py
2
+ import base64
3
+ import gc
4
+ import hashlib
5
+ import io
6
+ import os
7
+ import tempfile
8
+ from io import BytesIO
9
+
10
+ import gradio as gr
11
+ import requests
12
+ import torch
13
+ import torch.distributed as dist
14
+ from fastapi import FastAPI, HTTPException
15
+ from PIL import Image
16
+
17
+ from .api import download_from_url, encode_file_to_base64
18
+
19
+ try:
20
+ import ray
21
+ except:
22
+ print("Ray is not installed. If you want to use multi gpus api. Please install it by running 'pip install ray'.")
23
+ ray = None
24
+
25
+ def save_base64_video_dist(base64_string):
26
+ video_data = base64.b64decode(base64_string)
27
+
28
+ md5_hash = hashlib.md5(video_data).hexdigest()
29
+ filename = f"{md5_hash}.mp4"
30
+
31
+ temp_dir = tempfile.gettempdir()
32
+ file_path = os.path.join(temp_dir, filename)
33
+
34
+ if dist.is_initialized():
35
+ if dist.get_rank() == 0:
36
+ with open(file_path, 'wb') as video_file:
37
+ video_file.write(video_data)
38
+ dist.barrier()
39
+ else:
40
+ with open(file_path, 'wb') as video_file:
41
+ video_file.write(video_data)
42
+ return file_path
43
+
44
+ def save_base64_image_dist(base64_string):
45
+ video_data = base64.b64decode(base64_string)
46
+
47
+ md5_hash = hashlib.md5(video_data).hexdigest()
48
+ filename = f"{md5_hash}.jpg"
49
+
50
+ temp_dir = tempfile.gettempdir()
51
+ file_path = os.path.join(temp_dir, filename)
52
+
53
+ if dist.is_initialized():
54
+ if dist.get_rank() == 0:
55
+ with open(file_path, 'wb') as video_file:
56
+ video_file.write(video_data)
57
+ dist.barrier()
58
+ else:
59
+ with open(file_path, 'wb') as video_file:
60
+ video_file.write(video_data)
61
+ return file_path
62
+
63
+ def save_url_video_dist(url):
64
+ video_data = download_from_url(url)
65
+ if video_data:
66
+ return save_base64_video_dist(base64.b64encode(video_data))
67
+ return None
68
+
69
+ def save_url_image_dist(url):
70
+ image_data = download_from_url(url)
71
+ if image_data:
72
+ return save_base64_image_dist(base64.b64encode(image_data))
73
+ return None
74
+
75
+ if ray is not None:
76
+ @ray.remote(num_gpus=1)
77
+ class MultiNodesGenerator:
78
+ def __init__(
79
+ self, rank: int, world_size: int, Controller,
80
+ GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
81
+ config_path=None, ulysses_degree=1, ring_degree=1,
82
+ fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False,
83
+ weight_dtype=None, savedir_sample=None,
84
+ ):
85
+ # Set PyTorch distributed environment variables
86
+ os.environ["RANK"] = str(rank)
87
+ os.environ["WORLD_SIZE"] = str(world_size)
88
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
89
+ os.environ["MASTER_PORT"] = "29500"
90
+
91
+ self.rank = rank
92
+ self.controller = Controller(
93
+ GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
94
+ ulysses_degree=ulysses_degree, ring_degree=ring_degree,
95
+ fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit,
96
+ weight_dtype=weight_dtype, savedir_sample=savedir_sample,
97
+ )
98
+
99
+ def generate(self, datas):
100
+ try:
101
+ base_model_path = datas.get('base_model_path', 'none')
102
+ base_model_2_path = datas.get('base_model_2_path', 'none')
103
+ lora_model_path = datas.get('lora_model_path', 'none')
104
+ lora_model_2_path = datas.get('lora_model_2_path', 'none')
105
+ lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
106
+ prompt_textbox = datas.get('prompt_textbox', None)
107
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
108
+ sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
109
+ sample_step_slider = datas.get('sample_step_slider', 30)
110
+ resize_method = datas.get('resize_method', "Generate by")
111
+ width_slider = datas.get('width_slider', 672)
112
+ height_slider = datas.get('height_slider', 384)
113
+ base_resolution = datas.get('base_resolution', 512)
114
+ is_image = datas.get('is_image', False)
115
+ generation_method = datas.get('generation_method', False)
116
+ length_slider = datas.get('length_slider', 49)
117
+ overlap_video_length = datas.get('overlap_video_length', 4)
118
+ partial_video_length = datas.get('partial_video_length', 72)
119
+ cfg_scale_slider = datas.get('cfg_scale_slider', 6)
120
+ start_image = datas.get('start_image', None)
121
+ end_image = datas.get('end_image', None)
122
+ validation_video = datas.get('validation_video', None)
123
+ validation_video_mask = datas.get('validation_video_mask', None)
124
+ control_video = datas.get('control_video', None)
125
+ denoise_strength = datas.get('denoise_strength', 0.70)
126
+ seed_textbox = datas.get("seed_textbox", 43)
127
+
128
+ ref_image = datas.get('ref_image', None)
129
+ enable_teacache = datas.get('enable_teacache', True)
130
+ teacache_threshold = datas.get('teacache_threshold', 0.10)
131
+ num_skip_start_steps = datas.get('num_skip_start_steps', 1)
132
+ teacache_offload = datas.get('teacache_offload', False)
133
+ cfg_skip_ratio = datas.get('cfg_skip_ratio', 0)
134
+ enable_riflex = datas.get('enable_riflex', False)
135
+ riflex_k = datas.get('riflex_k', 6)
136
+ fps = datas.get('fps', None)
137
+
138
+ generation_method = "Image Generation" if is_image else generation_method
139
+
140
+ if start_image is not None:
141
+ if start_image.startswith('http'):
142
+ start_image = save_url_image_dist(start_image)
143
+ start_image = [Image.open(start_image).convert("RGB")]
144
+ else:
145
+ start_image = base64.b64decode(start_image)
146
+ start_image = [Image.open(BytesIO(start_image)).convert("RGB")]
147
+
148
+ if end_image is not None:
149
+ if end_image.startswith('http'):
150
+ end_image = save_url_image_dist(end_image)
151
+ end_image = [Image.open(end_image).convert("RGB")]
152
+ else:
153
+ end_image = base64.b64decode(end_image)
154
+ end_image = [Image.open(BytesIO(end_image)).convert("RGB")]
155
+
156
+ if validation_video is not None:
157
+ if validation_video.startswith('http'):
158
+ validation_video = save_url_video_dist(validation_video)
159
+ else:
160
+ validation_video = save_base64_video_dist(validation_video)
161
+
162
+ if validation_video_mask is not None:
163
+ if validation_video_mask.startswith('http'):
164
+ validation_video_mask = save_url_image_dist(validation_video_mask)
165
+ else:
166
+ validation_video_mask = save_base64_image_dist(validation_video_mask)
167
+
168
+ if control_video is not None:
169
+ if control_video.startswith('http'):
170
+ control_video = save_url_video_dist(control_video)
171
+ else:
172
+ control_video = save_base64_video_dist(control_video)
173
+
174
+ if ref_image is not None:
175
+ if ref_image.startswith('http'):
176
+ ref_image = save_url_image_dist(ref_image)
177
+ ref_image = [Image.open(ref_image).convert("RGB")]
178
+ else:
179
+ ref_image = base64.b64decode(ref_image)
180
+ ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")]
181
+
182
+ try:
183
+ save_sample_path, comment = self.controller.generate(
184
+ "",
185
+ base_model_path,
186
+ lora_model_path,
187
+ lora_alpha_slider,
188
+ prompt_textbox,
189
+ negative_prompt_textbox,
190
+ sampler_dropdown,
191
+ sample_step_slider,
192
+ resize_method,
193
+ width_slider,
194
+ height_slider,
195
+ base_resolution,
196
+ generation_method,
197
+ length_slider,
198
+ overlap_video_length,
199
+ partial_video_length,
200
+ cfg_scale_slider,
201
+ start_image,
202
+ end_image,
203
+ validation_video,
204
+ validation_video_mask,
205
+ control_video,
206
+ denoise_strength,
207
+ seed_textbox,
208
+ ref_image = ref_image,
209
+ enable_teacache = enable_teacache,
210
+ teacache_threshold = teacache_threshold,
211
+ num_skip_start_steps = num_skip_start_steps,
212
+ teacache_offload = teacache_offload,
213
+ cfg_skip_ratio = cfg_skip_ratio,
214
+ enable_riflex = enable_riflex,
215
+ riflex_k = riflex_k,
216
+ base_model_2_dropdown = base_model_2_path,
217
+ lora_model_2_dropdown = lora_model_2_path,
218
+ fps = fps,
219
+ is_api = True,
220
+ )
221
+ except Exception as e:
222
+ gc.collect()
223
+ torch.cuda.empty_cache()
224
+ torch.cuda.ipc_collect()
225
+ save_sample_path = ""
226
+ comment = f"Error. error information is {str(e)}"
227
+ if dist.is_initialized():
228
+ if dist.get_rank() == 0:
229
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
230
+ else:
231
+ return None
232
+ else:
233
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
234
+
235
+
236
+ if dist.is_initialized():
237
+ if dist.get_rank() == 0:
238
+ if save_sample_path != "":
239
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
240
+ else:
241
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
242
+ else:
243
+ return None
244
+ else:
245
+ if save_sample_path != "":
246
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
247
+ else:
248
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
249
+
250
+ except Exception as e:
251
+ print(f"Error generating: {str(e)}")
252
+ comment = f"Error generating: {str(e)}"
253
+ if dist.is_initialized():
254
+ if dist.get_rank() == 0:
255
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
256
+ else:
257
+ return None
258
+ else:
259
+ return {"message": comment, "save_sample_path": None, "base64_encoding": None}
260
+
261
+ class MultiNodesEngine:
262
+ def __init__(
263
+ self,
264
+ world_size,
265
+ Controller,
266
+ GPU_memory_mode,
267
+ scheduler_dict,
268
+ model_name,
269
+ model_type,
270
+ config_path,
271
+ ulysses_degree=1,
272
+ ring_degree=1,
273
+ fsdp_dit=False,
274
+ fsdp_text_encoder=False,
275
+ compile_dit=False,
276
+ weight_dtype=torch.bfloat16,
277
+ savedir_sample="samples"
278
+ ):
279
+ # Ensure Ray is initialized
280
+ if not ray.is_initialized():
281
+ ray.init()
282
+
283
+ num_workers = world_size
284
+ self.workers = [
285
+ MultiNodesGenerator.remote(
286
+ rank, world_size, Controller,
287
+ GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
288
+ ulysses_degree=ulysses_degree, ring_degree=ring_degree,
289
+ fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit,
290
+ weight_dtype=weight_dtype, savedir_sample=savedir_sample,
291
+ )
292
+ for rank in range(num_workers)
293
+ ]
294
+ print("Update workers done")
295
+
296
+ async def generate(self, data):
297
+ results = ray.get([
298
+ worker.generate.remote(data)
299
+ for worker in self.workers
300
+ ])
301
+
302
+ return next(path for path in results if path is not None)
303
+
304
+ def multi_nodes_infer_forward_api(_: gr.Blocks, app: FastAPI, engine):
305
+
306
+ @app.post("/videox_fun/infer_forward")
307
+ async def _multi_nodes_infer_forward_api(
308
+ datas: dict,
309
+ ):
310
+ try:
311
+ result = await engine.generate(datas)
312
+ return result
313
+ except Exception as e:
314
+ if isinstance(e, HTTPException):
315
+ raise e
316
+ raise HTTPException(status_code=500, detail=str(e))
317
+ else:
318
+ MultiNodesEngine = None
319
+ MultiNodesGenerator = None
320
+ multi_nodes_infer_forward_api = None
videox_fun/data/bucket_sampler.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
4
+ Sized, TypeVar, Union)
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from torch.utils.data import BatchSampler, Dataset, Sampler
11
+
12
+ # Original presets (commented out per request):
13
+ CUSTOM_ASPECT_RATIOS = {
14
+ "0.5676": [336, 592], # count=133984
15
+ "1.7619": [592, 336], # count=78813
16
+ "0.5682": [400, 704], # count=4421
17
+ "0.5556": [320, 576], # count=2481
18
+ "1.7600": [704, 400], # count=1682
19
+ "0.5319": [400, 752], # count=1235
20
+ "1.8000": [576, 320], # count=924
21
+ "0.5128": [320, 624], # count=711
22
+ "1.8800": [752, 400], # count=400
23
+ "1.9000": [608, 320], # count=226
24
+ "0.4237": [400, 944], # count=29
25
+ }
26
+ # CUSTOM_ASPECT_RATIOS = {
27
+ # "0.5676": [336, 592], # 336x592 (h x w)
28
+ # "1.7619": [592, 336], # 592x336
29
+ # "0.5682": [400, 704], # 400x704
30
+ # "1.7600": [704, 400], # 704x400
31
+ # "0.5319": [400, 752], # 400x752
32
+ # "1.8800": [752, 400], # 752x400
33
+ # "0.4237": [400, 944], # 400x944
34
+ # }
35
+
36
+
37
+ ASPECT_RATIO_512 = {
38
+ '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
39
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
40
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
41
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
42
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
43
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
44
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
45
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
46
+ '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
47
+ '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
48
+ }
49
+ ASPECT_RATIO_RANDOM_CROP_512 = {
50
+ '0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
51
+ '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
52
+ '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
53
+ '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
54
+ '2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
55
+ }
56
+ ASPECT_RATIO_RANDOM_CROP_PROB = [
57
+ 1, 2,
58
+ 4, 4, 4, 4,
59
+ 8, 8, 8,
60
+ 4, 4, 4, 4,
61
+ 2, 1
62
+ ]
63
+ ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
64
+
65
+ def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
66
+ aspect_ratio = height / width
67
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
68
+ return ratios[closest_ratio], float(closest_ratio)
69
+
70
+ def get_image_size_without_loading(path):
71
+ with Image.open(path) as img:
72
+ return img.size # (width, height)
73
+
74
+ class RandomSampler(Sampler[int]):
75
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
76
+
77
+ If with replacement, then user can specify :attr:`num_samples` to draw.
78
+
79
+ Args:
80
+ data_source (Dataset): dataset to sample from
81
+ replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
82
+ num_samples (int): number of samples to draw, default=`len(dataset)`.
83
+ generator (Generator): Generator used in sampling.
84
+ """
85
+
86
+ data_source: Sized
87
+ replacement: bool
88
+
89
+ def __init__(self, data_source: Sized, replacement: bool = False,
90
+ num_samples: Optional[int] = None, generator=None) -> None:
91
+ self.data_source = data_source
92
+ self.replacement = replacement
93
+ self._num_samples = num_samples
94
+ self.generator = generator
95
+ self._pos_start = 0
96
+
97
+ if not isinstance(self.replacement, bool):
98
+ raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
99
+
100
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
101
+ raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
102
+
103
+ @property
104
+ def num_samples(self) -> int:
105
+ # dataset size might change at runtime
106
+ if self._num_samples is None:
107
+ return len(self.data_source)
108
+ return self._num_samples
109
+
110
+ def __iter__(self) -> Iterator[int]:
111
+ n = len(self.data_source)
112
+ if self.generator is None:
113
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
114
+ generator = torch.Generator()
115
+ generator.manual_seed(seed)
116
+ else:
117
+ generator = self.generator
118
+
119
+ if self.replacement:
120
+ for _ in range(self.num_samples // 32):
121
+ yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
122
+ yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
123
+ else:
124
+ for _ in range(self.num_samples // n):
125
+ xx = torch.randperm(n, generator=generator).tolist()
126
+ if self._pos_start >= n:
127
+ self._pos_start = 0
128
+ for idx in range(self._pos_start, n):
129
+ yield xx[idx]
130
+ self._pos_start = (self._pos_start + 1) % n
131
+ self._pos_start = 0
132
+ yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
133
+
134
+ def __len__(self) -> int:
135
+ return self.num_samples
136
+
137
+ class AspectRatioBatchImageSampler(BatchSampler):
138
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
139
+
140
+ Args:
141
+ sampler (Sampler): Base sampler.
142
+ dataset (Dataset): Dataset providing data information.
143
+ batch_size (int): Size of mini-batch.
144
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
145
+ its size would be less than ``batch_size``.
146
+ aspect_ratios (dict): The predefined aspect ratios.
147
+ """
148
+ def __init__(
149
+ self,
150
+ sampler: Sampler,
151
+ dataset: Dataset,
152
+ batch_size: int,
153
+ train_folder: str = None,
154
+ aspect_ratios: dict = ASPECT_RATIO_512,
155
+ drop_last: bool = False,
156
+ config=None,
157
+ **kwargs
158
+ ) -> None:
159
+ if not isinstance(sampler, Sampler):
160
+ raise TypeError('sampler should be an instance of ``Sampler``, '
161
+ f'but got {sampler}')
162
+ if not isinstance(batch_size, int) or batch_size <= 0:
163
+ raise ValueError('batch_size should be a positive integer value, '
164
+ f'but got batch_size={batch_size}')
165
+ self.sampler = sampler
166
+ self.dataset = dataset
167
+ self.train_folder = train_folder
168
+ self.batch_size = batch_size
169
+ self.aspect_ratios = aspect_ratios
170
+ self.drop_last = drop_last
171
+ self.config = config
172
+ # buckets for each aspect ratio
173
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
174
+ # [str(k) for k, v in aspect_ratios]
175
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
176
+
177
+ def __iter__(self):
178
+ for idx in self.sampler:
179
+ try:
180
+ image_dict = self.dataset[idx]
181
+
182
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
183
+ if width is None or height is None:
184
+ image_id, name = image_dict['file_path'], image_dict['text']
185
+ if self.train_folder is None:
186
+ image_dir = image_id
187
+ else:
188
+ image_dir = os.path.join(self.train_folder, image_id)
189
+
190
+ width, height = get_image_size_without_loading(image_dir)
191
+
192
+ ratio = height / width # self.dataset[idx]
193
+ else:
194
+ height = int(height)
195
+ width = int(width)
196
+ ratio = height / width # self.dataset[idx]
197
+ except Exception as e:
198
+ print(e)
199
+ continue
200
+ # find the closest aspect ratio
201
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
202
+ if closest_ratio not in self.current_available_bucket_keys:
203
+ continue
204
+ bucket = self._aspect_ratio_buckets[closest_ratio]
205
+ bucket.append(idx)
206
+ # yield a batch of indices in the same aspect ratio group
207
+ if len(bucket) == self.batch_size:
208
+ yield bucket[:]
209
+ del bucket[:]
210
+
211
+ class AspectRatioBatchSampler(BatchSampler):
212
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
213
+
214
+ Args:
215
+ sampler (Sampler): Base sampler.
216
+ dataset (Dataset): Dataset providing data information.
217
+ batch_size (int): Size of mini-batch.
218
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
219
+ its size would be less than ``batch_size``.
220
+ aspect_ratios (dict): The predefined aspect ratios.
221
+ """
222
+ def __init__(
223
+ self,
224
+ sampler: Sampler,
225
+ dataset: Dataset,
226
+ batch_size: int,
227
+ video_folder: str = None,
228
+ train_data_format: str = "webvid",
229
+ aspect_ratios: dict = ASPECT_RATIO_512,
230
+ drop_last: bool = False,
231
+ config=None,
232
+ **kwargs
233
+ ) -> None:
234
+ if not isinstance(sampler, Sampler):
235
+ raise TypeError('sampler should be an instance of ``Sampler``, '
236
+ f'but got {sampler}')
237
+ if not isinstance(batch_size, int) or batch_size <= 0:
238
+ raise ValueError('batch_size should be a positive integer value, '
239
+ f'but got batch_size={batch_size}')
240
+ self.sampler = sampler
241
+ self.dataset = dataset
242
+ self.video_folder = video_folder
243
+ self.train_data_format = train_data_format
244
+ self.batch_size = batch_size
245
+ self.aspect_ratios = aspect_ratios
246
+ self.drop_last = drop_last
247
+ self.config = config
248
+ # buckets for each aspect ratio
249
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
250
+ # [str(k) for k, v in aspect_ratios]
251
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
252
+
253
+ def __iter__(self):
254
+ for idx in self.sampler:
255
+ try:
256
+ video_dict = self.dataset[idx]
257
+ width, more = video_dict.get("width", None), video_dict.get("height", None)
258
+
259
+ if width is None or height is None:
260
+ if self.train_data_format == "normal":
261
+ video_id, name = video_dict['file_path'], video_dict['text']
262
+ if self.video_folder is None:
263
+ video_dir = video_id
264
+ else:
265
+ video_dir = os.path.join(self.video_folder, video_id)
266
+ else:
267
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
268
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
269
+ cap = cv2.VideoCapture(video_dir)
270
+
271
+ # 获取视频尺寸
272
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
273
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
274
+
275
+ ratio = height / width # self.dataset[idx]
276
+ else:
277
+ height = int(height)
278
+ width = int(width)
279
+ ratio = height / width # self.dataset[idx]
280
+ except Exception as e:
281
+ print(e, self.dataset[idx], "This item is error, please check it.")
282
+ continue
283
+ # find the closest aspect ratio
284
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
285
+ if closest_ratio not in self.current_available_bucket_keys:
286
+ continue
287
+ bucket = self._aspect_ratio_buckets[closest_ratio]
288
+ bucket.append(idx)
289
+ # yield a batch of indices in the same aspect ratio group
290
+ if len(bucket) == self.batch_size:
291
+ yield bucket[:]
292
+ del bucket[:]
293
+
294
+ class AspectRatioBatchImageVideoSampler(BatchSampler):
295
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
296
+
297
+ Args:
298
+ sampler (Sampler): Base sampler.
299
+ dataset (Dataset): Dataset providing data information.
300
+ batch_size (int): Size of mini-batch.
301
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
302
+ its size would be less than ``batch_size``.
303
+ aspect_ratios (dict): The predefined aspect ratios.
304
+ """
305
+
306
+ def __init__(self,
307
+ sampler: Sampler,
308
+ dataset: Dataset,
309
+ batch_size: int,
310
+ train_folder: str = None,
311
+ aspect_ratios: dict = ASPECT_RATIO_512,
312
+ drop_last: bool = False
313
+ ) -> None:
314
+ if not isinstance(sampler, Sampler):
315
+ raise TypeError('sampler should be an instance of ``Sampler``, '
316
+ f'but got {sampler}')
317
+ if not isinstance(batch_size, int) or batch_size <= 0:
318
+ raise ValueError('batch_size should be a positive integer value, '
319
+ f'but got batch_size={batch_size}')
320
+ self.sampler = sampler
321
+ self.dataset = dataset
322
+ self.train_folder = train_folder
323
+ self.batch_size = batch_size
324
+ self.aspect_ratios = aspect_ratios
325
+ self.drop_last = drop_last
326
+
327
+ # buckets for each aspect ratio
328
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
329
+ self.bucket = {
330
+ 'image':{ratio: [] for ratio in aspect_ratios},
331
+ 'video':{ratio: [] for ratio in aspect_ratios}
332
+ }
333
+
334
+ def __iter__(self):
335
+ for idx in self.sampler:
336
+ content_type = self.dataset[idx].get('type', 'video') # Default to video for video edit datasets
337
+
338
+ try:
339
+ data_dict = self.dataset[idx]
340
+ width, height = data_dict.get("width", None), data_dict.get("height", None)
341
+
342
+ if width is None or height is None:
343
+ if content_type == 'image':
344
+ # Image branch
345
+ image_id = data_dict.get('file_path', '')
346
+ if self.train_folder is None:
347
+ image_dir = image_id
348
+ else:
349
+ image_dir = os.path.join(self.train_folder, image_id)
350
+ width, height = get_image_size_without_loading(image_dir)
351
+ else:
352
+ # Video branch - prefer original_video -> edited_video -> file_path
353
+ video_id = (
354
+ data_dict.get('original_video')
355
+ or data_dict.get('edited_video')
356
+ or data_dict.get('file_path')
357
+ )
358
+ if video_id is None:
359
+ raise ValueError(f"No valid video path found in dataset item: {data_dict}")
360
+ if self.train_folder is None:
361
+ video_dir = video_id
362
+ else:
363
+ video_dir = os.path.join(self.train_folder, video_id)
364
+ cap = cv2.VideoCapture(video_dir)
365
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
366
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
367
+ cap.release()
368
+ if width == 0 or height == 0:
369
+ raise ValueError(f"Invalid video size for {video_dir}: {width}x{height}")
370
+ else:
371
+ height = int(height)
372
+ width = int(width)
373
+
374
+ ratio = height / width
375
+
376
+ except Exception as e:
377
+ print(e, self.dataset[idx], "This item is error, please check it.")
378
+ continue
379
+
380
+ # Find the closest aspect ratio
381
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
382
+ if closest_ratio not in self.current_available_bucket_keys:
383
+ continue
384
+
385
+ # Add to appropriate bucket (image or video)
386
+ bucket = self.bucket[content_type][closest_ratio]
387
+ bucket.append(idx)
388
+
389
+ # Yield a batch when bucket is full (ensures all items are same type)
390
+ if len(bucket) == self.batch_size:
391
+ yield bucket[:]
392
+ del bucket[:]
videox_fun/data/dataset_image.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from PIL import Image
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+
12
+ class CC15M(Dataset):
13
+ def __init__(
14
+ self,
15
+ json_path,
16
+ video_folder=None,
17
+ resolution=512,
18
+ enable_bucket=False,
19
+ ):
20
+ print(f"loading annotations from {json_path} ...")
21
+ self.dataset = json.load(open(json_path, 'r'))
22
+ self.length = len(self.dataset)
23
+ print(f"data scale: {self.length}")
24
+
25
+ self.enable_bucket = enable_bucket
26
+ self.video_folder = video_folder
27
+
28
+ resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
29
+ self.pixel_transforms = transforms.Compose([
30
+ transforms.Resize(resolution[0]),
31
+ transforms.CenterCrop(resolution),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
34
+ ])
35
+
36
+ def get_batch(self, idx):
37
+ video_dict = self.dataset[idx]
38
+ video_id, name = video_dict['file_path'], video_dict['text']
39
+
40
+ if self.video_folder is None:
41
+ video_dir = video_id
42
+ else:
43
+ video_dir = os.path.join(self.video_folder, video_id)
44
+
45
+ pixel_values = Image.open(video_dir).convert("RGB")
46
+ return pixel_values, name
47
+
48
+ def __len__(self):
49
+ return self.length
50
+
51
+ def __getitem__(self, idx):
52
+ while True:
53
+ try:
54
+ pixel_values, name = self.get_batch(idx)
55
+ break
56
+ except Exception as e:
57
+ print(e)
58
+ idx = random.randint(0, self.length-1)
59
+
60
+ if not self.enable_bucket:
61
+ pixel_values = self.pixel_transforms(pixel_values)
62
+ else:
63
+ pixel_values = np.array(pixel_values)
64
+
65
+ sample = dict(pixel_values=pixel_values, text=name)
66
+ return sample
67
+
68
+ if __name__ == "__main__":
69
+ dataset = CC15M(
70
+ csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json",
71
+ resolution=512,
72
+ )
73
+
74
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
75
+ for idx, batch in enumerate(dataloader):
76
+ print(batch["pixel_values"].shape, len(batch["text"]))
videox_fun/data/dataset_image_video.py ADDED
@@ -0,0 +1,1939 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import gc
3
+ import io
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ import re
9
+ from contextlib import contextmanager
10
+ from random import shuffle
11
+ from threading import Thread
12
+
13
+ import albumentations
14
+ import cv2
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import torchvision.transforms as transforms
19
+ from decord import VideoReader
20
+ from einops import rearrange
21
+ from func_timeout import FunctionTimedOut, func_timeout
22
+ from packaging import version as pver
23
+ from PIL import Image
24
+ from torch.utils.data import BatchSampler, Sampler
25
+ from torch.utils.data.dataset import Dataset
26
+
27
+ VIDEO_READER_TIMEOUT = 20
28
+
29
+ def get_random_mask(shape, image_start_only=False):
30
+ f, c, h, w = shape
31
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
32
+
33
+ if not image_start_only:
34
+ if f != 1:
35
+ mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05])
36
+ else:
37
+ mask_index = np.random.choice([0, 1], p = [0.2, 0.8])
38
+ if mask_index == 0:
39
+ center_x = torch.randint(0, w, (1,)).item()
40
+ center_y = torch.randint(0, h, (1,)).item()
41
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
42
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
43
+
44
+ start_x = max(center_x - block_size_x // 2, 0)
45
+ end_x = min(center_x + block_size_x // 2, w)
46
+ start_y = max(center_y - block_size_y // 2, 0)
47
+ end_y = min(center_y + block_size_y // 2, h)
48
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
49
+ elif mask_index == 1:
50
+ mask[:, :, :, :] = 1
51
+ elif mask_index == 2:
52
+ mask_frame_index = np.random.randint(1, 5)
53
+ mask[mask_frame_index:, :, :, :] = 1
54
+ elif mask_index == 3:
55
+ mask_frame_index = np.random.randint(1, 5)
56
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
57
+ elif mask_index == 4:
58
+ center_x = torch.randint(0, w, (1,)).item()
59
+ center_y = torch.randint(0, h, (1,)).item()
60
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
61
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
62
+
63
+ start_x = max(center_x - block_size_x // 2, 0)
64
+ end_x = min(center_x + block_size_x // 2, w)
65
+ start_y = max(center_y - block_size_y // 2, 0)
66
+ end_y = min(center_y + block_size_y // 2, h)
67
+
68
+ mask_frame_before = np.random.randint(0, f // 2)
69
+ mask_frame_after = np.random.randint(f // 2, f)
70
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
71
+ elif mask_index == 5:
72
+ mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
73
+ elif mask_index == 6:
74
+ num_frames_to_mask = random.randint(1, max(f // 2, 1))
75
+ frames_to_mask = random.sample(range(f), num_frames_to_mask)
76
+
77
+ for i in frames_to_mask:
78
+ block_height = random.randint(1, h // 4)
79
+ block_width = random.randint(1, w // 4)
80
+ top_left_y = random.randint(0, h - block_height)
81
+ top_left_x = random.randint(0, w - block_width)
82
+ mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
83
+ elif mask_index == 7:
84
+ center_x = torch.randint(0, w, (1,)).item()
85
+ center_y = torch.randint(0, h, (1,)).item()
86
+ a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴
87
+ b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴
88
+
89
+ for i in range(h):
90
+ for j in range(w):
91
+ if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
92
+ mask[:, :, i, j] = 1
93
+ elif mask_index == 8:
94
+ center_x = torch.randint(0, w, (1,)).item()
95
+ center_y = torch.randint(0, h, (1,)).item()
96
+ radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
97
+ for i in range(h):
98
+ for j in range(w):
99
+ if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
100
+ mask[:, :, i, j] = 1
101
+ elif mask_index == 9:
102
+ for idx in range(f):
103
+ if np.random.rand() > 0.5:
104
+ mask[idx, :, :, :] = 1
105
+ else:
106
+ raise ValueError(f"The mask_index {mask_index} is not define")
107
+ else:
108
+ if f != 1:
109
+ mask[1:, :, :, :] = 1
110
+ else:
111
+ mask[:, :, :, :] = 1
112
+ return mask
113
+
114
+ class Camera(object):
115
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
116
+ """
117
+ def __init__(self, entry):
118
+ fx, fy, cx, cy = entry[1:5]
119
+ self.fx = fx
120
+ self.fy = fy
121
+ self.cx = cx
122
+ self.cy = cy
123
+ w2c_mat = np.array(entry[7:]).reshape(3, 4)
124
+ w2c_mat_4x4 = np.eye(4)
125
+ w2c_mat_4x4[:3, :] = w2c_mat
126
+ self.w2c_mat = w2c_mat_4x4
127
+ self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
128
+
129
+ def custom_meshgrid(*args):
130
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
131
+ """
132
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
133
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
134
+ return torch.meshgrid(*args)
135
+ else:
136
+ return torch.meshgrid(*args, indexing='ij')
137
+
138
+ def get_relative_pose(cam_params):
139
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
140
+ """
141
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
142
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
143
+ cam_to_origin = 0
144
+ target_cam_c2w = np.array([
145
+ [1, 0, 0, 0],
146
+ [0, 1, 0, -cam_to_origin],
147
+ [0, 0, 1, 0],
148
+ [0, 0, 0, 1]
149
+ ])
150
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
151
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
152
+ ret_poses = np.array(ret_poses, dtype=np.float32)
153
+ return ret_poses
154
+
155
+ def ray_condition(K, c2w, H, W, device):
156
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
157
+ """
158
+ # c2w: B, V, 4, 4
159
+ # K: B, V, 4
160
+
161
+ B = K.shape[0]
162
+
163
+ j, i = custom_meshgrid(
164
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
165
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
166
+ )
167
+ i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
168
+ j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
169
+
170
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
171
+
172
+ zs = torch.ones_like(i) # [B, HxW]
173
+ xs = (i - cx) / fx * zs
174
+ ys = (j - cy) / fy * zs
175
+ zs = zs.expand_as(ys)
176
+
177
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
178
+ directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
179
+
180
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
181
+ rays_o = c2w[..., :3, 3] # B, V, 3
182
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
183
+ # c2w @ dirctions
184
+ rays_dxo = torch.cross(rays_o, rays_d)
185
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
186
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
187
+ # plucker = plucker.permute(0, 1, 4, 2, 3)
188
+ return plucker
189
+
190
+ def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
191
+ """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
192
+ """
193
+ with open(pose_file_path, 'r') as f:
194
+ poses = f.readlines()
195
+
196
+ poses = [pose.strip().split(' ') for pose in poses[1:]]
197
+ cam_params = [[float(x) for x in pose] for pose in poses]
198
+ if return_poses:
199
+ return cam_params
200
+ else:
201
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
202
+
203
+ sample_wh_ratio = width / height
204
+ pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
205
+
206
+ if pose_wh_ratio > sample_wh_ratio:
207
+ resized_ori_w = height * pose_wh_ratio
208
+ for cam_param in cam_params:
209
+ cam_param.fx = resized_ori_w * cam_param.fx / width
210
+ else:
211
+ resized_ori_h = width / pose_wh_ratio
212
+ for cam_param in cam_params:
213
+ cam_param.fy = resized_ori_h * cam_param.fy / height
214
+
215
+ intrinsic = np.asarray([[cam_param.fx * width,
216
+ cam_param.fy * height,
217
+ cam_param.cx * width,
218
+ cam_param.cy * height]
219
+ for cam_param in cam_params], dtype=np.float32)
220
+
221
+ K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
222
+ c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
223
+ c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
224
+ plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
225
+ plucker_embedding = plucker_embedding[None]
226
+ plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
227
+ return plucker_embedding
228
+
229
+ def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
230
+ """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
231
+ """
232
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
233
+
234
+ sample_wh_ratio = width / height
235
+ pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
236
+
237
+ if pose_wh_ratio > sample_wh_ratio:
238
+ resized_ori_w = height * pose_wh_ratio
239
+ for cam_param in cam_params:
240
+ cam_param.fx = resized_ori_w * cam_param.fx / width
241
+ else:
242
+ resized_ori_h = width / pose_wh_ratio
243
+ for cam_param in cam_params:
244
+ cam_param.fy = resized_ori_h * cam_param.fy / height
245
+
246
+ intrinsic = np.asarray([[cam_param.fx * width,
247
+ cam_param.fy * height,
248
+ cam_param.cx * width,
249
+ cam_param.cy * height]
250
+ for cam_param in cam_params], dtype=np.float32)
251
+
252
+ K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
253
+ c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
254
+ c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
255
+ plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
256
+ plucker_embedding = plucker_embedding[None]
257
+ plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
258
+ return plucker_embedding
259
+
260
+ def derive_ground_object_from_instruction(instruction: str) -> str:
261
+ s = (instruction or '').strip()
262
+ if not s:
263
+ return 'the target area'
264
+ s = s.rstrip('.').strip()
265
+
266
+ # swap/replace: capture phrase between "replace/swap" and "with/by"
267
+ swap_patterns = [
268
+ r"\breplace\s+(.*?)\s+(?:with|by)\b",
269
+ r"\bswap\s+(.*?)\s+with\b",
270
+ ]
271
+ for pat in swap_patterns:
272
+ m = re.search(pat, s, flags=re.IGNORECASE)
273
+ if m:
274
+ phrase = m.group(1).strip(' .,:;')
275
+ if phrase:
276
+ return phrase
277
+
278
+ # removal: capture object after remove/delete/erase/eliminate up to a preposition or punctuation
279
+ m = re.search(r"\b(?:remove|delete|erase|eliminate)\s+(.*?)(?:\s+(?:from|in|at|on|over|under|near|by)\b|[.,;]|$)", s, flags=re.IGNORECASE)
280
+ if m:
281
+ phrase = m.group(1).strip(' .,:;')
282
+ if phrase:
283
+ return phrase
284
+
285
+ # add/insert: generic target area
286
+ if re.search(r"^\s*(?:add|insert)\b", s, flags=re.IGNORECASE):
287
+ return 'the target area'
288
+
289
+ # local style (change/make ...): take the immediate noun after determiner
290
+ m = re.search(r"\b(?:change|make)\s+(?:(the|a|an)\s+)?([A-Za-z][A-Za-z0-9\-]*)", s, flags=re.IGNORECASE)
291
+ if m:
292
+ det = m.group(1) or ''
293
+ noun = m.group(2)
294
+ phrase = (det + ' ' + noun).strip()
295
+ return phrase
296
+
297
+ return 'the target area'
298
+
299
+ class ImageVideoSampler(BatchSampler):
300
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
301
+
302
+ Args:
303
+ sampler (Sampler): Base sampler.
304
+ dataset (Dataset): Dataset providing data information.
305
+ batch_size (int): Size of mini-batch.
306
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
307
+ its size would be less than ``batch_size``.
308
+ aspect_ratios (dict): The predefined aspect ratios.
309
+ """
310
+
311
+ def __init__(self,
312
+ sampler: Sampler,
313
+ dataset: Dataset,
314
+ batch_size: int,
315
+ drop_last: bool = False
316
+ ) -> None:
317
+ if not isinstance(sampler, Sampler):
318
+ raise TypeError('sampler should be an instance of ``Sampler``, '
319
+ f'but got {sampler}')
320
+ if not isinstance(batch_size, int) or batch_size <= 0:
321
+ raise ValueError('batch_size should be a positive integer value, '
322
+ f'but got batch_size={batch_size}')
323
+ self.sampler = sampler
324
+ self.dataset = dataset
325
+ self.batch_size = batch_size
326
+ self.drop_last = drop_last
327
+
328
+ # buckets for each aspect ratio
329
+ self.bucket = {'image':[], 'video':[]}
330
+
331
+ def __iter__(self):
332
+ for idx in self.sampler:
333
+ content_type = self.dataset.dataset[idx].get('type', 'image')
334
+ self.bucket[content_type].append(idx)
335
+
336
+ # yield a batch of indices in the same aspect ratio group
337
+ if len(self.bucket['video']) == self.batch_size:
338
+ bucket = self.bucket['video']
339
+ yield bucket[:]
340
+ del bucket[:]
341
+ elif len(self.bucket['image']) == self.batch_size:
342
+ bucket = self.bucket['image']
343
+ yield bucket[:]
344
+ del bucket[:]
345
+
346
+ @contextmanager
347
+ def VideoReader_contextmanager(*args, **kwargs):
348
+ vr = VideoReader(*args, **kwargs)
349
+ try:
350
+ yield vr
351
+ finally:
352
+ del vr
353
+ gc.collect()
354
+
355
+ def get_video_reader_batch(video_reader, batch_index):
356
+ frames = video_reader.get_batch(batch_index).asnumpy()
357
+ return frames
358
+
359
+ def resize_frame(frame, target_short_side):
360
+ h, w, _ = frame.shape
361
+ if h < w:
362
+ if target_short_side > h:
363
+ return frame
364
+ new_h = target_short_side
365
+ new_w = int(target_short_side * w / h)
366
+ else:
367
+ if target_short_side > w:
368
+ return frame
369
+ new_w = target_short_side
370
+ new_h = int(target_short_side * h / w)
371
+
372
+ resized_frame = cv2.resize(frame, (new_w, new_h))
373
+ return resized_frame
374
+
375
+ class VideoEditDataset(Dataset):
376
+ def __init__(
377
+ self,
378
+ ann_path,
379
+ data_root=None,
380
+ video_sample_height: int = None, # 改为None以支持动态分辨率
381
+ video_sample_width: int = None,
382
+ video_sample_stride=1,
383
+ video_sample_n_frames=65, # 9+8=17 for your case
384
+ source_frames=33,
385
+ edit_frames=32,
386
+ text_drop_ratio=0.1,
387
+ enable_bucket=False,
388
+ enable_inpaint=False,
389
+ instruction_template="A video sequence showing two parts: the first half shows the original scene, and the second half shows the same scene but {edit_instruction}",
390
+ ):
391
+ dataset = json.load(open(ann_path))
392
+ if isinstance(dataset, dict):
393
+ new_dataset = []
394
+ for vid_id, info in dataset.items():
395
+ text_content = info["edit_instruction"]
396
+ new_dataset.append({
397
+ "original_video": info["original_video"],
398
+ "edited_video": info["edited_video"],
399
+ "text": text_content,
400
+ "type": info.get("type", "video"),
401
+ # 添加分辨率信息到metadata
402
+ "resolution": info.get("resolution", None)
403
+ })
404
+ dataset = new_dataset
405
+
406
+ self.data_root = data_root
407
+ self.dataset = dataset
408
+ self.length = len(self.dataset)
409
+
410
+ self.source_frames = source_frames
411
+ self.edit_frames = edit_frames
412
+ self.video_sample_n_frames = video_sample_n_frames
413
+
414
+ self.instruction_template = instruction_template
415
+ self.enable_bucket = enable_bucket
416
+ self.text_drop_ratio = text_drop_ratio
417
+ self.enable_inpaint = enable_inpaint
418
+ self.video_sample_stride = video_sample_stride
419
+
420
+ # 如果启用bucket,不固定分辨率
421
+ if enable_bucket:
422
+ self.video_sample_height = None
423
+ self.video_sample_width = None
424
+ else:
425
+ self.video_sample_height = video_sample_height
426
+ self.video_sample_width = video_sample_width
427
+
428
+ def load_video_pair(self, original_path, edited_path):
429
+ """加载视频对,保持原始分辨率用于bucket training"""
430
+ if self.data_root is not None:
431
+ original_path = os.path.join(self.data_root, original_path)
432
+ edited_path = os.path.join(self.data_root, edited_path)
433
+
434
+ with VideoReader_contextmanager(original_path, num_threads=2) as orig_reader, \
435
+ VideoReader_contextmanager(edited_path, num_threads=2) as edit_reader:
436
+
437
+ # 获取视频信息
438
+ orig_length = len(orig_reader)
439
+ edit_length = len(edit_reader)
440
+ min_length = min(orig_length, edit_length)
441
+
442
+ # 统一采样策略
443
+ start_idx = 0 # 从头开始
444
+
445
+ orig_indices = np.linspace(
446
+ start_idx,
447
+ min(start_idx + (self.source_frames - 1) * self.video_sample_stride, orig_length - 1),
448
+ self.source_frames,
449
+ dtype=int
450
+ )
451
+
452
+ edit_indices = np.linspace(
453
+ start_idx,
454
+ min(start_idx + (self.edit_frames - 1) * self.video_sample_stride, edit_length - 1),
455
+ self.edit_frames,
456
+ dtype=int
457
+ )
458
+
459
+ # 加载帧
460
+ orig_frames = get_video_reader_batch(orig_reader, orig_indices)
461
+ edit_frames = get_video_reader_batch(edit_reader, edit_indices)
462
+
463
+ # 在拼接前对齐两段视频到相同 HxW(缩放后中心裁剪到 min(H1,H2) x min(W1,W2))
464
+ def resize_and_center_crop_batch(frames_np, target_h, target_w):
465
+ resized = []
466
+ for i in range(frames_np.shape[0]):
467
+ frame = frames_np[i]
468
+ h, w = frame.shape[0], frame.shape[1]
469
+ scale = max(target_h / h, target_w / w)
470
+ new_h = int(round(h * scale))
471
+ new_w = int(round(w * scale))
472
+ frame_resized = cv2.resize(frame, (new_w, new_h))
473
+ y0 = max((new_h - target_h) // 2, 0)
474
+ x0 = max((new_w - target_w) // 2, 0)
475
+ frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w]
476
+ resized.append(frame_cropped)
477
+ return np.stack(resized, axis=0)
478
+
479
+ oh, ow = orig_frames.shape[1], orig_frames.shape[2]
480
+ eh, ew = edit_frames.shape[1], edit_frames.shape[2]
481
+ target_h = min(oh, eh)
482
+ target_w = min(ow, ew)
483
+ if (oh != target_h or ow != target_w):
484
+ orig_frames = resize_and_center_crop_batch(orig_frames, target_h, target_w)
485
+ if (eh != target_h or ew != target_w):
486
+ edit_frames = resize_and_center_crop_batch(edit_frames, target_h, target_w)
487
+
488
+ # 如果启用bucket,返回numpy数组
489
+ if self.enable_bucket:
490
+ return np.concatenate([orig_frames, edit_frames], axis=0)
491
+ else:
492
+ # 转换为tensor并归一化
493
+ orig_frames = torch.from_numpy(orig_frames).permute(0, 3, 1, 2).contiguous() / 255.
494
+ edit_frames = torch.from_numpy(edit_frames).permute(0, 3, 1, 2).contiguous() / 255.
495
+ return torch.cat([orig_frames, edit_frames], dim=0)
496
+
497
+ def __len__(self):
498
+ return self.length
499
+
500
+ def __getitem__(self, idx):
501
+ data_info = self.dataset[idx % len(self.dataset)]
502
+
503
+ while True:
504
+ try:
505
+ # 加载视频对
506
+ pixel_values = self.load_video_pair(
507
+ data_info['original_video'],
508
+ data_info['edited_video']
509
+ )
510
+
511
+ # 准备文本
512
+ text = data_info['text']
513
+ if self.instruction_template and "{edit_instruction}" in self.instruction_template:
514
+ text = self.instruction_template.format(edit_instruction=text)
515
+
516
+ if random.random() < self.text_drop_ratio:
517
+ text = ''
518
+
519
+ sample = {
520
+ "pixel_values": pixel_values,
521
+ "text": text,
522
+ "data_type": "video",
523
+ "idx": idx,
524
+ }
525
+
526
+ # 如果需要inpainting
527
+ if self.enable_inpaint and not self.enable_bucket:
528
+ # 这里添加inpaint逻辑
529
+ pass
530
+
531
+ return sample
532
+
533
+ except Exception as e:
534
+ try:
535
+ print(
536
+ f"Error loading video pair: {e}\n"
537
+ f" original={os.path.join(self.data_root, data_info.get('original_video','')) if self.data_root else data_info.get('original_video','')}\n"
538
+ f" edited ={os.path.join(self.data_root, data_info.get('edited_video','')) if self.data_root else data_info.get('edited_video','')}"
539
+ )
540
+ except Exception:
541
+ print(f"Error loading video pair: {e}")
542
+ idx = random.randint(0, self.length-1)
543
+
544
+ class VideoEditReasoningDataset(Dataset):
545
+ def __init__(
546
+ self,
547
+ ann_path,
548
+ data_root=None,
549
+ video_sample_height: int = None,
550
+ video_sample_width: int = None,
551
+ video_sample_stride=1,
552
+ video_sample_n_frames=65,
553
+ source_frames=33,
554
+ reasoning_frames=4,
555
+ edit_frames=32,
556
+ text_drop_ratio=0.1,
557
+ enable_bucket=False,
558
+ enable_inpaint=False,
559
+ instruction_template="A video sequence showing three parts: first the original scene, then grounded {ground_instrction}, and finally the same scene but {edit_instruction}",
560
+ ):
561
+ dataset = json.load(open(ann_path))
562
+ if isinstance(dataset, dict):
563
+ new_dataset = []
564
+ for vid_id, info in dataset.items():
565
+ text_content = info.get("edit_instruction", info.get("text", ""))
566
+ # support both 'grounded_video' and 'ground_video'
567
+ grounded_key = "grounded_video" if "grounded_video" in info else "ground_video"
568
+ new_dataset.append({
569
+ "original_video": info["original_video"],
570
+ "grounded_video": info[grounded_key],
571
+ "edited_video": info["edited_video"],
572
+ "text": text_content,
573
+ "edit_instruction": text_content,
574
+ "type": info.get("type", "video"),
575
+ "resolution": info.get("resolution", None),
576
+ })
577
+ dataset = new_dataset
578
+
579
+ self.data_root = data_root
580
+ self.dataset = dataset
581
+ self.length = len(self.dataset)
582
+
583
+ self.source_frames = source_frames
584
+ self.reasoning_frames = reasoning_frames
585
+ self.edit_frames = edit_frames
586
+ self.video_sample_n_frames = video_sample_n_frames
587
+
588
+ self.instruction_template = instruction_template
589
+ self.enable_bucket = enable_bucket
590
+ self.text_drop_ratio = text_drop_ratio
591
+ self.enable_inpaint = enable_inpaint
592
+ self.video_sample_stride = video_sample_stride
593
+
594
+ if enable_bucket:
595
+ self.video_sample_height = None
596
+ self.video_sample_width = None
597
+ else:
598
+ self.video_sample_height = video_sample_height
599
+ self.video_sample_width = video_sample_width
600
+
601
+ def load_video_pair(self, original_path, grounded_path, edited_path):
602
+ if self.data_root is not None:
603
+ original_path = os.path.join(self.data_root, original_path)
604
+ grounded_path = os.path.join(self.data_root, grounded_path)
605
+ edited_path = os.path.join(self.data_root, edited_path)
606
+
607
+ with VideoReader_contextmanager(original_path, num_threads=2) as orig_reader, \
608
+ VideoReader_contextmanager(grounded_path, num_threads=2) as ground_reader, \
609
+ VideoReader_contextmanager(edited_path, num_threads=2) as edit_reader:
610
+
611
+ orig_length = len(orig_reader)
612
+ ground_length = len(ground_reader)
613
+ edit_length = len(edit_reader)
614
+
615
+ start_idx = 0
616
+
617
+ orig_indices = np.linspace(
618
+ start_idx,
619
+ min(start_idx + (self.source_frames - 1) * self.video_sample_stride, max(orig_length - 1, 0)),
620
+ self.source_frames,
621
+ dtype=int
622
+ )
623
+
624
+ # reasoning/grounded indices at 8-frame interval (example: 0,7,14,21, ...)
625
+ interval = 8
626
+ ground_indices_full = np.arange(0, max(ground_length, 1), interval, dtype=int)
627
+ if len(ground_indices_full) == 0:
628
+ ground_indices = np.array([0] * self.reasoning_frames, dtype=int)
629
+ else:
630
+ ground_indices = ground_indices_full[: self.reasoning_frames]
631
+ if len(ground_indices) < self.reasoning_frames:
632
+ pad_value = ground_indices[-1] if len(ground_indices) > 0 else 0
633
+ ground_indices = np.pad(
634
+ ground_indices, (0, self.reasoning_frames - len(ground_indices)), constant_values=pad_value
635
+ )
636
+
637
+ edit_indices = np.linspace(
638
+ start_idx,
639
+ min(start_idx + (self.edit_frames - 1) * self.video_sample_stride, max(edit_length - 1, 0)),
640
+ self.edit_frames,
641
+ dtype=int
642
+ )
643
+
644
+ orig_frames = get_video_reader_batch(orig_reader, orig_indices)
645
+ ground_frames = get_video_reader_batch(ground_reader, ground_indices)
646
+ edit_frames = get_video_reader_batch(edit_reader, edit_indices)
647
+
648
+ def resize_and_center_crop_batch(frames_np, target_h, target_w):
649
+ resized = []
650
+ for i in range(frames_np.shape[0]):
651
+ frame = frames_np[i]
652
+ h, w = frame.shape[0], frame.shape[1]
653
+ scale = max(target_h / h, target_w / w)
654
+ new_h = int(round(h * scale))
655
+ new_w = int(round(w * scale))
656
+ frame_resized = cv2.resize(frame, (new_w, new_h))
657
+ y0 = max((new_h - target_h) // 2, 0)
658
+ x0 = max((new_w - target_w) // 2, 0)
659
+ frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w]
660
+ resized.append(frame_cropped)
661
+ return np.stack(resized, axis=0)
662
+
663
+ oh, ow = orig_frames.shape[1], orig_frames.shape[2]
664
+ gh, gw = ground_frames.shape[1], ground_frames.shape[2]
665
+ eh, ew = edit_frames.shape[1], edit_frames.shape[2]
666
+ target_h = min(oh, gh, eh)
667
+ target_w = min(ow, gw, ew)
668
+ if (oh != target_h or ow != target_w):
669
+ orig_frames = resize_and_center_crop_batch(orig_frames, target_h, target_w)
670
+ if (gh != target_h or gw != target_w):
671
+ ground_frames = resize_and_center_crop_batch(ground_frames, target_h, target_w)
672
+ if (eh != target_h or ew != target_w):
673
+ edit_frames = resize_and_center_crop_batch(edit_frames, target_h, target_w)
674
+
675
+ if self.enable_bucket:
676
+ return np.concatenate([orig_frames, ground_frames, edit_frames], axis=0)
677
+ else:
678
+ orig_frames = torch.from_numpy(orig_frames).permute(0, 3, 1, 2).contiguous() / 255.
679
+ ground_frames = torch.from_numpy(ground_frames).permute(0, 3, 1, 2).contiguous() / 255.
680
+ edit_frames = torch.from_numpy(edit_frames).permute(0, 3, 1, 2).contiguous() / 255.
681
+ return torch.cat([orig_frames, ground_frames, edit_frames], dim=0)
682
+
683
+ def __len__(self):
684
+ return self.length
685
+
686
+ def __getitem__(self, idx):
687
+ data_info = self.dataset[idx % len(self.dataset)]
688
+
689
+ while True:
690
+ try:
691
+ pixel_values = self.load_video_pair(
692
+ data_info['original_video'],
693
+ data_info.get('grounded_video', data_info.get('ground_video')),
694
+ data_info['edited_video'],
695
+ )
696
+
697
+ # Prepare instructions
698
+ edit_text = data_info.get('edit_instruction', data_info.get('text', ''))
699
+ ground_instr = derive_ground_object_from_instruction(edit_text)
700
+
701
+ text = edit_text
702
+ if self.instruction_template:
703
+ text = self.instruction_template.format(edit_instruction=edit_text, ground_instrction=ground_instr)
704
+
705
+ if random.random() < self.text_drop_ratio:
706
+ text = ''
707
+
708
+ sample = {
709
+ "pixel_values": pixel_values,
710
+ "text": text,
711
+ "data_type": "video",
712
+ "idx": idx,
713
+ }
714
+
715
+ if self.enable_inpaint and not self.enable_bucket:
716
+ pass
717
+
718
+ return sample
719
+
720
+ except Exception as e:
721
+ print(f"Error loading video triplet: {e}")
722
+ idx = random.randint(0, self.length-1)
723
+
724
+ class ImageVideoDataset(Dataset):
725
+ def __init__(
726
+ self,
727
+ ann_path, data_root=None,
728
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
729
+ image_sample_size=512,
730
+ video_repeat=0,
731
+ text_drop_ratio=0.1,
732
+ enable_bucket=False,
733
+ video_length_drop_start=0.0,
734
+ video_length_drop_end=1.0,
735
+ enable_inpaint=False,
736
+ return_file_name=False,
737
+ ):
738
+ # Loading annotations from files
739
+ print(f"loading annotations from {ann_path} ...")
740
+ if ann_path.endswith('.csv'):
741
+ with open(ann_path, 'r') as csvfile:
742
+ dataset = list(csv.DictReader(csvfile))
743
+ elif ann_path.endswith('.json'):
744
+ dataset = json.load(open(ann_path))
745
+
746
+ self.data_root = data_root
747
+
748
+ # It's used to balance num of images and videos.
749
+ if video_repeat > 0:
750
+ self.dataset = []
751
+ for data in dataset:
752
+ if data.get('type', 'image') != 'video':
753
+ self.dataset.append(data)
754
+
755
+ for _ in range(video_repeat):
756
+ for data in dataset:
757
+ if data.get('type', 'image') == 'video':
758
+ self.dataset.append(data)
759
+ else:
760
+ self.dataset = dataset
761
+ del dataset
762
+
763
+ self.length = len(self.dataset)
764
+ print(f"data scale: {self.length}")
765
+ # TODO: enable bucket training
766
+ self.enable_bucket = enable_bucket
767
+ self.text_drop_ratio = text_drop_ratio
768
+ self.enable_inpaint = enable_inpaint
769
+ self.return_file_name = return_file_name
770
+
771
+ self.video_length_drop_start = video_length_drop_start
772
+ self.video_length_drop_end = video_length_drop_end
773
+
774
+ # Video params
775
+ self.video_sample_stride = video_sample_stride
776
+ self.video_sample_n_frames = video_sample_n_frames
777
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
778
+ self.video_transforms = transforms.Compose(
779
+ [
780
+ transforms.Resize(min(self.video_sample_size)),
781
+ transforms.CenterCrop(self.video_sample_size),
782
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
783
+ ]
784
+ )
785
+
786
+ # Image params
787
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
788
+ self.image_transforms = transforms.Compose([
789
+ transforms.Resize(min(self.image_sample_size)),
790
+ transforms.CenterCrop(self.image_sample_size),
791
+ transforms.ToTensor(),
792
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
793
+ ])
794
+
795
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
796
+
797
+ def get_batch(self, idx):
798
+ data_info = self.dataset[idx % len(self.dataset)]
799
+
800
+ if data_info.get('type', 'image')=='video':
801
+ video_id, text = data_info['file_path'], data_info['text']
802
+
803
+ if self.data_root is None:
804
+ video_dir = video_id
805
+ else:
806
+ video_dir = os.path.join(self.data_root, video_id)
807
+
808
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
809
+ min_sample_n_frames = min(
810
+ self.video_sample_n_frames,
811
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
812
+ )
813
+ if min_sample_n_frames == 0:
814
+ raise ValueError(f"No Frames in video.")
815
+
816
+ video_length = int(self.video_length_drop_end * len(video_reader))
817
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
818
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
819
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
820
+
821
+ try:
822
+ sample_args = (video_reader, batch_index)
823
+ pixel_values = func_timeout(
824
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
825
+ )
826
+ resized_frames = []
827
+ for i in range(len(pixel_values)):
828
+ frame = pixel_values[i]
829
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
830
+ resized_frames.append(resized_frame)
831
+ pixel_values = np.array(resized_frames)
832
+ except FunctionTimedOut:
833
+ raise ValueError(f"Read {idx} timeout.")
834
+ except Exception as e:
835
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
836
+
837
+ if not self.enable_bucket:
838
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
839
+ pixel_values = pixel_values / 255.
840
+ del video_reader
841
+ else:
842
+ pixel_values = pixel_values
843
+
844
+ if not self.enable_bucket:
845
+ pixel_values = self.video_transforms(pixel_values)
846
+
847
+ # Random use no text generation
848
+ if random.random() < self.text_drop_ratio:
849
+ text = ''
850
+ return pixel_values, text, 'video', video_dir
851
+ else:
852
+ image_path, text = data_info['file_path'], data_info['text']
853
+ if self.data_root is not None:
854
+ image_path = os.path.join(self.data_root, image_path)
855
+ image = Image.open(image_path).convert('RGB')
856
+ if not self.enable_bucket:
857
+ image = self.image_transforms(image).unsqueeze(0)
858
+ else:
859
+ image = np.expand_dims(np.array(image), 0)
860
+ if random.random() < self.text_drop_ratio:
861
+ text = ''
862
+ return image, text, 'image', image_path
863
+
864
+ def __len__(self):
865
+ return self.length
866
+
867
+ def __getitem__(self, idx):
868
+ data_info = self.dataset[idx % len(self.dataset)]
869
+ data_type = data_info.get('type', 'image')
870
+ while True:
871
+ sample = {}
872
+ try:
873
+ data_info_local = self.dataset[idx % len(self.dataset)]
874
+ data_type_local = data_info_local.get('type', 'image')
875
+ if data_type_local != data_type:
876
+ raise ValueError("data_type_local != data_type")
877
+
878
+ pixel_values, name, data_type, file_path = self.get_batch(idx)
879
+ sample["pixel_values"] = pixel_values
880
+ sample["text"] = name
881
+ sample["data_type"] = data_type
882
+ sample["idx"] = idx
883
+ if self.return_file_name:
884
+ sample["file_name"] = os.path.basename(file_path)
885
+
886
+ if len(sample) > 0:
887
+ break
888
+ except Exception as e:
889
+ print(e, self.dataset[idx % len(self.dataset)])
890
+ idx = random.randint(0, self.length-1)
891
+
892
+ class ImageVideoEditDataset(Dataset):
893
+ def __init__(
894
+ self,
895
+ ann_path,
896
+ data_root=None,
897
+ video_sample_size=512,
898
+ video_sample_stride=1,
899
+ source_frames=33,
900
+ target_frames=32,
901
+ text_drop_ratio=0.1,
902
+ enable_bucket=False,
903
+ enable_inpaint=False,
904
+ video_length_drop_start=0.0,
905
+ video_length_drop_end=1.0,
906
+ instruction_template="A video sequence showing two parts: the first half shows the original scene, and the second half shows the same scene but {edit_instruction}",
907
+ ):
908
+ dataset = json.load(open(ann_path))
909
+ if isinstance(dataset, dict):
910
+ new_dataset = []
911
+ for _, info in dataset.items():
912
+ # Keep original keys, just standardize text field
913
+ data_type = info.get("type", "video")
914
+ entry = dict(info) # Copy original entry
915
+ # Standardize text field name and handle None/empty values
916
+ if "edit_instruction" in entry:
917
+ entry["text"] = entry["edit_instruction"]
918
+ elif "instruction" in entry:
919
+ entry["text"] = entry["instruction"]
920
+ elif "text" not in entry:
921
+ entry["text"] = ""
922
+
923
+ # Ensure text is not None (convert None to empty string)
924
+ if entry["text"] is None:
925
+ entry["text"] = ""
926
+
927
+ # Add file_path for bucket sampler compatibility
928
+ # Bucket sampler expects 'file_path' to get dimensions
929
+ if data_type == "video":
930
+ entry["file_path"] = entry.get("original_video", "")
931
+ else: # image
932
+ entry["file_path"] = entry.get("original_image", "")
933
+
934
+ new_dataset.append(entry)
935
+ dataset = new_dataset
936
+
937
+ self.data_root = data_root
938
+ self.dataset = dataset
939
+ self.length = len(self.dataset)
940
+
941
+ # sampling params
942
+ self.video_sample_stride = video_sample_stride
943
+ self.source_frames = source_frames
944
+ self.target_frames = target_frames
945
+ self.video_length_drop_start = video_length_drop_start
946
+ self.video_length_drop_end = video_length_drop_end
947
+
948
+ # transforms params (match ImageVideoDataset)
949
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
950
+ self.video_transforms = transforms.Compose(
951
+ [
952
+ transforms.Resize(min(self.video_sample_size)),
953
+ transforms.CenterCrop(self.video_sample_size),
954
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
955
+ ]
956
+ )
957
+
958
+ # Image transforms for non-bucket mode
959
+ self.image_transforms = transforms.Compose([
960
+ transforms.Resize(min(self.video_sample_size)),
961
+ transforms.CenterCrop(self.video_sample_size),
962
+ transforms.ToTensor(),
963
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
964
+ ])
965
+
966
+ self.instruction_template = instruction_template
967
+ self.enable_bucket = enable_bucket
968
+ self.text_drop_ratio = text_drop_ratio
969
+ self.enable_inpaint = enable_inpaint
970
+
971
+ # For pre-resize like ImageVideoDataset
972
+ self.larger_side_of_image_and_video = min(self.video_sample_size)
973
+
974
+ def _resize_and_center_crop_batch(self, frames_np, target_h, target_w):
975
+ resized = []
976
+ for i in range(frames_np.shape[0]):
977
+ frame = frames_np[i]
978
+ h, w = frame.shape[0], frame.shape[1]
979
+ scale = max(target_h / h, target_w / w)
980
+ new_h = int(round(h * scale))
981
+ new_w = int(round(w * scale))
982
+ frame_resized = cv2.resize(frame, (new_w, new_h))
983
+ y0 = max((new_h - target_h) // 2, 0)
984
+ x0 = max((new_w - target_w) // 2, 0)
985
+ frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w]
986
+ resized.append(frame_cropped)
987
+ return np.stack(resized, axis=0)
988
+
989
+ def _resize_and_center_crop_image(self, image_np, target_h, target_w):
990
+ h, w = image_np.shape[0], image_np.shape[1]
991
+ scale = max(target_h / h, target_w / w)
992
+ new_h = int(round(h * scale))
993
+ new_w = int(round(w * scale))
994
+ image_resized = cv2.resize(image_np, (new_w, new_h))
995
+ y0 = max((new_h - target_h) // 2, 0)
996
+ x0 = max((new_w - target_w) // 2, 0)
997
+ image_cropped = image_resized[y0:y0 + target_h, x0:x0 + target_w]
998
+ return image_cropped
999
+
1000
+ def get_batch(self, idx):
1001
+ data_info = self.dataset[idx % len(self.dataset)]
1002
+
1003
+ data_type = data_info.get('type', 'video')
1004
+
1005
+ # Handle None or empty instruction with safety fallback
1006
+ raw_text = data_info.get('text', '')
1007
+ if raw_text is None or (isinstance(raw_text, str) and not raw_text.strip()):
1008
+ # Use a generic fallback description if instruction is missing
1009
+ raw_text = "the content has been modified"
1010
+
1011
+ # Apply instruction template if available
1012
+ if self.instruction_template and "{edit_instruction}" in self.instruction_template:
1013
+ text = self.instruction_template.format(edit_instruction=raw_text)
1014
+ else:
1015
+ text = raw_text
1016
+
1017
+ if data_type == 'video':
1018
+ # video pair branch (default)
1019
+ src_rel, tgt_rel = data_info['original_video'], data_info['edited_video']
1020
+
1021
+ if self.data_root is not None:
1022
+ src_path = os.path.join(self.data_root, src_rel)
1023
+ tgt_path = os.path.join(self.data_root, tgt_rel)
1024
+ else:
1025
+ src_path = src_rel
1026
+ tgt_path = tgt_rel
1027
+
1028
+ # Force use CPU decoder to read all frames instead of just keyframes
1029
+ from decord import cpu
1030
+ with VideoReader_contextmanager(src_path, num_threads=2, ctx=cpu(0)) as src_reader, \
1031
+ VideoReader_contextmanager(tgt_path, num_threads=2, ctx=cpu(0)) as tgt_reader:
1032
+
1033
+ # Get video lengths
1034
+ src_length = len(src_reader)
1035
+ tgt_length = len(tgt_reader)
1036
+
1037
+ # Check if video has enough frames
1038
+ if src_length < self.source_frames:
1039
+ raise ValueError(f"Source video only has {src_length} frames, but requested {self.source_frames}")
1040
+ if tgt_length < self.target_frames:
1041
+ raise ValueError(f"Target video only has {tgt_length} frames, but requested {self.target_frames}")
1042
+
1043
+ # Unified sampling strategy: start from beginning (same as VideoEditDataset)
1044
+ start_idx = 0
1045
+
1046
+ src_indices = np.linspace(
1047
+ start_idx,
1048
+ min(start_idx + (self.source_frames - 1) * self.video_sample_stride, src_length - 1),
1049
+ self.source_frames,
1050
+ dtype=int
1051
+ )
1052
+
1053
+ tgt_indices = np.linspace(
1054
+ start_idx,
1055
+ min(start_idx + (self.target_frames - 1) * self.video_sample_stride, tgt_length - 1),
1056
+ self.target_frames,
1057
+ dtype=int
1058
+ )
1059
+
1060
+ # read batches with timeout
1061
+ try:
1062
+ src_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(src_reader, src_indices))
1063
+ tgt_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(tgt_reader, tgt_indices))
1064
+ except FunctionTimedOut:
1065
+ raise ValueError(f"Read {idx} timeout.")
1066
+ except Exception as e:
1067
+ raise ValueError(f"Failed to extract frames from pair. Error is {e}.")
1068
+
1069
+ # align HxW between source and target to enable concat
1070
+ sh, sw = src_frames.shape[1], src_frames.shape[2]
1071
+ th, tw = tgt_frames.shape[1], tgt_frames.shape[2]
1072
+ target_h = min(sh, th)
1073
+ target_w = min(sw, tw)
1074
+ if (sh != target_h or sw != target_w):
1075
+ src_frames = self._resize_and_center_crop_batch(src_frames, target_h, target_w)
1076
+ if (th != target_h or tw != target_w):
1077
+ tgt_frames = self._resize_and_center_crop_batch(tgt_frames, target_h, target_w)
1078
+
1079
+ if not self.enable_bucket:
1080
+ src_tensor = torch.from_numpy(src_frames).permute(0, 3, 1, 2).contiguous() / 255.
1081
+ tgt_tensor = torch.from_numpy(tgt_frames).permute(0, 3, 1, 2).contiguous() / 255.
1082
+
1083
+ src_tensor = self.video_transforms(src_tensor)
1084
+ tgt_tensor = self.video_transforms(tgt_tensor)
1085
+ else:
1086
+ src_tensor = src_frames
1087
+ tgt_tensor = tgt_frames
1088
+
1089
+ # Random text drop
1090
+ if random.random() < self.text_drop_ratio:
1091
+ text = ''
1092
+
1093
+ return src_tensor, tgt_tensor, text, 'video'
1094
+ else:
1095
+ # image pair branch (simple like ImageVideoDataset image path)
1096
+ src_img_rel = data_info.get('original_image')
1097
+ tgt_img_rel = data_info.get('edited_image')
1098
+ if src_img_rel is None or tgt_img_rel is None:
1099
+ raise ValueError('Missing original_image/edited_image for image sample')
1100
+
1101
+ if self.data_root is not None:
1102
+ src_img_path = os.path.join(self.data_root, src_img_rel)
1103
+ tgt_img_path = os.path.join(self.data_root, tgt_img_rel)
1104
+ else:
1105
+ src_img_path = src_img_rel
1106
+ tgt_img_path = tgt_img_rel
1107
+
1108
+ src_img = Image.open(src_img_path).convert('RGB')
1109
+ tgt_img = Image.open(tgt_img_path).convert('RGB')
1110
+
1111
+ if not self.enable_bucket:
1112
+ # Apply transforms and add frame dimension
1113
+ src_tensor = self.image_transforms(src_img).unsqueeze(0) # (1, C, H, W)
1114
+ tgt_tensor = self.image_transforms(tgt_img).unsqueeze(0) # (1, C, H, W)
1115
+ else:
1116
+ # For bucket mode, keep as numpy and add frame dimension
1117
+ src_tensor = np.expand_dims(np.array(src_img), axis=0) # (1, H, W, C)
1118
+ tgt_tensor = np.expand_dims(np.array(tgt_img), axis=0) # (1, H, W, C)
1119
+
1120
+ if random.random() < self.text_drop_ratio:
1121
+ text = ''
1122
+
1123
+ return src_tensor, tgt_tensor, text, 'image'
1124
+
1125
+ def __len__(self):
1126
+ return self.length
1127
+
1128
+ def __getitem__(self, idx):
1129
+ data_info = self.dataset[idx % len(self.dataset)]
1130
+ data_type = data_info.get('type', 'video')
1131
+ while True:
1132
+ sample = {}
1133
+ try:
1134
+ data_info_local = self.dataset[idx % len(self.dataset)]
1135
+ data_type_local = data_info_local.get('type', 'video')
1136
+ if data_type_local != data_type:
1137
+ raise ValueError("data_type_local != data_type")
1138
+
1139
+ src_vals, tgt_vals, name, data_type = self.get_batch(idx)
1140
+ if data_type == 'video':
1141
+ sample["pixel_values_src_video"] = src_vals
1142
+ sample["pixel_values_tgt_video"] = tgt_vals
1143
+ else:
1144
+ sample["pixel_values_src_image"] = src_vals
1145
+ sample["pixel_values_tgt_image"] = tgt_vals
1146
+ sample["text"] = name
1147
+ sample["data_type"] = data_type
1148
+ sample["idx"] = idx
1149
+
1150
+ if len(sample) > 0:
1151
+ break
1152
+ except Exception as e:
1153
+ print(e, self.dataset[idx % len(self.dataset)])
1154
+ idx = random.randint(0, self.length-1)
1155
+
1156
+ # Inpaint not applied here to avoid ambiguity across src/tgt branches
1157
+
1158
+ return sample
1159
+
1160
+
1161
+ class ImageVideoCoTDataset(Dataset):
1162
+ """
1163
+ Dataset for Chain-of-Thought (CoT) style image/video editing.
1164
+ - For videos: loads original_video, grounded_video, and edited_video (3-part)
1165
+ - For images: loads original_image and edited_image (2-part, same as ImageVideoEditDataset)
1166
+ """
1167
+ def __init__(
1168
+ self,
1169
+ ann_path,
1170
+ data_root=None,
1171
+ video_sample_size=512,
1172
+ video_sample_stride=1,
1173
+ source_frames=33,
1174
+ reasoning_frames=4,
1175
+ target_frames=33,
1176
+ text_drop_ratio=0.1,
1177
+ enable_bucket=False,
1178
+ enable_inpaint=False,
1179
+ video_length_drop_start=0.0,
1180
+ video_length_drop_end=1.0,
1181
+ instruction_template="A video sequence showing three parts: first the original scene, then grounded {ground_instruction}, and finally the same scene but {edit_instruction}",
1182
+ enable_gradual_ground=False,
1183
+ enable_gray_red_mask=False,
1184
+ enable_gray_black_background=False,
1185
+ enable_gray_alpha_overlay=False,
1186
+ gray_alpha=0.5,
1187
+ gray_intensity_range=(96, 160),
1188
+ gray_tolerance=12,
1189
+ ):
1190
+ dataset = json.load(open(ann_path))
1191
+ if isinstance(dataset, dict):
1192
+ new_dataset = []
1193
+ for _, info in dataset.items():
1194
+ data_type = info.get("type", "video")
1195
+ entry = dict(info) # Copy original entry
1196
+
1197
+ # Standardize text field name and handle None/empty values
1198
+ if "edit_instruction" in entry:
1199
+ entry["text"] = entry["edit_instruction"]
1200
+ elif "instruction" in entry:
1201
+ entry["text"] = entry["instruction"]
1202
+ elif "text" not in entry:
1203
+ entry["text"] = ""
1204
+
1205
+ # Ensure text is not None
1206
+ if entry["text"] is None:
1207
+ entry["text"] = ""
1208
+
1209
+ # Add file_path for bucket sampler compatibility
1210
+ if data_type == "video":
1211
+ entry["file_path"] = entry.get("original_video", "")
1212
+ else: # image
1213
+ entry["file_path"] = entry.get("original_image", "")
1214
+
1215
+ new_dataset.append(entry)
1216
+ dataset = new_dataset
1217
+
1218
+ self.data_root = data_root
1219
+ self.dataset = dataset
1220
+ self.length = len(self.dataset)
1221
+
1222
+ # sampling params
1223
+ self.video_sample_stride = video_sample_stride
1224
+ self.source_frames = source_frames
1225
+ self.reasoning_frames = reasoning_frames
1226
+ self.target_frames = target_frames
1227
+ self.video_length_drop_start = video_length_drop_start
1228
+ self.video_length_drop_end = video_length_drop_end
1229
+
1230
+ # transforms params
1231
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
1232
+ self.video_transforms = transforms.Compose(
1233
+ [
1234
+ transforms.Resize(min(self.video_sample_size)),
1235
+ transforms.CenterCrop(self.video_sample_size),
1236
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
1237
+ ]
1238
+ )
1239
+
1240
+ # Image transforms for non-bucket mode
1241
+ self.image_transforms = transforms.Compose([
1242
+ transforms.Resize(min(self.video_sample_size)),
1243
+ transforms.CenterCrop(self.video_sample_size),
1244
+ transforms.ToTensor(),
1245
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
1246
+ ])
1247
+
1248
+ self.instruction_template = instruction_template
1249
+ self.enable_bucket = enable_bucket
1250
+ self.text_drop_ratio = text_drop_ratio
1251
+ self.enable_inpaint = enable_inpaint
1252
+ self.enable_gradual_ground = enable_gradual_ground
1253
+ # only one visualization mode at a time
1254
+ enabled_modes = int(bool(enable_gray_red_mask)) + int(bool(enable_gray_black_background)) + int(bool(enable_gray_alpha_overlay))
1255
+ if enabled_modes > 1:
1256
+ raise ValueError("enable_gray_red_mask, enable_gray_black_background and enable_gray_alpha_overlay cannot be enabled simultaneously.")
1257
+ self.enable_gray_red_mask = enable_gray_red_mask
1258
+ self.enable_gray_black_background = enable_gray_black_background
1259
+ self.enable_gray_alpha_overlay = enable_gray_alpha_overlay
1260
+ self.gray_alpha = float(gray_alpha)
1261
+ if not (0.0 <= self.gray_alpha <= 1.0):
1262
+ raise ValueError("gray_alpha must be in [0,1].")
1263
+ if not isinstance(gray_intensity_range, (list, tuple)) or len(gray_intensity_range) != 2:
1264
+ raise ValueError("gray_intensity_range must contain exactly two values (min and max intensity).")
1265
+ self.gray_intensity_range = (int(gray_intensity_range[0]), int(gray_intensity_range[1]))
1266
+ if self.gray_intensity_range[0] > self.gray_intensity_range[1]:
1267
+ raise ValueError("gray_intensity_range min value cannot be greater than max value.")
1268
+ self.gray_tolerance = int(gray_tolerance)
1269
+
1270
+ # For pre-resize like ImageVideoDataset
1271
+ self.larger_side_of_image_and_video = min(self.video_sample_size)
1272
+
1273
+ def _resize_and_center_crop_batch(self, frames_np, target_h, target_w):
1274
+ resized = []
1275
+ for i in range(frames_np.shape[0]):
1276
+ frame = frames_np[i]
1277
+ h, w = frame.shape[0], frame.shape[1]
1278
+ scale = max(target_h / h, target_w / w)
1279
+ new_h = int(round(h * scale))
1280
+ new_w = int(round(w * scale))
1281
+ frame_resized = cv2.resize(frame, (new_w, new_h))
1282
+ y0 = max((new_h - target_h) // 2, 0)
1283
+ x0 = max((new_w - target_w) // 2, 0)
1284
+ frame_cropped = frame_resized[y0:y0 + target_h, x0:x0 + target_w]
1285
+ resized.append(frame_cropped)
1286
+ return np.stack(resized, axis=0)
1287
+
1288
+ def _resize_and_center_crop_image(self, image_np, target_h, target_w):
1289
+ h, w = image_np.shape[0], image_np.shape[1]
1290
+ scale = max(target_h / h, target_w / w)
1291
+ new_h = int(round(h * scale))
1292
+ new_w = int(round(w * scale))
1293
+ image_resized = cv2.resize(image_np, (new_w, new_h))
1294
+ y0 = max((new_h - target_h) // 2, 0)
1295
+ x0 = max((new_w - target_w) // 2, 0)
1296
+ image_cropped = image_resized[y0:y0 + target_h, x0:x0 + target_w]
1297
+ return image_cropped
1298
+
1299
+ def _derive_ground_instruction(self, edit_instruction_text: str) -> str:
1300
+ """Derive grounded object phrase from instruction using shared rules."""
1301
+ return derive_ground_object_from_instruction(edit_instruction_text)
1302
+
1303
+ def _ensure_same_size_pair(self, img_a: np.ndarray, img_b: np.ndarray) -> tuple:
1304
+ """Resize img_b to img_a's size if needed to enable per-pixel interpolation."""
1305
+ ha, wa = img_a.shape[:2]
1306
+ hb, wb = img_b.shape[:2]
1307
+ if (ha, wa) == (hb, wb):
1308
+ return img_a, img_b
1309
+ resized_b = cv2.resize(img_b, (wa, ha), interpolation=cv2.INTER_LINEAR)
1310
+ return img_a, resized_b
1311
+
1312
+ def _interpolate_ground_frames(self, ground_first: np.ndarray, target_first: np.ndarray,
1313
+ total_steps: int = 16,
1314
+ pick_indices: tuple = (0, 4, 8, 12)) -> np.ndarray:
1315
+ """
1316
+ Create grounding frames by linearly interpolating between the first frame of
1317
+ the grounding video and the first frame of the edited video, then picking
1318
+ specific indices.
1319
+ Returns array of shape (len(pick_indices), H, W, 3) in uint8.
1320
+ """
1321
+ a_np, b_np = self._ensure_same_size_pair(ground_first, target_first)
1322
+
1323
+ a_t = torch.from_numpy(a_np).float() / 255.0 # H, W, C
1324
+ b_t = torch.from_numpy(b_np).float() / 255.0 # H, W, C
1325
+
1326
+ a_t = a_t.permute(2, 0, 1).contiguous() # C, H, W
1327
+ b_t = b_t.permute(2, 0, 1).contiguous() # C, H, W
1328
+
1329
+ c, h, w = a_t.shape
1330
+ pair = torch.stack([a_t, b_t], dim=0) # 2, C, H, W
1331
+ pair_chw_t = pair.permute(1, 2, 3, 0).contiguous() # C, H, W, 2
1332
+ seq = pair_chw_t.view(1, c * h * w, 2) # 1, (C*H*W), 2
1333
+ with torch.no_grad():
1334
+ seq_interp = F.interpolate(seq, size=int(total_steps), mode="linear", align_corners=True)
1335
+ seq_interp = seq_interp.view(c, h, w, int(total_steps)).permute(3, 0, 1, 2).contiguous() # T, C, H, W
1336
+
1337
+ out_frames = []
1338
+ t_steps = int(total_steps)
1339
+ for idx in pick_indices:
1340
+ safe_idx = max(0, min(int(idx), t_steps - 1))
1341
+ img = (seq_interp[safe_idx].clamp(0.0, 1.0) * 255.0).byte().permute(1, 2, 0).cpu().numpy()
1342
+ out_frames.append(img)
1343
+ return np.stack(out_frames, axis=0)
1344
+
1345
+ def _build_gray_mask(self, frame: np.ndarray) -> np.ndarray:
1346
+ """Detect gray regions in a frame using intensity range and tolerance."""
1347
+ frame_float = frame.astype(np.float32)
1348
+ if frame_float.max() <= 1.0:
1349
+ frame_float = frame_float * 255.0
1350
+ channel_max = frame_float.max(axis=2)
1351
+ channel_min = frame_float.min(axis=2)
1352
+ min_intensity, max_intensity = self.gray_intensity_range
1353
+ tone_flatness = channel_max - channel_min
1354
+ mask = tone_flatness <= float(self.gray_tolerance)
1355
+ mask &= channel_max >= float(min_intensity)
1356
+ mask &= channel_max <= float(max_intensity)
1357
+ return mask
1358
+
1359
+ def _apply_gray_region_effect(self, frames_np: np.ndarray, mode: str) -> np.ndarray:
1360
+ """Apply requested effect on detected gray regions for a batch of frames."""
1361
+ processed_frames = []
1362
+ for frame in frames_np:
1363
+ mask = self._build_gray_mask(frame)
1364
+ if not np.any(mask):
1365
+ processed_frames.append(frame)
1366
+ continue
1367
+ frame_out = frame.copy()
1368
+ if np.issubdtype(frame_out.dtype, np.floating) and frame_out.max() <= 1.0:
1369
+ red_value = np.array([1.0, 0.0, 0.0], dtype=frame_out.dtype)
1370
+ else:
1371
+ red_value = np.array([255, 0, 0], dtype=frame_out.dtype)
1372
+ if mode == "red":
1373
+ frame_out[mask] = red_value
1374
+ else:
1375
+ frame_out[:] = 0
1376
+ frame_out[mask] = frame[mask]
1377
+ processed_frames.append(frame_out)
1378
+ return np.stack(processed_frames, axis=0)
1379
+
1380
+ def _apply_gray_overlay_from_reference(self, src_frames_np: np.ndarray, ref_frames_np: np.ndarray,
1381
+ alpha: float = 0.5, gray_value: float = 0.5, num_frames: int = 4) -> np.ndarray:
1382
+ """
1383
+ Detect gray regions on ref frames, and overlay gray with alpha onto the
1384
+ first `num_frames` frames of src frames at the same positions.
1385
+ """
1386
+ n = min(int(num_frames), int(src_frames_np.shape[0]), int(ref_frames_np.shape[0]))
1387
+ if n <= 0:
1388
+ return src_frames_np
1389
+ out = src_frames_np.copy()
1390
+ a = float(alpha)
1391
+ a = 0.0 if a < 0.0 else (1.0 if a > 1.0 else a)
1392
+ gv = float(gray_value)
1393
+ gv = 0.0 if gv < 0.0 else (1.0 if gv > 1.0 else gv)
1394
+ for i in range(n):
1395
+ mask = self._build_gray_mask(ref_frames_np[i])
1396
+ if not np.any(mask):
1397
+ continue
1398
+ src = out[i]
1399
+ # normalize to 0..1 float
1400
+ if np.issubdtype(src.dtype, np.floating):
1401
+ f = src.astype(np.float32)
1402
+ if f.max() > 1.0:
1403
+ f = np.clip(f / 255.0, 0.0, 1.0)
1404
+ back_to_uint8 = False
1405
+ else:
1406
+ f = src.astype(np.float32) / 255.0
1407
+ back_to_uint8 = True
1408
+ gray_color = np.array([gv, gv, gv], dtype=np.float32)
1409
+ # boolean mask is (H,W); f[mask] -> (K,3), broadcast with gray_color (3,)
1410
+ f[mask] = (1.0 - a) * f[mask] + a * gray_color
1411
+ if back_to_uint8:
1412
+ out[i] = (f * 255.0).clip(0, 255).astype(src.dtype)
1413
+ else:
1414
+ out[i] = f.astype(src.dtype)
1415
+ return out
1416
+
1417
+ def get_batch(self, idx):
1418
+ data_info = self.dataset[idx % len(self.dataset)]
1419
+ data_type = data_info.get('type', 'video')
1420
+
1421
+ # Handle None or empty instruction with safety fallback
1422
+ raw_text = data_info.get('text', '')
1423
+ if raw_text is None or (isinstance(raw_text, str) and not raw_text.strip()):
1424
+ raw_text = "the content has been modified"
1425
+
1426
+ if data_type == 'video':
1427
+ # Video triplet branch: original + grounded + edited
1428
+ src_rel = data_info['original_video']
1429
+ # Support both 'grounded_video' and 'ground_video' keys
1430
+ ground_rel = data_info.get('grounded_video', data_info.get('ground_video'))
1431
+ tgt_rel = data_info['edited_video']
1432
+
1433
+ if self.data_root is not None:
1434
+ src_path = os.path.join(self.data_root, src_rel)
1435
+ ground_path = os.path.join(self.data_root, ground_rel)
1436
+ tgt_path = os.path.join(self.data_root, tgt_rel)
1437
+ else:
1438
+ src_path = src_rel
1439
+ ground_path = ground_rel
1440
+ tgt_path = tgt_rel
1441
+
1442
+ # Force use CPU decoder to read all frames
1443
+ from decord import cpu
1444
+ with VideoReader_contextmanager(src_path, num_threads=2, ctx=cpu(0)) as src_reader, \
1445
+ VideoReader_contextmanager(ground_path, num_threads=2, ctx=cpu(0)) as ground_reader, \
1446
+ VideoReader_contextmanager(tgt_path, num_threads=2, ctx=cpu(0)) as tgt_reader:
1447
+
1448
+ # Get video lengths
1449
+ src_length = len(src_reader)
1450
+ ground_length = len(ground_reader)
1451
+ tgt_length = len(tgt_reader)
1452
+
1453
+ # Check if video has enough frames
1454
+ if src_length < self.source_frames:
1455
+ raise ValueError(f"Source video only has {src_length} frames, but requested {self.source_frames}")
1456
+ if tgt_length < self.target_frames:
1457
+ raise ValueError(f"Target video only has {tgt_length} frames, but requested {self.target_frames}")
1458
+
1459
+ # Unified sampling strategy: start from beginning
1460
+ start_idx = 0
1461
+
1462
+ # Sample source frames
1463
+ src_indices = np.linspace(
1464
+ start_idx,
1465
+ min(start_idx + (self.source_frames - 1) * self.video_sample_stride, src_length - 1),
1466
+ self.source_frames,
1467
+ dtype=int
1468
+ )
1469
+
1470
+ # Sample target frames
1471
+ tgt_indices = np.linspace(
1472
+ start_idx,
1473
+ min(start_idx + (self.target_frames - 1) * self.video_sample_stride, tgt_length - 1),
1474
+ self.target_frames,
1475
+ dtype=int
1476
+ )
1477
+
1478
+ # Read batches with timeout
1479
+ try:
1480
+ src_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(src_reader, src_indices))
1481
+ tgt_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(tgt_reader, tgt_indices))
1482
+
1483
+ if self.enable_gradual_ground:
1484
+ # Interpolate between first frame of grounded and edited videos
1485
+ ground_first = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(ground_reader, [0]))
1486
+ # Use the first decoded edited frame if available to avoid double decode
1487
+ tgt_first_frame = tgt_frames[0]
1488
+ # steps: 0..15, pick 0,3,6,9,12 -> 5 grounding frames
1489
+ ground_frames = self._interpolate_ground_frames(
1490
+ ground_first=ground_first[0],
1491
+ target_first=tgt_first_frame,
1492
+ total_steps=16,
1493
+ pick_indices=(0, 3, 6, 9, 12),
1494
+ )
1495
+ else:
1496
+ # # Original behavior: sample grounding frames evenly by stride
1497
+ # ground_indices = np.linspace(
1498
+ # start_idx,
1499
+ # min(start_idx + (self.reasoning_frames - 1) * self.video_sample_stride, ground_length - 1),
1500
+ # self.reasoning_frames,
1501
+ # dtype=int
1502
+ # )
1503
+
1504
+ #==============================================================
1505
+ # New behavior: ground_indices are the first 'reasoning_frames' from src_indices
1506
+ ground_indices = src_indices[:self.reasoning_frames]
1507
+
1508
+ # --- 增加这个重要的安全检查 ---
1509
+ # 确保我们想采样的最后一帧 (ground_indices[-1])
1510
+ # 没有超出 ground_video 的总长度 (ground_length)
1511
+ if len(ground_indices) > 0 and ground_indices[-1] >= ground_length:
1512
+ raise ValueError(
1513
+ f"Data inconsistency error: Ground video has only {ground_length} frames, "
1514
+ f"but the source-based sampling (stride={self.video_sample_stride}) "
1515
+ f"requires reading up to frame {ground_indices[-1]}. "
1516
+ f"File: {ground_path}"
1517
+ )
1518
+ ground_frames = func_timeout(VIDEO_READER_TIMEOUT, get_video_reader_batch, args=(ground_reader, ground_indices))
1519
+ except FunctionTimedOut:
1520
+ raise ValueError(f"Read {idx} timeout.")
1521
+ except Exception as e:
1522
+ raise ValueError(f"Failed to extract frames from triplet. Error is {e}.")
1523
+
1524
+ # Align HxW among source, ground, and target to enable concat
1525
+ sh, sw = src_frames.shape[1], src_frames.shape[2]
1526
+ gh, gw = ground_frames.shape[1], ground_frames.shape[2]
1527
+ th, tw = tgt_frames.shape[1], tgt_frames.shape[2]
1528
+ target_h = min(sh, gh, th)
1529
+ target_w = min(sw, gw, tw)
1530
+
1531
+ if (sh != target_h or sw != target_w):
1532
+ src_frames = self._resize_and_center_crop_batch(src_frames, target_h, target_w)
1533
+ if (gh != target_h or gw != target_w):
1534
+ ground_frames = self._resize_and_center_crop_batch(ground_frames, target_h, target_w)
1535
+ if (th != target_h or tw != target_w):
1536
+ tgt_frames = self._resize_and_center_crop_batch(tgt_frames, target_h, target_w)
1537
+
1538
+ if self.enable_gray_red_mask or self.enable_gray_black_background:
1539
+ effect_mode = "red" if self.enable_gray_red_mask else "black"
1540
+ ground_frames = self._apply_gray_region_effect(ground_frames, effect_mode)
1541
+ elif self.enable_gray_alpha_overlay:
1542
+ # Use gray regions detected on grounding frames to overlay 50% gray on the
1543
+ # first 4 frames of the original video.
1544
+ ground_frames = self._apply_gray_overlay_from_reference(
1545
+ src_frames, ground_frames, alpha=self.gray_alpha, gray_value=0.5, num_frames=4
1546
+ )
1547
+
1548
+ if not self.enable_bucket:
1549
+ src_tensor = torch.from_numpy(src_frames).permute(0, 3, 1, 2).contiguous() / 255.
1550
+ ground_tensor = torch.from_numpy(ground_frames).permute(0, 3, 1, 2).contiguous() / 255.
1551
+ tgt_tensor = torch.from_numpy(tgt_frames).permute(0, 3, 1, 2).contiguous() / 255.
1552
+
1553
+ src_tensor = self.video_transforms(src_tensor)
1554
+ ground_tensor = self.video_transforms(ground_tensor)
1555
+ tgt_tensor = self.video_transforms(tgt_tensor)
1556
+ else:
1557
+ src_tensor = src_frames
1558
+ ground_tensor = ground_frames
1559
+ tgt_tensor = tgt_frames
1560
+ # Prepare text with template
1561
+ ground_instr = self._derive_ground_instruction(raw_text)
1562
+ if self.instruction_template and "{edit_instruction}" in self.instruction_template:
1563
+ text = self.instruction_template.format(
1564
+ edit_instruction=raw_text,
1565
+ ground_instruction=ground_instr
1566
+ )
1567
+ else:
1568
+ text = raw_text
1569
+
1570
+ # Random text drop
1571
+ if random.random() < self.text_drop_ratio:
1572
+ text = ''
1573
+
1574
+ return src_tensor, ground_tensor, tgt_tensor, text, 'video'
1575
+
1576
+ else:
1577
+ # Image pair branch (simple like ImageVideoEditDataset)
1578
+ src_img_rel = data_info.get('original_image')
1579
+ tgt_img_rel = data_info.get('edited_image')
1580
+ if src_img_rel is None or tgt_img_rel is None:
1581
+ raise ValueError('Missing original_image/edited_image for image sample')
1582
+
1583
+ if self.data_root is not None:
1584
+ src_img_path = os.path.join(self.data_root, src_img_rel)
1585
+ tgt_img_path = os.path.join(self.data_root, tgt_img_rel)
1586
+ else:
1587
+ src_img_path = src_img_rel
1588
+ tgt_img_path = tgt_img_rel
1589
+
1590
+ src_img = Image.open(src_img_path).convert('RGB')
1591
+ tgt_img = Image.open(tgt_img_path).convert('RGB')
1592
+
1593
+ if not self.enable_bucket:
1594
+ # Apply transforms and add frame dimension
1595
+ src_tensor = self.image_transforms(src_img).unsqueeze(0) # (1, C, H, W)
1596
+ tgt_tensor = self.image_transforms(tgt_img).unsqueeze(0) # (1, C, H, W)
1597
+ else:
1598
+ # For bucket mode, keep as numpy and add frame dimension
1599
+ src_tensor = np.expand_dims(np.array(src_img), axis=0) # (1, H, W, C)
1600
+ tgt_tensor = np.expand_dims(np.array(tgt_img), axis=0) # (1, H, W, C)
1601
+
1602
+ # Apply instruction template if available
1603
+ if self.instruction_template and "{edit_instruction}" in self.instruction_template:
1604
+ text = self.instruction_template.format(edit_instruction=raw_text, ground_instruction="")
1605
+ else:
1606
+ text = raw_text
1607
+
1608
+ if random.random() < self.text_drop_ratio:
1609
+ text = ''
1610
+
1611
+ # For images, ground_tensor is None
1612
+ return src_tensor, None, tgt_tensor, text, 'image'
1613
+
1614
+ def __len__(self):
1615
+ return self.length
1616
+
1617
+ def __getitem__(self, idx):
1618
+ data_info = self.dataset[idx % len(self.dataset)]
1619
+ data_type = data_info.get('type', 'video')
1620
+ while True:
1621
+ sample = {}
1622
+ try:
1623
+ data_info_local = self.dataset[idx % len(self.dataset)]
1624
+ data_type_local = data_info_local.get('type', 'video')
1625
+ if data_type_local != data_type:
1626
+ raise ValueError("data_type_local != data_type")
1627
+
1628
+ result = self.get_batch(idx)
1629
+
1630
+ if data_type == 'video':
1631
+ src_vals, ground_vals, tgt_vals, name, data_type = result
1632
+ sample["pixel_values_src_video"] = src_vals
1633
+ sample["pixel_values_ground_video"] = ground_vals
1634
+ sample["pixel_values_tgt_video"] = tgt_vals
1635
+ else:
1636
+ src_vals, _, tgt_vals, name, data_type = result
1637
+ sample["pixel_values_src_image"] = src_vals
1638
+ sample["pixel_values_tgt_image"] = tgt_vals
1639
+
1640
+ sample["text"] = name
1641
+ sample["data_type"] = data_type
1642
+ sample["idx"] = idx
1643
+
1644
+ if len(sample) > 0:
1645
+ break
1646
+ except Exception as e:
1647
+ print(e, self.dataset[idx % len(self.dataset)])
1648
+ idx = random.randint(0, self.length-1)
1649
+
1650
+ return sample
1651
+
1652
+ def padding_image(images, new_width, new_height):
1653
+ new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255))
1654
+
1655
+ aspect_ratio = images.width / images.height
1656
+ if new_width / new_height > 1:
1657
+ if aspect_ratio > new_width / new_height:
1658
+ new_img_width = new_width
1659
+ new_img_height = int(new_img_width / aspect_ratio)
1660
+ else:
1661
+ new_img_height = new_height
1662
+ new_img_width = int(new_img_height * aspect_ratio)
1663
+ else:
1664
+ if aspect_ratio > new_width / new_height:
1665
+ new_img_width = new_width
1666
+ new_img_height = int(new_img_width / aspect_ratio)
1667
+ else:
1668
+ new_img_height = new_height
1669
+ new_img_width = int(new_img_height * aspect_ratio)
1670
+
1671
+ resized_img = images.resize((new_img_width, new_img_height))
1672
+
1673
+ paste_x = (new_width - new_img_width) // 2
1674
+ paste_y = (new_height - new_img_height) // 2
1675
+
1676
+ new_image.paste(resized_img, (paste_x, paste_y))
1677
+
1678
+ return new_image
1679
+
1680
+ class ImageVideoControlDataset(Dataset):
1681
+ def __init__(
1682
+ self,
1683
+ ann_path, data_root=None,
1684
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
1685
+ image_sample_size=512,
1686
+ video_repeat=0,
1687
+ text_drop_ratio=0.1,
1688
+ enable_bucket=False,
1689
+ video_length_drop_start=0.1,
1690
+ video_length_drop_end=0.9,
1691
+ enable_inpaint=False,
1692
+ enable_camera_info=False,
1693
+ ):
1694
+ # Loading annotations from files
1695
+ if ann_path.endswith('.csv'):
1696
+ with open(ann_path, 'r') as csvfile:
1697
+ dataset = list(csv.DictReader(csvfile))
1698
+ elif ann_path.endswith('.json'):
1699
+ dataset = json.load(open(ann_path))
1700
+
1701
+ self.data_root = data_root
1702
+
1703
+ # It's used to balance num of images and videos.
1704
+ if video_repeat > 0:
1705
+ self.dataset = []
1706
+ for data in dataset:
1707
+ if data.get('type', 'image') != 'video':
1708
+ self.dataset.append(data)
1709
+
1710
+ for _ in range(video_repeat):
1711
+ for data in dataset:
1712
+ if data.get('type', 'image') == 'video':
1713
+ self.dataset.append(data)
1714
+ else:
1715
+ self.dataset = dataset
1716
+ del dataset
1717
+
1718
+ self.length = len(self.dataset)
1719
+ print(f"data scale: {self.length}")
1720
+ # TODO: enable bucket training
1721
+ self.enable_bucket = enable_bucket
1722
+ self.text_drop_ratio = text_drop_ratio
1723
+ self.enable_inpaint = enable_inpaint
1724
+ self.enable_camera_info = enable_camera_info
1725
+
1726
+ self.video_length_drop_start = video_length_drop_start
1727
+ self.video_length_drop_end = video_length_drop_end
1728
+
1729
+ # Video params
1730
+ self.video_sample_stride = video_sample_stride
1731
+ self.video_sample_n_frames = video_sample_n_frames
1732
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
1733
+ self.video_transforms = transforms.Compose(
1734
+ [
1735
+ transforms.Resize(min(self.video_sample_size)),
1736
+ transforms.CenterCrop(self.video_sample_size),
1737
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
1738
+ ]
1739
+ )
1740
+ if self.enable_camera_info:
1741
+ self.video_transforms_camera = transforms.Compose(
1742
+ [
1743
+ transforms.Resize(min(self.video_sample_size)),
1744
+ transforms.CenterCrop(self.video_sample_size)
1745
+ ]
1746
+ )
1747
+
1748
+ # Image params
1749
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
1750
+ self.image_transforms = transforms.Compose([
1751
+ transforms.Resize(min(self.image_sample_size)),
1752
+ transforms.CenterCrop(self.image_sample_size),
1753
+ transforms.ToTensor(),
1754
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
1755
+ ])
1756
+
1757
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
1758
+
1759
+ def get_batch(self, idx):
1760
+ data_info = self.dataset[idx % len(self.dataset)]
1761
+ video_id, text = data_info['file_path'], data_info['text']
1762
+
1763
+ if data_info.get('type', 'image')=='video':
1764
+ if self.data_root is None:
1765
+ video_dir = video_id
1766
+ else:
1767
+ video_dir = os.path.join(self.data_root, video_id)
1768
+
1769
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
1770
+ min_sample_n_frames = min(
1771
+ self.video_sample_n_frames,
1772
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
1773
+ )
1774
+ if min_sample_n_frames == 0:
1775
+ raise ValueError(f"No Frames in video.")
1776
+
1777
+ video_length = int(self.video_length_drop_end * len(video_reader))
1778
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
1779
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
1780
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
1781
+
1782
+ try:
1783
+ sample_args = (video_reader, batch_index)
1784
+ pixel_values = func_timeout(
1785
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
1786
+ )
1787
+ resized_frames = []
1788
+ for i in range(len(pixel_values)):
1789
+ frame = pixel_values[i]
1790
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
1791
+ resized_frames.append(resized_frame)
1792
+ pixel_values = np.array(resized_frames)
1793
+ except FunctionTimedOut:
1794
+ raise ValueError(f"Read {idx} timeout.")
1795
+ except Exception as e:
1796
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
1797
+
1798
+ if not self.enable_bucket:
1799
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
1800
+ pixel_values = pixel_values / 255.
1801
+ del video_reader
1802
+ else:
1803
+ pixel_values = pixel_values
1804
+
1805
+ if not self.enable_bucket:
1806
+ pixel_values = self.video_transforms(pixel_values)
1807
+
1808
+ # Random use no text generation
1809
+ if random.random() < self.text_drop_ratio:
1810
+ text = ''
1811
+
1812
+ control_video_id = data_info['control_file_path']
1813
+
1814
+ if self.data_root is None:
1815
+ control_video_id = control_video_id
1816
+ else:
1817
+ control_video_id = os.path.join(self.data_root, control_video_id)
1818
+
1819
+ if self.enable_camera_info:
1820
+ if control_video_id.lower().endswith('.txt'):
1821
+ if not self.enable_bucket:
1822
+ control_pixel_values = torch.zeros_like(pixel_values)
1823
+
1824
+ control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0])
1825
+ control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous()
1826
+ control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)
1827
+ control_camera_values = self.video_transforms_camera(control_camera_values)
1828
+ else:
1829
+ control_pixel_values = np.zeros_like(pixel_values)
1830
+
1831
+ control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True)
1832
+ control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0)
1833
+ control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0]
1834
+ control_camera_values = np.array([control_camera_values[index] for index in batch_index])
1835
+ else:
1836
+ if not self.enable_bucket:
1837
+ control_pixel_values = torch.zeros_like(pixel_values)
1838
+ control_camera_values = None
1839
+ else:
1840
+ control_pixel_values = np.zeros_like(pixel_values)
1841
+ control_camera_values = None
1842
+ else:
1843
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
1844
+ try:
1845
+ sample_args = (control_video_reader, batch_index)
1846
+ control_pixel_values = func_timeout(
1847
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
1848
+ )
1849
+ resized_frames = []
1850
+ for i in range(len(control_pixel_values)):
1851
+ frame = control_pixel_values[i]
1852
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
1853
+ resized_frames.append(resized_frame)
1854
+ control_pixel_values = np.array(resized_frames)
1855
+ except FunctionTimedOut:
1856
+ raise ValueError(f"Read {idx} timeout.")
1857
+ except Exception as e:
1858
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
1859
+
1860
+ if not self.enable_bucket:
1861
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
1862
+ control_pixel_values = control_pixel_values / 255.
1863
+ del control_video_reader
1864
+ else:
1865
+ control_pixel_values = control_pixel_values
1866
+
1867
+ if not self.enable_bucket:
1868
+ control_pixel_values = self.video_transforms(control_pixel_values)
1869
+ control_camera_values = None
1870
+
1871
+ return pixel_values, control_pixel_values, control_camera_values, text, "video"
1872
+ else:
1873
+ image_path, text = data_info['file_path'], data_info['text']
1874
+ if self.data_root is not None:
1875
+ image_path = os.path.join(self.data_root, image_path)
1876
+ image = Image.open(image_path).convert('RGB')
1877
+ if not self.enable_bucket:
1878
+ image = self.image_transforms(image).unsqueeze(0)
1879
+ else:
1880
+ image = np.expand_dims(np.array(image), 0)
1881
+
1882
+ if random.random() < self.text_drop_ratio:
1883
+ text = ''
1884
+
1885
+ control_image_id = data_info['control_file_path']
1886
+
1887
+ if self.image_root is None:
1888
+ control_image_id = control_image_id
1889
+ else:
1890
+ control_image_id = os.path.join(self.image_root, control_image_id)
1891
+
1892
+ control_image = Image.open(control_image_id).convert('RGB')
1893
+ if not self.enable_bucket:
1894
+ control_image = self.image_transforms(control_image).unsqueeze(0)
1895
+ else:
1896
+ control_image = np.expand_dims(np.array(control_image), 0)
1897
+ return image, control_image, None, text, 'image'
1898
+ def __len__(self):
1899
+ return self.length
1900
+
1901
+ def __getitem__(self, idx):
1902
+ data_info = self.dataset[idx % len(self.dataset)]
1903
+ data_type = data_info.get('type', 'image')
1904
+ while True:
1905
+ sample = {}
1906
+ try:
1907
+ data_info_local = self.dataset[idx % len(self.dataset)]
1908
+ data_type_local = data_info_local.get('type', 'image')
1909
+ if data_type_local != data_type:
1910
+ raise ValueError("data_type_local != data_type")
1911
+
1912
+ pixel_values, control_pixel_values, control_camera_values, name, data_type = self.get_batch(idx)
1913
+
1914
+ sample["pixel_values"] = pixel_values
1915
+ sample["control_pixel_values"] = control_pixel_values
1916
+ sample["text"] = name
1917
+ sample["data_type"] = data_type
1918
+ sample["idx"] = idx
1919
+
1920
+ if self.enable_camera_info:
1921
+ sample["control_camera_values"] = control_camera_values
1922
+
1923
+ if len(sample) > 0:
1924
+ break
1925
+ except Exception as e:
1926
+ print(e, self.dataset[idx % len(self.dataset)])
1927
+ idx = random.randint(0, self.length-1)
1928
+
1929
+ if self.enable_inpaint and not self.enable_bucket:
1930
+ mask = get_random_mask(pixel_values.size())
1931
+ mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
1932
+ sample["mask_pixel_values"] = mask_pixel_values
1933
+ sample["mask"] = mask
1934
+
1935
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
1936
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
1937
+ sample["clip_pixel_values"] = clip_pixel_values
1938
+
1939
+ return sample
videox_fun/data/dataset_video.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import gc
3
+ import io
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ from contextlib import contextmanager
9
+ from threading import Thread
10
+
11
+ import albumentations
12
+ import cv2
13
+ import numpy as np
14
+ import torch
15
+ import torchvision.transforms as transforms
16
+ from decord import VideoReader
17
+ from einops import rearrange
18
+ from func_timeout import FunctionTimedOut, func_timeout
19
+ from PIL import Image
20
+ from torch.utils.data import BatchSampler, Sampler
21
+ from torch.utils.data.dataset import Dataset
22
+
23
+ VIDEO_READER_TIMEOUT = 20
24
+
25
+ def get_random_mask(shape):
26
+ f, c, h, w = shape
27
+
28
+ mask_index = np.random.randint(0, 4)
29
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
30
+ if mask_index == 0:
31
+ mask[1:, :, :, :] = 1
32
+ elif mask_index == 1:
33
+ mask_frame_index = 1
34
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
35
+ elif mask_index == 2:
36
+ center_x = torch.randint(0, w, (1,)).item()
37
+ center_y = torch.randint(0, h, (1,)).item()
38
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
39
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
40
+
41
+ start_x = max(center_x - block_size_x // 2, 0)
42
+ end_x = min(center_x + block_size_x // 2, w)
43
+ start_y = max(center_y - block_size_y // 2, 0)
44
+ end_y = min(center_y + block_size_y // 2, h)
45
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
46
+ elif mask_index == 3:
47
+ center_x = torch.randint(0, w, (1,)).item()
48
+ center_y = torch.randint(0, h, (1,)).item()
49
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
50
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
51
+
52
+ start_x = max(center_x - block_size_x // 2, 0)
53
+ end_x = min(center_x + block_size_x // 2, w)
54
+ start_y = max(center_y - block_size_y // 2, 0)
55
+ end_y = min(center_y + block_size_y // 2, h)
56
+
57
+ mask_frame_before = np.random.randint(0, f // 2)
58
+ mask_frame_after = np.random.randint(f // 2, f)
59
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
60
+ else:
61
+ raise ValueError(f"The mask_index {mask_index} is not define")
62
+ return mask
63
+
64
+
65
+ @contextmanager
66
+ def VideoReader_contextmanager(*args, **kwargs):
67
+ vr = VideoReader(*args, **kwargs)
68
+ try:
69
+ yield vr
70
+ finally:
71
+ del vr
72
+ gc.collect()
73
+
74
+
75
+ def get_video_reader_batch(video_reader, batch_index):
76
+ frames = video_reader.get_batch(batch_index).asnumpy()
77
+ return frames
78
+
79
+
80
+ class WebVid10M(Dataset):
81
+ def __init__(
82
+ self,
83
+ csv_path, video_folder,
84
+ sample_size=256, sample_stride=4, sample_n_frames=16,
85
+ enable_bucket=False, enable_inpaint=False, is_image=False,
86
+ ):
87
+ print(f"loading annotations from {csv_path} ...")
88
+ with open(csv_path, 'r') as csvfile:
89
+ self.dataset = list(csv.DictReader(csvfile))
90
+ self.length = len(self.dataset)
91
+ print(f"data scale: {self.length}")
92
+
93
+ self.video_folder = video_folder
94
+ self.sample_stride = sample_stride
95
+ self.sample_n_frames = sample_n_frames
96
+ self.enable_bucket = enable_bucket
97
+ self.enable_inpaint = enable_inpaint
98
+ self.is_image = is_image
99
+
100
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
101
+ self.pixel_transforms = transforms.Compose([
102
+ transforms.Resize(sample_size[0]),
103
+ transforms.CenterCrop(sample_size),
104
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
105
+ ])
106
+
107
+ def get_batch(self, idx):
108
+ video_dict = self.dataset[idx]
109
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
110
+
111
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
112
+ video_reader = VideoReader(video_dir)
113
+ video_length = len(video_reader)
114
+
115
+ if not self.is_image:
116
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
117
+ start_idx = random.randint(0, video_length - clip_length)
118
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
119
+ else:
120
+ batch_index = [random.randint(0, video_length - 1)]
121
+
122
+ if not self.enable_bucket:
123
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
124
+ pixel_values = pixel_values / 255.
125
+ del video_reader
126
+ else:
127
+ pixel_values = video_reader.get_batch(batch_index).asnumpy()
128
+
129
+ if self.is_image:
130
+ pixel_values = pixel_values[0]
131
+ return pixel_values, name
132
+
133
+ def __len__(self):
134
+ return self.length
135
+
136
+ def __getitem__(self, idx):
137
+ while True:
138
+ try:
139
+ pixel_values, name = self.get_batch(idx)
140
+ break
141
+
142
+ except Exception as e:
143
+ print("Error info:", e)
144
+ idx = random.randint(0, self.length-1)
145
+
146
+ if not self.enable_bucket:
147
+ pixel_values = self.pixel_transforms(pixel_values)
148
+ if self.enable_inpaint:
149
+ mask = get_random_mask(pixel_values.size())
150
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
151
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
152
+ else:
153
+ sample = dict(pixel_values=pixel_values, text=name)
154
+ return sample
155
+
156
+
157
+ class VideoDataset(Dataset):
158
+ def __init__(
159
+ self,
160
+ json_path, video_folder=None,
161
+ sample_size=256, sample_stride=4, sample_n_frames=16,
162
+ enable_bucket=False, enable_inpaint=False
163
+ ):
164
+ print(f"loading annotations from {json_path} ...")
165
+ self.dataset = json.load(open(json_path, 'r'))
166
+ self.length = len(self.dataset)
167
+ print(f"data scale: {self.length}")
168
+
169
+ self.video_folder = video_folder
170
+ self.sample_stride = sample_stride
171
+ self.sample_n_frames = sample_n_frames
172
+ self.enable_bucket = enable_bucket
173
+ self.enable_inpaint = enable_inpaint
174
+
175
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
176
+ self.pixel_transforms = transforms.Compose(
177
+ [
178
+ transforms.Resize(sample_size[0]),
179
+ transforms.CenterCrop(sample_size),
180
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
181
+ ]
182
+ )
183
+
184
+ def get_batch(self, idx):
185
+ video_dict = self.dataset[idx]
186
+ video_id, name = video_dict['file_path'], video_dict['text']
187
+
188
+ if self.video_folder is None:
189
+ video_dir = video_id
190
+ else:
191
+ video_dir = os.path.join(self.video_folder, video_id)
192
+
193
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
194
+ video_length = len(video_reader)
195
+
196
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
197
+ start_idx = random.randint(0, video_length - clip_length)
198
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
199
+
200
+ try:
201
+ sample_args = (video_reader, batch_index)
202
+ pixel_values = func_timeout(
203
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
204
+ )
205
+ except FunctionTimedOut:
206
+ raise ValueError(f"Read {idx} timeout.")
207
+ except Exception as e:
208
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
209
+
210
+ if not self.enable_bucket:
211
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
212
+ pixel_values = pixel_values / 255.
213
+ del video_reader
214
+ else:
215
+ pixel_values = pixel_values
216
+
217
+ return pixel_values, name
218
+
219
+ def __len__(self):
220
+ return self.length
221
+
222
+ def __getitem__(self, idx):
223
+ while True:
224
+ try:
225
+ pixel_values, name = self.get_batch(idx)
226
+ break
227
+
228
+ except Exception as e:
229
+ print("Error info:", e)
230
+ idx = random.randint(0, self.length-1)
231
+
232
+ if not self.enable_bucket:
233
+ pixel_values = self.pixel_transforms(pixel_values)
234
+ if self.enable_inpaint:
235
+ mask = get_random_mask(pixel_values.size())
236
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
237
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
238
+ else:
239
+ sample = dict(pixel_values=pixel_values, text=name)
240
+ return sample
241
+
242
+
243
+ if __name__ == "__main__":
244
+ if 1:
245
+ dataset = VideoDataset(
246
+ json_path="/home/zhoumo.xjq/disk3/datasets/webvidval/results_2M_val.json",
247
+ sample_size=256,
248
+ sample_stride=4, sample_n_frames=16,
249
+ )
250
+
251
+ if 0:
252
+ dataset = WebVid10M(
253
+ csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
254
+ video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
255
+ sample_size=256,
256
+ sample_stride=4, sample_n_frames=16,
257
+ is_image=False,
258
+ )
259
+
260
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
261
+ for idx, batch in enumerate(dataloader):
262
+ print(batch["pixel_values"].shape, len(batch["text"]))
videox_fun/dist/__init__.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+
3
+ from .cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0
4
+ from .fsdp import shard_model
5
+ from .fuser import (get_sequence_parallel_rank,
6
+ get_sequence_parallel_world_size, get_sp_group,
7
+ get_world_group, init_distributed_environment,
8
+ initialize_model_parallel, set_multi_gpus_devices,
9
+ xFuserLongContextAttention)
10
+ from .wan_xfuser import usp_attn_forward, usp_attn_s2v_forward
11
+ from .qwen_xfuser import QwenImageMultiGPUsAttnProcessor2_0
12
+ from .flux_xfuser import FluxMultiGPUsAttnProcessor2_0
13
+
14
+ # The pai_fuser is an internally developed acceleration package, which can be used on PAI.
15
+ if importlib.util.find_spec("paifuser") is not None:
16
+ # --------------------------------------------------------------- #
17
+ # The simple_wrapper is used to solve the problem
18
+ # about conflicts between cython and torch.compile
19
+ # --------------------------------------------------------------- #
20
+ def simple_wrapper(func):
21
+ def inner(*args, **kwargs):
22
+ return func(*args, **kwargs)
23
+ return inner
24
+
25
+ # --------------------------------------------------------------- #
26
+ # Sparse Attention Kernel
27
+ # --------------------------------------------------------------- #
28
+ from paifuser.models import parallel_magvit_vae
29
+ from paifuser.ops import wan_usp_sparse_attention_wrapper
30
+ from . import wan_xfuser
31
+
32
+ # --------------------------------------------------------------- #
33
+ # Sparse Attention
34
+ # --------------------------------------------------------------- #
35
+ usp_sparse_attn_wrap_forward = simple_wrapper(wan_usp_sparse_attention_wrapper()(wan_xfuser.usp_attn_forward))
36
+ wan_xfuser.usp_attn_forward = usp_sparse_attn_wrap_forward
37
+ usp_attn_forward = usp_sparse_attn_wrap_forward
38
+ print("Import PAI VAE Turbo and Sparse Attention")
39
+
40
+ # --------------------------------------------------------------- #
41
+ # Fast Rope Kernel
42
+ # --------------------------------------------------------------- #
43
+ import types
44
+ import torch
45
+ from paifuser.ops import (ENABLE_KERNEL, usp_fast_rope_apply_qk,
46
+ usp_rope_apply_real_qk)
47
+
48
+ def deepcopy_function(f):
49
+ return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)
50
+
51
+ local_rope_apply_qk = deepcopy_function(wan_xfuser.rope_apply_qk)
52
+
53
+ if ENABLE_KERNEL:
54
+ def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
55
+ if torch.is_grad_enabled():
56
+ return local_rope_apply_qk(q, k, grid_sizes, freqs)
57
+ else:
58
+ return usp_fast_rope_apply_qk(q, k, grid_sizes, freqs)
59
+
60
+ else:
61
+ def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
62
+ return usp_rope_apply_real_qk(q, k, grid_sizes, freqs)
63
+
64
+ wan_xfuser.rope_apply_qk = adaptive_fast_usp_rope_apply_qk
65
+ rope_apply_qk = adaptive_fast_usp_rope_apply_qk
66
+ print("Import PAI Fast rope")
videox_fun/dist/cogvideox_xfuser.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention import Attention
6
+ from diffusers.models.embeddings import apply_rotary_emb
7
+
8
+ from .fuser import (get_sequence_parallel_rank,
9
+ get_sequence_parallel_world_size, get_sp_group,
10
+ init_distributed_environment, initialize_model_parallel,
11
+ xFuserLongContextAttention)
12
+
13
+ class CogVideoXMultiGPUsAttnProcessor2_0:
14
+ r"""
15
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
16
+ query and key vectors, but does not include spatial normalization.
17
+ """
18
+
19
+ def __init__(self):
20
+ if xFuserLongContextAttention is not None:
21
+ try:
22
+ self.hybrid_seq_parallel_attn = xFuserLongContextAttention()
23
+ except Exception:
24
+ self.hybrid_seq_parallel_attn = None
25
+ else:
26
+ self.hybrid_seq_parallel_attn = None
27
+ if not hasattr(F, "scaled_dot_product_attention"):
28
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
29
+
30
+ def __call__(
31
+ self,
32
+ attn: Attention,
33
+ hidden_states: torch.Tensor,
34
+ encoder_hidden_states: torch.Tensor,
35
+ attention_mask: Optional[torch.Tensor] = None,
36
+ image_rotary_emb: Optional[torch.Tensor] = None,
37
+ ) -> torch.Tensor:
38
+ text_seq_length = encoder_hidden_states.size(1)
39
+
40
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
41
+
42
+ batch_size, sequence_length, _ = (
43
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
44
+ )
45
+
46
+ if attention_mask is not None:
47
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
48
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
49
+
50
+ query = attn.to_q(hidden_states)
51
+ key = attn.to_k(hidden_states)
52
+ value = attn.to_v(hidden_states)
53
+
54
+ inner_dim = key.shape[-1]
55
+ head_dim = inner_dim // attn.heads
56
+
57
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
58
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
59
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
60
+
61
+ if attn.norm_q is not None:
62
+ query = attn.norm_q(query)
63
+ if attn.norm_k is not None:
64
+ key = attn.norm_k(key)
65
+
66
+ # Apply RoPE if needed
67
+ if image_rotary_emb is not None:
68
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
69
+ if not attn.is_cross_attention:
70
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
71
+
72
+ if self.hybrid_seq_parallel_attn is None:
73
+ hidden_states = F.scaled_dot_product_attention(
74
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
75
+ )
76
+ hidden_states = hidden_states
77
+ else:
78
+ img_q = query[:, :, text_seq_length:].transpose(1, 2)
79
+ txt_q = query[:, :, :text_seq_length].transpose(1, 2)
80
+ img_k = key[:, :, text_seq_length:].transpose(1, 2)
81
+ txt_k = key[:, :, :text_seq_length].transpose(1, 2)
82
+ img_v = value[:, :, text_seq_length:].transpose(1, 2)
83
+ txt_v = value[:, :, :text_seq_length].transpose(1, 2)
84
+
85
+ hidden_states = self.hybrid_seq_parallel_attn(
86
+ None,
87
+ img_q, img_k, img_v, dropout_p=0.0, causal=False,
88
+ joint_tensor_query=txt_q,
89
+ joint_tensor_key=txt_k,
90
+ joint_tensor_value=txt_v,
91
+ joint_strategy='front',
92
+ ).transpose(1, 2)
93
+
94
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
95
+
96
+ # linear proj
97
+ hidden_states = attn.to_out[0](hidden_states)
98
+ # dropout
99
+ hidden_states = attn.to_out[1](hidden_states)
100
+
101
+ encoder_hidden_states, hidden_states = hidden_states.split(
102
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
103
+ )
104
+ return hidden_states, encoder_hidden_states
105
+
videox_fun/dist/flux_xfuser.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention_processor import Attention
6
+
7
+ from .fuser import xFuserLongContextAttention
8
+
9
+
10
+ def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
11
+ query = attn.to_q(hidden_states)
12
+ key = attn.to_k(hidden_states)
13
+ value = attn.to_v(hidden_states)
14
+
15
+ encoder_query = encoder_key = encoder_value = None
16
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
17
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
18
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
19
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
20
+
21
+ return query, key, value, encoder_query, encoder_key, encoder_value
22
+
23
+
24
+ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
25
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
26
+
27
+
28
+ def apply_rotary_emb(
29
+ x: torch.Tensor,
30
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
31
+ use_real: bool = True,
32
+ use_real_unbind_dim: int = -1,
33
+ sequence_dim: int = 2,
34
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
35
+ """
36
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
37
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
38
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
39
+ tensors contain rotary embeddings and are returned as real tensors.
40
+
41
+ Args:
42
+ x (`torch.Tensor`):
43
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
44
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
45
+
46
+ Returns:
47
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
48
+ """
49
+ if use_real:
50
+ cos, sin = freqs_cis # [S, D]
51
+ if sequence_dim == 2:
52
+ cos = cos[None, None, :, :]
53
+ sin = sin[None, None, :, :]
54
+ elif sequence_dim == 1:
55
+ cos = cos[None, :, None, :]
56
+ sin = sin[None, :, None, :]
57
+ else:
58
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
59
+
60
+ cos, sin = cos.to(x.device), sin.to(x.device)
61
+
62
+ if use_real_unbind_dim == -1:
63
+ # Used for flux, cogvideox, hunyuan-dit
64
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
65
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
66
+ elif use_real_unbind_dim == -2:
67
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
68
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
69
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
70
+ else:
71
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
72
+
73
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
74
+
75
+ return out
76
+ else:
77
+ # used for lumina
78
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
79
+ freqs_cis = freqs_cis.unsqueeze(2)
80
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
81
+
82
+ return x_out.type_as(x)
83
+
84
+
85
+ class FluxMultiGPUsAttnProcessor2_0:
86
+ r"""
87
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
88
+ query and key vectors, but does not include spatial normalization.
89
+ """
90
+
91
+ def __init__(self):
92
+ if not hasattr(F, "scaled_dot_product_attention"):
93
+ raise ImportError("FluxMultiGPUsAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
94
+
95
+ def __call__(
96
+ self,
97
+ attn: "FluxAttention",
98
+ hidden_states: torch.Tensor,
99
+ encoder_hidden_states: torch.Tensor = None,
100
+ attention_mask: Optional[torch.Tensor] = None,
101
+ image_rotary_emb: Optional[torch.Tensor] = None,
102
+ text_seq_len: int = None,
103
+ ) -> torch.FloatTensor:
104
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
105
+ attn, hidden_states, encoder_hidden_states
106
+ )
107
+
108
+ query = query.unflatten(-1, (attn.heads, -1))
109
+ key = key.unflatten(-1, (attn.heads, -1))
110
+ value = value.unflatten(-1, (attn.heads, -1))
111
+
112
+ query = attn.norm_q(query)
113
+ key = attn.norm_k(key)
114
+
115
+ if attn.added_kv_proj_dim is not None:
116
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
117
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
118
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
119
+
120
+ encoder_query = attn.norm_added_q(encoder_query)
121
+ encoder_key = attn.norm_added_k(encoder_key)
122
+
123
+ query = torch.cat([encoder_query, query], dim=1)
124
+ key = torch.cat([encoder_key, key], dim=1)
125
+ value = torch.cat([encoder_value, value], dim=1)
126
+
127
+ if image_rotary_emb is not None:
128
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
129
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
130
+
131
+ text_seq_len = encoder_query.shape[1]
132
+ txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len]
133
+ img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:]
134
+ else:
135
+ if image_rotary_emb is not None:
136
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
137
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
138
+ txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len]
139
+ img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:]
140
+
141
+ half_dtypes = (torch.float16, torch.bfloat16)
142
+ def half(x):
143
+ return x if x.dtype in half_dtypes else x.to(dtype)
144
+
145
+ hidden_states = xFuserLongContextAttention()(
146
+ None,
147
+ half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
148
+ joint_tensor_query=half(txt_query) if txt_query is not None else None,
149
+ joint_tensor_key=half(txt_key) if txt_key is not None else None,
150
+ joint_tensor_value=half(txt_value) if txt_value is not None else None,
151
+ joint_strategy='front',
152
+ )
153
+
154
+ # Reshape back
155
+ hidden_states = hidden_states.flatten(2, 3)
156
+ hidden_states = hidden_states.to(img_query.dtype)
157
+
158
+ if encoder_hidden_states is not None:
159
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
160
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
161
+ )
162
+ hidden_states = attn.to_out[0](hidden_states)
163
+ hidden_states = attn.to_out[1](hidden_states)
164
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
165
+
166
+ return hidden_states, encoder_hidden_states
167
+ else:
168
+ return hidden_states
videox_fun/dist/fsdp.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyied from https://github.com/Wan-Video/Wan2.1/blob/main/wan/distributed/fsdp.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import gc
4
+ from functools import partial
5
+
6
+ import torch
7
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
9
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
10
+ from torch.distributed.utils import _free_storage
11
+
12
+
13
+ def shard_model(
14
+ model,
15
+ device_id,
16
+ param_dtype=torch.bfloat16,
17
+ reduce_dtype=torch.float32,
18
+ buffer_dtype=torch.float32,
19
+ process_group=None,
20
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
21
+ sync_module_states=True,
22
+ module_to_wrapper=None,
23
+ ):
24
+ model = FSDP(
25
+ module=model,
26
+ process_group=process_group,
27
+ sharding_strategy=sharding_strategy,
28
+ auto_wrap_policy=partial(
29
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks if module_to_wrapper is None else module_to_wrapper),
30
+ mixed_precision=MixedPrecision(
31
+ param_dtype=param_dtype,
32
+ reduce_dtype=reduce_dtype,
33
+ buffer_dtype=buffer_dtype),
34
+ device_id=device_id,
35
+ sync_module_states=sync_module_states)
36
+ return model
37
+
38
+ def free_model(model):
39
+ for m in model.modules():
40
+ if isinstance(m, FSDP):
41
+ _free_storage(m._handle.flat_param.data)
42
+ del model
43
+ gc.collect()
44
+ torch.cuda.empty_cache()
videox_fun/dist/fuser.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+
6
+ try:
7
+ # The pai_fuser is an internally developed acceleration package, which can be used on PAI.
8
+ if importlib.util.find_spec("paifuser") is not None:
9
+ import paifuser
10
+ from paifuser.xfuser.core.distributed import (
11
+ get_sequence_parallel_rank, get_sequence_parallel_world_size,
12
+ get_sp_group, get_world_group, init_distributed_environment,
13
+ initialize_model_parallel)
14
+ from paifuser.xfuser.core.long_ctx_attention import \
15
+ xFuserLongContextAttention
16
+ print("Import PAI DiT Turbo")
17
+ else:
18
+ import xfuser
19
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
20
+ get_sequence_parallel_world_size,
21
+ get_sp_group, get_world_group,
22
+ init_distributed_environment,
23
+ initialize_model_parallel)
24
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
25
+ print("Xfuser import sucessful")
26
+ except Exception as ex:
27
+ get_sequence_parallel_world_size = None
28
+ get_sequence_parallel_rank = None
29
+ xFuserLongContextAttention = None
30
+ get_sp_group = None
31
+ get_world_group = None
32
+ init_distributed_environment = None
33
+ initialize_model_parallel = None
34
+
35
+ def set_multi_gpus_devices(ulysses_degree, ring_degree, classifier_free_guidance_degree=1):
36
+ if ulysses_degree > 1 or ring_degree > 1 or classifier_free_guidance_degree > 1:
37
+ if get_sp_group is None:
38
+ raise RuntimeError("xfuser is not installed.")
39
+ dist.init_process_group("nccl")
40
+ print('parallel inference enabled: ulysses_degree=%d ring_degree=%d classifier_free_guidance_degree=% rank=%d world_size=%d' % (
41
+ ulysses_degree, ring_degree, classifier_free_guidance_degree, dist.get_rank(),
42
+ dist.get_world_size()))
43
+ assert dist.get_world_size() == ring_degree * ulysses_degree * classifier_free_guidance_degree, \
44
+ "number of GPUs(%d) should be equal to ring_degree * ulysses_degree * classifier_free_guidance_degree." % dist.get_world_size()
45
+ init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
46
+ initialize_model_parallel(sequence_parallel_degree=ring_degree * ulysses_degree,
47
+ classifier_free_guidance_degree=classifier_free_guidance_degree,
48
+ ring_degree=ring_degree,
49
+ ulysses_degree=ulysses_degree)
50
+ # device = torch.device("cuda:%d" % dist.get_rank())
51
+ device = torch.device(f"cuda:{get_world_group().local_rank}")
52
+ print('rank=%d device=%s' % (get_world_group().rank, str(device)))
53
+ else:
54
+ device = "cuda"
55
+ return device
videox_fun/dist/qwen_xfuser.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import glob
3
+ import json
4
+ import math
5
+ import os
6
+ import types
7
+ import warnings
8
+ from typing import Any, Dict, List, Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.cuda.amp as amp
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
16
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
17
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
18
+ from diffusers.models.attention import FeedForward
19
+ from diffusers.models.attention_processor import Attention
20
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
21
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
22
+ from diffusers.models.modeling_utils import ModelMixin
23
+ from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
24
+ from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
25
+ scale_lora_layers, unscale_lora_layers)
26
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
27
+ from torch import nn
28
+ from .fuser import (get_sequence_parallel_rank,
29
+ get_sequence_parallel_world_size, get_sp_group,
30
+ init_distributed_environment, initialize_model_parallel,
31
+ xFuserLongContextAttention)
32
+
33
+ def apply_rotary_emb_qwen(
34
+ x: torch.Tensor,
35
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
36
+ use_real: bool = True,
37
+ use_real_unbind_dim: int = -1,
38
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
39
+ """
40
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
41
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
42
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
43
+ tensors contain rotary embeddings and are returned as real tensors.
44
+
45
+ Args:
46
+ x (`torch.Tensor`):
47
+ Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
48
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
49
+
50
+ Returns:
51
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
52
+ """
53
+ if use_real:
54
+ cos, sin = freqs_cis # [S, D]
55
+ cos = cos[None, None]
56
+ sin = sin[None, None]
57
+ cos, sin = cos.to(x.device), sin.to(x.device)
58
+
59
+ if use_real_unbind_dim == -1:
60
+ # Used for flux, cogvideox, hunyuan-dit
61
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
62
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
63
+ elif use_real_unbind_dim == -2:
64
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
65
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
66
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
67
+ else:
68
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
69
+
70
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
71
+
72
+ return out
73
+ else:
74
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
75
+ freqs_cis = freqs_cis.unsqueeze(1)
76
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
77
+
78
+ return x_out.type_as(x)
79
+
80
+
81
+ class QwenImageMultiGPUsAttnProcessor2_0:
82
+ r"""
83
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
84
+ query and key vectors, but does not include spatial normalization.
85
+ """
86
+
87
+ def __init__(self):
88
+ if not hasattr(F, "scaled_dot_product_attention"):
89
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
90
+
91
+ def __call__(
92
+ self,
93
+ attn: Attention,
94
+ hidden_states: torch.FloatTensor, # Image stream
95
+ encoder_hidden_states: torch.FloatTensor = None, # Text stream
96
+ encoder_hidden_states_mask: torch.FloatTensor = None,
97
+ attention_mask: Optional[torch.FloatTensor] = None,
98
+ image_rotary_emb: Optional[torch.Tensor] = None,
99
+ ) -> torch.FloatTensor:
100
+ if encoder_hidden_states is None:
101
+ raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
102
+
103
+ seq_txt = encoder_hidden_states.shape[1]
104
+
105
+ # Compute QKV for image stream (sample projections)
106
+ img_query = attn.to_q(hidden_states)
107
+ img_key = attn.to_k(hidden_states)
108
+ img_value = attn.to_v(hidden_states)
109
+
110
+ # Compute QKV for text stream (context projections)
111
+ txt_query = attn.add_q_proj(encoder_hidden_states)
112
+ txt_key = attn.add_k_proj(encoder_hidden_states)
113
+ txt_value = attn.add_v_proj(encoder_hidden_states)
114
+
115
+ # Reshape for multi-head attention
116
+ img_query = img_query.unflatten(-1, (attn.heads, -1))
117
+ img_key = img_key.unflatten(-1, (attn.heads, -1))
118
+ img_value = img_value.unflatten(-1, (attn.heads, -1))
119
+
120
+ txt_query = txt_query.unflatten(-1, (attn.heads, -1))
121
+ txt_key = txt_key.unflatten(-1, (attn.heads, -1))
122
+ txt_value = txt_value.unflatten(-1, (attn.heads, -1))
123
+
124
+ # Apply QK normalization
125
+ if attn.norm_q is not None:
126
+ img_query = attn.norm_q(img_query)
127
+ if attn.norm_k is not None:
128
+ img_key = attn.norm_k(img_key)
129
+ if attn.norm_added_q is not None:
130
+ txt_query = attn.norm_added_q(txt_query)
131
+ if attn.norm_added_k is not None:
132
+ txt_key = attn.norm_added_k(txt_key)
133
+
134
+ # Apply RoPE
135
+ if image_rotary_emb is not None:
136
+ img_freqs, txt_freqs = image_rotary_emb
137
+ img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
138
+ img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
139
+ txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
140
+ txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
141
+
142
+ # Concatenate for joint attention
143
+ # Order: [text, image]
144
+ # joint_query = torch.cat([txt_query, img_query], dim=1)
145
+ # joint_key = torch.cat([txt_key, img_key], dim=1)
146
+ # joint_value = torch.cat([txt_value, img_value], dim=1)
147
+
148
+ half_dtypes = (torch.float16, torch.bfloat16)
149
+ def half(x):
150
+ return x if x.dtype in half_dtypes else x.to(dtype)
151
+
152
+ joint_hidden_states = xFuserLongContextAttention()(
153
+ None,
154
+ half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
155
+ joint_tensor_query=half(txt_query),
156
+ joint_tensor_key=half(txt_key),
157
+ joint_tensor_value=half(txt_value),
158
+ joint_strategy='front',
159
+ )
160
+
161
+ # Reshape back
162
+ joint_hidden_states = joint_hidden_states.flatten(2, 3)
163
+ joint_hidden_states = joint_hidden_states.to(img_query.dtype)
164
+
165
+ # Split attention outputs back
166
+ txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
167
+ img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
168
+
169
+ # Apply output projections
170
+ img_attn_output = attn.to_out[0](img_attn_output)
171
+ if len(attn.to_out) > 1:
172
+ img_attn_output = attn.to_out[1](img_attn_output) # dropout
173
+
174
+ txt_attn_output = attn.to_add_out(txt_attn_output)
175
+
176
+ return img_attn_output, txt_attn_output
videox_fun/dist/wan_xfuser.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.cuda.amp as amp
3
+
4
+ from .fuser import (get_sequence_parallel_rank,
5
+ get_sequence_parallel_world_size, get_sp_group,
6
+ init_distributed_environment, initialize_model_parallel,
7
+ xFuserLongContextAttention)
8
+
9
+
10
+ def pad_freqs(original_tensor, target_len):
11
+ seq_len, s1, s2 = original_tensor.shape
12
+ pad_size = target_len - seq_len
13
+ padding_tensor = torch.ones(
14
+ pad_size,
15
+ s1,
16
+ s2,
17
+ dtype=original_tensor.dtype,
18
+ device=original_tensor.device)
19
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
20
+ return padded_tensor
21
+
22
+ @amp.autocast(enabled=False)
23
+ @torch.compiler.disable()
24
+ def rope_apply(x, grid_sizes, freqs):
25
+ """
26
+ x: [B, L, N, C].
27
+ grid_sizes: [B, 3].
28
+ freqs: [M, C // 2].
29
+ """
30
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
31
+ # split freqs
32
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
33
+
34
+ # loop over samples
35
+ output = []
36
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
37
+ seq_len = f * h * w
38
+
39
+ # precompute multipliers
40
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape(
41
+ s, n, -1, 2))
42
+ freqs_i = torch.cat([
43
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
44
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
45
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
46
+ ],
47
+ dim=-1).reshape(seq_len, 1, -1)
48
+
49
+ # apply rotary embedding
50
+ sp_size = get_sequence_parallel_world_size()
51
+ sp_rank = get_sequence_parallel_rank()
52
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
53
+ s_per_rank = s
54
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
55
+ s_per_rank), :, :]
56
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
57
+ x_i = torch.cat([x_i, x[i, s:]])
58
+
59
+ # append to collection
60
+ output.append(x_i)
61
+ return torch.stack(output)
62
+
63
+ def rope_apply_qk(q, k, grid_sizes, freqs):
64
+ q = rope_apply(q, grid_sizes, freqs)
65
+ k = rope_apply(k, grid_sizes, freqs)
66
+ return q, k
67
+
68
+ def usp_attn_forward(self,
69
+ x,
70
+ seq_lens,
71
+ grid_sizes,
72
+ freqs,
73
+ dtype=torch.bfloat16,
74
+ t=0):
75
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
76
+ half_dtypes = (torch.float16, torch.bfloat16)
77
+
78
+ def half(x):
79
+ return x if x.dtype in half_dtypes else x.to(dtype)
80
+
81
+ # query, key, value function
82
+ def qkv_fn(x):
83
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
84
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
85
+ v = self.v(x).view(b, s, n, d)
86
+ return q, k, v
87
+
88
+ q, k, v = qkv_fn(x)
89
+ q, k = rope_apply_qk(q, k, grid_sizes, freqs)
90
+
91
+ # TODO: We should use unpaded q,k,v for attention.
92
+ # k_lens = seq_lens // get_sequence_parallel_world_size()
93
+ # if k_lens is not None:
94
+ # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
95
+ # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
96
+ # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
97
+
98
+ x = xFuserLongContextAttention()(
99
+ None,
100
+ query=half(q),
101
+ key=half(k),
102
+ value=half(v),
103
+ window_size=self.window_size)
104
+
105
+ # TODO: padding after attention.
106
+ # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
107
+
108
+ # output
109
+ x = x.flatten(2)
110
+ x = self.o(x)
111
+ return x
112
+
113
+ @amp.autocast(enabled=False)
114
+ @torch.compiler.disable()
115
+ def s2v_rope_apply(x, grid_sizes, freqs):
116
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
117
+ # loop over samples
118
+ output = []
119
+ for i, _ in enumerate(x):
120
+ s = x.size(1)
121
+ # precompute multipliers
122
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
123
+ s, n, -1, 2))
124
+ freqs_i = freqs[i]
125
+ freqs_i_rank = pad_freqs(freqs_i, s)
126
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
127
+ x_i = torch.cat([x_i, x[i, s:]])
128
+ # append to collection
129
+ output.append(x_i)
130
+ return torch.stack(output).float()
131
+
132
+ def s2v_rope_apply_qk(q, k, grid_sizes, freqs):
133
+ q = s2v_rope_apply(q, grid_sizes, freqs)
134
+ k = s2v_rope_apply(k, grid_sizes, freqs)
135
+ return q, k
136
+
137
+ def usp_attn_s2v_forward(self,
138
+ x,
139
+ seq_lens,
140
+ grid_sizes,
141
+ freqs,
142
+ dtype=torch.bfloat16,
143
+ t=0):
144
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
145
+ half_dtypes = (torch.float16, torch.bfloat16)
146
+
147
+ def half(x):
148
+ return x if x.dtype in half_dtypes else x.to(dtype)
149
+
150
+ # query, key, value function
151
+ def qkv_fn(x):
152
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
153
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
154
+ v = self.v(x).view(b, s, n, d)
155
+ return q, k, v
156
+
157
+ q, k, v = qkv_fn(x)
158
+ q, k = s2v_rope_apply_qk(q, k, grid_sizes, freqs)
159
+
160
+ # TODO: We should use unpaded q,k,v for attention.
161
+ # k_lens = seq_lens // get_sequence_parallel_world_size()
162
+ # if k_lens is not None:
163
+ # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
164
+ # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
165
+ # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
166
+
167
+ x = xFuserLongContextAttention()(
168
+ None,
169
+ query=half(q),
170
+ key=half(k),
171
+ value=half(v),
172
+ window_size=self.window_size)
173
+
174
+ # TODO: padding after attention.
175
+ # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
176
+
177
+ # output
178
+ x = x.flatten(2)
179
+ x = self.o(x)
180
+ return x
videox_fun/pipeline/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pipeline_wan import WanPipeline
2
+ from .pipeline_wan2_2 import Wan2_2Pipeline
3
+
4
+ WanFunPipeline = WanPipeline
5
+ Wan2_2FunPipeline = Wan2_2Pipeline
6
+
7
+ import importlib.util
8
+
9
+ if importlib.util.find_spec("paifuser") is not None:
10
+ # --------------------------------------------------------------- #
11
+ # Sparse Attention
12
+ # --------------------------------------------------------------- #
13
+ from paifuser.ops import sparse_reset
14
+
15
+ # Wan2.1
16
+ WanFunPipeline.__call__ = sparse_reset(WanFunPipeline.__call__)
17
+ WanPipeline.__call__ = sparse_reset(WanPipeline.__call__)
18
+
19
+ # Wan2.2
20
+ Wan2_2FunPipeline.__call__ = sparse_reset(Wan2_2FunPipeline.__call__)
21
+ Wan2_2Pipeline.__call__ = sparse_reset(Wan2_2Pipeline.__call__)
videox_fun/pipeline/pipeline_wan.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import FlowMatchEulerDiscreteScheduler
9
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
10
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
11
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
12
+ from diffusers.utils.torch_utils import randn_tensor
13
+ from diffusers.video_processor import VideoProcessor
14
+
15
+ from ..models import (AutoencoderKLWan, AutoTokenizer,
16
+ WanT5EncoderModel, WanTransformer3DModel)
17
+ from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
18
+ get_sampling_sigmas)
19
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
20
+
21
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
+
23
+
24
+ EXAMPLE_DOC_STRING = """
25
+ Examples:
26
+ ```python
27
+ pass
28
+ ```
29
+ """
30
+
31
+
32
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
33
+ def retrieve_timesteps(
34
+ scheduler,
35
+ num_inference_steps: Optional[int] = None,
36
+ device: Optional[Union[str, torch.device]] = None,
37
+ timesteps: Optional[List[int]] = None,
38
+ sigmas: Optional[List[float]] = None,
39
+ **kwargs,
40
+ ):
41
+ """
42
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
43
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
44
+
45
+ Args:
46
+ scheduler (`SchedulerMixin`):
47
+ The scheduler to get timesteps from.
48
+ num_inference_steps (`int`):
49
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
50
+ must be `None`.
51
+ device (`str` or `torch.device`, *optional*):
52
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
53
+ timesteps (`List[int]`, *optional*):
54
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
55
+ `num_inference_steps` and `sigmas` must be `None`.
56
+ sigmas (`List[float]`, *optional*):
57
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
58
+ `num_inference_steps` and `timesteps` must be `None`.
59
+
60
+ Returns:
61
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
62
+ second element is the number of inference steps.
63
+ """
64
+ if timesteps is not None and sigmas is not None:
65
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
66
+ if timesteps is not None:
67
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
68
+ if not accepts_timesteps:
69
+ raise ValueError(
70
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
71
+ f" timestep schedules. Please check whether you are using the correct scheduler."
72
+ )
73
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
74
+ timesteps = scheduler.timesteps
75
+ num_inference_steps = len(timesteps)
76
+ elif sigmas is not None:
77
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
78
+ if not accept_sigmas:
79
+ raise ValueError(
80
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
81
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
82
+ )
83
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
84
+ timesteps = scheduler.timesteps
85
+ num_inference_steps = len(timesteps)
86
+ else:
87
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
88
+ timesteps = scheduler.timesteps
89
+ return timesteps, num_inference_steps
90
+
91
+
92
+ @dataclass
93
+ class WanPipelineOutput(BaseOutput):
94
+ r"""
95
+ Output class for Wan pipelines.
96
+
97
+ Args:
98
+ videos: full decoded video tensor
99
+ ground_videos: decoded grounding segment (optional)
100
+ edit_videos: decoded edited segment (optional)
101
+ """
102
+
103
+ videos: torch.Tensor
104
+ ground_videos: Optional[torch.Tensor] = None
105
+ edit_videos: Optional[torch.Tensor] = None
106
+
107
+
108
+ class WanPipeline(DiffusionPipeline):
109
+ r"""
110
+ Pipeline for text-to-video generation using Wan.
111
+
112
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
113
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
114
+ """
115
+
116
+ _optional_components = []
117
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
118
+
119
+ _callback_tensor_inputs = [
120
+ "latents",
121
+ "prompt_embeds",
122
+ "negative_prompt_embeds",
123
+ ]
124
+
125
+ def __init__(
126
+ self,
127
+ tokenizer: AutoTokenizer,
128
+ text_encoder: WanT5EncoderModel,
129
+ vae: AutoencoderKLWan,
130
+ transformer: WanTransformer3DModel,
131
+ scheduler: FlowMatchEulerDiscreteScheduler,
132
+ ):
133
+ super().__init__()
134
+
135
+ self.register_modules(
136
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
137
+ )
138
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
139
+
140
+ def _get_t5_prompt_embeds(
141
+ self,
142
+ prompt: Union[str, List[str]] = None,
143
+ num_videos_per_prompt: int = 1,
144
+ max_sequence_length: int = 512,
145
+ device: Optional[torch.device] = None,
146
+ dtype: Optional[torch.dtype] = None,
147
+ ):
148
+ device = device or self._execution_device
149
+ dtype = dtype or self.text_encoder.dtype
150
+
151
+ prompt = [prompt] if isinstance(prompt, str) else prompt
152
+ batch_size = len(prompt)
153
+
154
+ text_inputs = self.tokenizer(
155
+ prompt,
156
+ padding="max_length",
157
+ max_length=max_sequence_length,
158
+ truncation=True,
159
+ add_special_tokens=True,
160
+ return_tensors="pt",
161
+ )
162
+ text_input_ids = text_inputs.input_ids
163
+ prompt_attention_mask = text_inputs.attention_mask
164
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
165
+
166
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
167
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
168
+ logger.warning(
169
+ "The following part of your input was truncated because `max_sequence_length` is set to "
170
+ f" {max_sequence_length} tokens: {removed_text}"
171
+ )
172
+
173
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
174
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
175
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
176
+
177
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
178
+ _, seq_len, _ = prompt_embeds.shape
179
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
180
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
181
+
182
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
183
+
184
+ def encode_prompt(
185
+ self,
186
+ prompt: Union[str, List[str]],
187
+ negative_prompt: Optional[Union[str, List[str]]] = None,
188
+ do_classifier_free_guidance: bool = True,
189
+ num_videos_per_prompt: int = 1,
190
+ prompt_embeds: Optional[torch.Tensor] = None,
191
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
192
+ max_sequence_length: int = 512,
193
+ device: Optional[torch.device] = None,
194
+ dtype: Optional[torch.dtype] = None,
195
+ ):
196
+ r"""
197
+ Encodes the prompt into text encoder hidden states.
198
+
199
+ Args:
200
+ prompt (`str` or `List[str]`, *optional*):
201
+ prompt to be encoded
202
+ negative_prompt (`str` or `List[str]`, *optional*):
203
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
204
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
205
+ less than `1`).
206
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
207
+ Whether to use classifier free guidance or not.
208
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
209
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
210
+ prompt_embeds (`torch.Tensor`, *optional*):
211
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
212
+ provided, text embeddings will be generated from `prompt` input argument.
213
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
214
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
215
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
216
+ argument.
217
+ device: (`torch.device`, *optional*):
218
+ torch device
219
+ dtype: (`torch.dtype`, *optional*):
220
+ torch dtype
221
+ """
222
+ device = device or self._execution_device
223
+
224
+ prompt = [prompt] if isinstance(prompt, str) else prompt
225
+ if prompt is not None:
226
+ batch_size = len(prompt)
227
+ else:
228
+ batch_size = prompt_embeds.shape[0]
229
+
230
+ if prompt_embeds is None:
231
+ prompt_embeds = self._get_t5_prompt_embeds(
232
+ prompt=prompt,
233
+ num_videos_per_prompt=num_videos_per_prompt,
234
+ max_sequence_length=max_sequence_length,
235
+ device=device,
236
+ dtype=dtype,
237
+ )
238
+
239
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
240
+ negative_prompt = negative_prompt or ""
241
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
242
+
243
+ if prompt is not None and type(prompt) is not type(negative_prompt):
244
+ raise TypeError(
245
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
246
+ f" {type(prompt)}."
247
+ )
248
+ elif batch_size != len(negative_prompt):
249
+ raise ValueError(
250
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
251
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
252
+ " the batch size of `prompt`."
253
+ )
254
+
255
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
256
+ prompt=negative_prompt,
257
+ num_videos_per_prompt=num_videos_per_prompt,
258
+ max_sequence_length=max_sequence_length,
259
+ device=device,
260
+ dtype=dtype,
261
+ )
262
+
263
+ return prompt_embeds, negative_prompt_embeds
264
+
265
+ def prepare_latents(
266
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
267
+ ):
268
+ if isinstance(generator, list) and len(generator) != batch_size:
269
+ raise ValueError(
270
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
271
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
272
+ )
273
+
274
+ shape = (
275
+ batch_size,
276
+ num_channels_latents,
277
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
278
+ height // self.vae.spatial_compression_ratio,
279
+ width // self.vae.spatial_compression_ratio,
280
+ )
281
+
282
+ if latents is None:
283
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
284
+ else:
285
+ latents = latents.to(device)
286
+
287
+ # scale the initial noise by the standard deviation required by the scheduler
288
+ if hasattr(self.scheduler, "init_noise_sigma"):
289
+ latents = latents * self.scheduler.init_noise_sigma
290
+ return latents
291
+
292
+ def prepare_video_latents(
293
+ self,
294
+ video: torch.Tensor,
295
+ batch_size: int = 1,
296
+ num_channels_latents: int = 16,
297
+ height: int = 480,
298
+ width: int = 832,
299
+ dtype: torch.dtype = torch.float32,
300
+ device: torch.device = None,
301
+ generator: torch.Generator = None,
302
+ condition_count: int = None,
303
+ latents: torch.Tensor = None,
304
+ timestep: torch.Tensor = None,
305
+ ):
306
+
307
+ video = video.to(device=device, dtype=dtype)
308
+ num_latent_frames = (video.shape[2] - 1) // self.vae.temporal_compression_ratio + 1
309
+
310
+ shape = (
311
+ batch_size,
312
+ num_channels_latents,
313
+ num_latent_frames,
314
+ height // self.vae.spatial_compression_ratio,
315
+ width // self.vae.spatial_compression_ratio,
316
+ )
317
+
318
+ if latents is not None:
319
+ return latents.to(device=device, dtype=dtype)
320
+
321
+ video_latents = []
322
+ print('video',video.shape)
323
+ for i in range(video.shape[0]):
324
+ # 假设 self.vae.encode 返回的是 (LatentDistribution, …)
325
+ latent_dist = self.vae.encode(video[i : i + 1])[0]
326
+ latent = latent_dist.mode() # 直接取 mode,不做 mean/std
327
+ video_latents.append(latent)
328
+ init_latents = torch.cat(video_latents, dim=0) # (B, C, T, H', W')
329
+
330
+ # 再往前 condition_count 帧注入随机 noise
331
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
332
+ init_latents[:, :, condition_count:, :, :] = noise[:, :, condition_count:, :, :]
333
+
334
+ # 现在可以正确调用 add_noise
335
+ # init_latents[:, :, condition_count:, :, :] = self.scheduler.add_noise(
336
+ # init_latents[:, :, condition_count:, :, :],
337
+ # noise[:, :, condition_count:, :, :],
338
+ # timestep
339
+ # )
340
+ # print('init_latents shape',init_latents.shape)
341
+ return init_latents
342
+
343
+ def prepare_video_latents_new(
344
+ self,
345
+ video: torch.Tensor,
346
+ batch_size: int = 1,
347
+ num_channels_latents: int = 16,
348
+ height: int = 480,
349
+ width: int = 832,
350
+ dtype: torch.dtype = torch.float32,
351
+ device: torch.device = None,
352
+ generator: torch.Generator = None,
353
+ condition_count: int = None,
354
+ latents: torch.Tensor = None,
355
+ timestep: torch.Tensor = None,
356
+ ):
357
+
358
+ video = video.to(device=device, dtype=dtype)
359
+
360
+ if latents is not None:
361
+ return latents.to(device=device, dtype=dtype)
362
+
363
+ video_latents = []
364
+ print('video',video.shape)
365
+ for i in range(video.shape[0]):
366
+ # 假设 self.vae.encode 返回的是 (LatentDistribution, …)
367
+ latent_dist = self.vae.encode(video[i : i + 1])[0]
368
+ latent = latent_dist.mode() # 直接取 mode,不做 mean/std
369
+ video_latents.append(latent)
370
+ org_latents = torch.cat(video_latents, dim=0) # (B, C, T, H', W')
371
+ print('org_latents',org_latents.shape)
372
+
373
+ # 再往后 condition_count 帧注入随机 noise,shape和org_latents一样
374
+ noise = randn_tensor(org_latents.shape, generator=generator, device=device, dtype=dtype)
375
+ print('noise',noise.shape)
376
+ init_latents = torch.cat([org_latents, noise], dim=2)
377
+ print('init_latents',init_latents.shape)
378
+ return init_latents
379
+
380
+
381
+ def prepare_cot_video_latents(
382
+ self,
383
+ video: torch.Tensor,
384
+ reasoning_latent_count: int = 1,
385
+ batch_size: int = 1,
386
+ num_channels_latents: int = 16,
387
+ height: int = 480,
388
+ width: int = 832,
389
+ dtype: torch.dtype = torch.float32,
390
+ device: torch.device = None,
391
+ generator: torch.Generator = None,
392
+ condition_count: int = None,
393
+ latents: torch.Tensor = None,
394
+ timestep: torch.Tensor = None,
395
+ ):
396
+
397
+ video = video.to(device=device, dtype=dtype)
398
+
399
+ if latents is not None:
400
+ return latents.to(device=device, dtype=dtype)
401
+
402
+ video_latents = []
403
+ #print('video',video.shape)
404
+ for i in range(video.shape[0]):
405
+ # 假设 self.vae.encode 返回的是 (LatentDistribution, …)
406
+ latent_dist = self.vae.encode(video[i : i + 1])[0]
407
+ latent = latent_dist.mode() # 直接取 mode,不做 mean/std
408
+ video_latents.append(latent)
409
+ org_latents = torch.cat(video_latents, dim=0) # (B, C, T, H', W')
410
+ print('org_latents',org_latents.shape)
411
+ batch_size, num_channels_latents, num_frames_latent, height_latent, width_latent = org_latents.shape
412
+ tgt_frames = num_frames_latent + reasoning_latent_count
413
+ noise_latents_shape = (batch_size, num_channels_latents, tgt_frames, height_latent, width_latent)
414
+ # 再往后 condition_count 帧注入随机 noise,shape和org_latents一样
415
+ noise = randn_tensor(noise_latents_shape, generator=generator, device=device, dtype=dtype)
416
+ print('noise',noise.shape)
417
+ init_latents = torch.cat([org_latents, noise], dim=2)
418
+ print('init_latents',init_latents.shape)
419
+ return init_latents
420
+
421
+
422
+
423
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
424
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
425
+ frames = (frames / 2 + 0.5).clamp(0, 1)
426
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
427
+ frames = frames.cpu().float().numpy()
428
+ return frames
429
+
430
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
431
+ def prepare_extra_step_kwargs(self, generator, eta):
432
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
433
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
434
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
435
+ # and should be between [0, 1]
436
+
437
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
438
+ extra_step_kwargs = {}
439
+ if accepts_eta:
440
+ extra_step_kwargs["eta"] = eta
441
+
442
+ # check if the scheduler accepts generator
443
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
444
+ if accepts_generator:
445
+ extra_step_kwargs["generator"] = generator
446
+ return extra_step_kwargs
447
+
448
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
449
+ def check_inputs(
450
+ self,
451
+ prompt,
452
+ height,
453
+ width,
454
+ negative_prompt,
455
+ callback_on_step_end_tensor_inputs,
456
+ prompt_embeds=None,
457
+ negative_prompt_embeds=None,
458
+ ):
459
+ if height % 8 != 0 or width % 8 != 0:
460
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
461
+
462
+ if callback_on_step_end_tensor_inputs is not None and not all(
463
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
464
+ ):
465
+ raise ValueError(
466
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
467
+ )
468
+ if prompt is not None and prompt_embeds is not None:
469
+ raise ValueError(
470
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
471
+ " only forward one of the two."
472
+ )
473
+ elif prompt is None and prompt_embeds is None:
474
+ raise ValueError(
475
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
476
+ )
477
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
478
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
479
+
480
+ if prompt is not None and negative_prompt_embeds is not None:
481
+ raise ValueError(
482
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
483
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
484
+ )
485
+
486
+ if negative_prompt is not None and negative_prompt_embeds is not None:
487
+ raise ValueError(
488
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
489
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
490
+ )
491
+
492
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
493
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
494
+ raise ValueError(
495
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
496
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
497
+ f" {negative_prompt_embeds.shape}."
498
+ )
499
+
500
+ @property
501
+ def guidance_scale(self):
502
+ return self._guidance_scale
503
+
504
+ @property
505
+ def num_timesteps(self):
506
+ return self._num_timesteps
507
+
508
+ @property
509
+ def attention_kwargs(self):
510
+ return self._attention_kwargs
511
+
512
+ @property
513
+ def interrupt(self):
514
+ return self._interrupt
515
+
516
+ @torch.no_grad()
517
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
518
+ def __call__(
519
+ self,
520
+ video: Union[torch.FloatTensor] = None,
521
+ prompt: Optional[Union[str, List[str]]] = None,
522
+ negative_prompt: Optional[Union[str, List[str]]] = None,
523
+ height: int = 480,
524
+ width: int = 720,
525
+ num_frames: int = 49,
526
+ source_frames: int = 33,
527
+ reasoning_frames: int = 4,
528
+ num_inference_steps: int = 50,
529
+ timesteps: Optional[List[int]] = None,
530
+ guidance_scale: float = 6,
531
+ num_videos_per_prompt: int = 1,
532
+ eta: float = 0.0,
533
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
534
+ latents: Optional[torch.FloatTensor] = None,
535
+ prompt_embeds: Optional[torch.FloatTensor] = None,
536
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
537
+ output_type: str = "numpy",
538
+ return_dict: bool = False,
539
+ callback_on_step_end: Optional[
540
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
541
+ ] = None,
542
+ attention_kwargs: Optional[Dict[str, Any]] = None,
543
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
544
+ max_sequence_length: int = 512,
545
+ comfyui_progressbar: bool = False,
546
+ shift: int = 5,
547
+ repeat_rope: bool = True,
548
+ cot: bool = False,
549
+ ) -> Union[WanPipelineOutput, Tuple]:
550
+ """
551
+ Function invoked when calling the pipeline for generation.
552
+ Args:
553
+
554
+ Examples:
555
+
556
+ Returns:
557
+
558
+ """
559
+
560
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
561
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
562
+ num_videos_per_prompt = 1
563
+
564
+ # 1. Check inputs. Raise error if not correct
565
+ self.check_inputs(
566
+ prompt,
567
+ height,
568
+ width,
569
+ negative_prompt,
570
+ callback_on_step_end_tensor_inputs,
571
+ prompt_embeds,
572
+ negative_prompt_embeds,
573
+ )
574
+ self._guidance_scale = guidance_scale
575
+ self._attention_kwargs = attention_kwargs
576
+ self._interrupt = False
577
+
578
+ # 2. Default call parameters
579
+ if prompt is not None and isinstance(prompt, str):
580
+ batch_size = 1
581
+ elif prompt is not None and isinstance(prompt, list):
582
+ batch_size = len(prompt)
583
+ else:
584
+ batch_size = prompt_embeds.shape[0]
585
+
586
+ device = self._execution_device
587
+ weight_dtype = self.text_encoder.dtype
588
+
589
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
590
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
591
+ # corresponds to doing no classifier free guidance.
592
+ do_classifier_free_guidance = guidance_scale > 1.0
593
+
594
+ # 3. Encode input prompt
595
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
596
+ prompt,
597
+ negative_prompt,
598
+ do_classifier_free_guidance,
599
+ num_videos_per_prompt=num_videos_per_prompt,
600
+ prompt_embeds=prompt_embeds,
601
+ negative_prompt_embeds=negative_prompt_embeds,
602
+ max_sequence_length=max_sequence_length,
603
+ device=device,
604
+ )
605
+ if do_classifier_free_guidance:
606
+ in_prompt_embeds = negative_prompt_embeds + prompt_embeds
607
+ else:
608
+ in_prompt_embeds = prompt_embeds
609
+
610
+ # 4. Prepare timesteps
611
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
612
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
613
+ elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
614
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
615
+ timesteps = self.scheduler.timesteps
616
+ elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
617
+ sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
618
+ timesteps, _ = retrieve_timesteps(
619
+ self.scheduler,
620
+ device=device,
621
+ sigmas=sampling_sigmas)
622
+ else:
623
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
624
+ self._num_timesteps = len(timesteps)
625
+ if comfyui_progressbar:
626
+ from comfy.utils import ProgressBar
627
+ pbar = ProgressBar(num_inference_steps + 1)
628
+
629
+ # compute latent source length consistent with training: (F-1)//ratio + 1, or 1 when F==1
630
+ compression_ratio = getattr(self.vae, "temporal_compression_ratio", 4)
631
+ condition_count = 1 if source_frames == 1 else (source_frames - 1) // compression_ratio + 1
632
+
633
+ # 5. Prepare latents (unified across org/repeat/cot)
634
+ latent_channels = self.transformer.config.in_channels
635
+ if cot:
636
+ # latent grounding segment length from pixel-space reasoning_frames (used only when cot=True)
637
+ ground_latent_count = 1 if reasoning_frames <= 1 else (reasoning_frames - 1) // compression_ratio + 1
638
+ print('ground_latent_count',ground_latent_count)
639
+ latents = self.prepare_cot_video_latents(
640
+ video,
641
+ ground_latent_count,
642
+ batch_size,
643
+ latent_channels,
644
+ height,
645
+ width,
646
+ weight_dtype,
647
+ device,
648
+ generator,
649
+ condition_count,
650
+ latents,
651
+ )
652
+ elif repeat_rope:
653
+ latents = self.prepare_video_latents_new(
654
+ video,
655
+ batch_size,
656
+ latent_channels,
657
+ height,
658
+ width,
659
+ weight_dtype,
660
+ device,
661
+ generator,
662
+ condition_count,
663
+ latents,
664
+ )
665
+ else:
666
+ latents = self.prepare_video_latents_new(
667
+ video,
668
+ batch_size,
669
+ latent_channels,
670
+ height,
671
+ width,
672
+ weight_dtype,
673
+ device,
674
+ generator,
675
+ condition_count,
676
+ latents,
677
+ )
678
+ if comfyui_progressbar:
679
+ pbar.update(1)
680
+
681
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
682
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
683
+
684
+ # Get actual latent dimensions (consistent with training)
685
+ #print('latents',latents.shape)
686
+ bsz, channel, actual_num_frames, actual_height, actual_width = latents.size()
687
+ target_shape = (self.vae.latent_channels, actual_num_frames, actual_height, actual_width)
688
+ #print('target_shape',target_shape)
689
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
690
+ # 7. Denoising loop
691
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
692
+ self.transformer.num_inference_steps = num_inference_steps
693
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
694
+ for i, t in enumerate(timesteps):
695
+ self.transformer.current_steps = i
696
+
697
+ if self.interrupt:
698
+ continue
699
+
700
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
701
+ if hasattr(self.scheduler, "scale_model_input"):
702
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
703
+
704
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
705
+ timestep = t.expand(latent_model_input.shape[0])
706
+
707
+ # predict noise model_output
708
+ with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
709
+ # frame_split_indices enables repeat temporal RoPE for paired (src+tgt) inputs
710
+ frame_split_indices = None
711
+ ground_frame_indices = None
712
+ if repeat_rope and video is not None:
713
+ frame_split_indices = [condition_count] * latent_model_input.shape[0]
714
+ if cot:
715
+ # grounding frames should use temporal RoPE position 0
716
+ ground_frame_indices = [
717
+ (condition_count, condition_count + ground_latent_count)
718
+ ] * latent_model_input.shape[0]
719
+ # print('ground_frame_indices',ground_frame_indices)
720
+ # print('frame_split_indices',frame_split_indices)
721
+ noise_pred = self.transformer(
722
+ x=latent_model_input,
723
+ context=in_prompt_embeds,
724
+ t=timestep,
725
+ seq_len=seq_len,
726
+ frame_split_indices=frame_split_indices,
727
+ ground_frame_indices=ground_frame_indices,
728
+ )
729
+
730
+ # perform guidance
731
+ if do_classifier_free_guidance:
732
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
733
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
734
+
735
+ ######source video no noise pred################
736
+ noise_pred[:, :, :condition_count] = 0
737
+ ######source video no noise pred################
738
+
739
+ # compute the previous noisy sample x_t -> x_t-1
740
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
741
+
742
+ if callback_on_step_end is not None:
743
+ callback_kwargs = {}
744
+ for k in callback_on_step_end_tensor_inputs:
745
+ callback_kwargs[k] = locals()[k]
746
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
747
+
748
+ latents = callback_outputs.pop("latents", latents)
749
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
750
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
751
+
752
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
753
+ progress_bar.update()
754
+ if comfyui_progressbar:
755
+ pbar.update(1)
756
+
757
+ # Optionally decode outputs. For cot=True, segment into src/ground/edit; otherwise decode whole latents
758
+ ground_video = None
759
+ edit_video = None
760
+ if cot:
761
+ if output_type == "numpy":
762
+ ground_start = condition_count
763
+ ground_end = condition_count + ground_latent_count
764
+ src_lat = latents[:, :, :ground_start] if ground_start > 0 else None
765
+ ground_lat = latents[:, :, ground_start:ground_end] if ground_end > ground_start and ground_start < latents.shape[2] else None
766
+ edit_lat = latents[:, :, ground_end:] if ground_end < latents.shape[2] else None
767
+
768
+ parts = []
769
+ ## only ground and edit
770
+ if ground_lat is not None and ground_lat.shape[2] > 0:
771
+ ground_video = self.decode_latents(ground_lat)
772
+ parts.append(ground_video)
773
+ if edit_lat is not None and edit_lat.shape[2] > 0:
774
+ edit_video = self.decode_latents(edit_lat)
775
+ parts.append(edit_video)
776
+ print('ground_video',ground_video.shape, 'edit_video',edit_video.shape)
777
+ video = np.concatenate(parts, axis=2)
778
+ else:
779
+ # org/repeat: split by condition_count -> src + edit, then temporal concat
780
+ if output_type == "numpy":
781
+ src_lat = latents[:, :, :condition_count] if condition_count > 0 else None
782
+ edit_lat = latents[:, :, condition_count:] if condition_count < latents.shape[2] else None
783
+ ## only decode edit video
784
+ if edit_lat is not None and edit_lat.shape[2] > 0:
785
+ edit_video = self.decode_latents(edit_lat)
786
+ video = edit_video
787
+
788
+ # Offload all models
789
+ self.maybe_free_model_hooks()
790
+
791
+ if not return_dict:
792
+ if isinstance(video, np.ndarray):
793
+ video = torch.from_numpy(video)
794
+ if ground_video is not None and isinstance(ground_video, np.ndarray):
795
+ ground_video = torch.from_numpy(ground_video)
796
+ if edit_video is not None and isinstance(edit_video, np.ndarray):
797
+ edit_video = torch.from_numpy(edit_video)
798
+
799
+ return WanPipelineOutput(videos=video, ground_videos=ground_video, edit_videos=edit_video)
videox_fun/pipeline/pipeline_wan2_2.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import FlowMatchEulerDiscreteScheduler
9
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
10
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
11
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
12
+ from diffusers.utils.torch_utils import randn_tensor
13
+ from diffusers.video_processor import VideoProcessor
14
+
15
+ from ..models import (AutoencoderKLWan, AutoTokenizer,
16
+ WanT5EncoderModel, Wan2_2Transformer3DModel)
17
+ from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
18
+ get_sampling_sigmas)
19
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
20
+
21
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
+
23
+
24
+ EXAMPLE_DOC_STRING = """
25
+ Examples:
26
+ ```python
27
+ pass
28
+ ```
29
+ """
30
+
31
+
32
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
33
+ def retrieve_timesteps(
34
+ scheduler,
35
+ num_inference_steps: Optional[int] = None,
36
+ device: Optional[Union[str, torch.device]] = None,
37
+ timesteps: Optional[List[int]] = None,
38
+ sigmas: Optional[List[float]] = None,
39
+ **kwargs,
40
+ ):
41
+ """
42
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
43
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
44
+
45
+ Args:
46
+ scheduler (`SchedulerMixin`):
47
+ The scheduler to get timesteps from.
48
+ num_inference_steps (`int`):
49
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
50
+ must be `None`.
51
+ device (`str` or `torch.device`, *optional*):
52
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
53
+ timesteps (`List[int]`, *optional*):
54
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
55
+ `num_inference_steps` and `sigmas` must be `None`.
56
+ sigmas (`List[float]`, *optional*):
57
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
58
+ `num_inference_steps` and `timesteps` must be `None`.
59
+
60
+ Returns:
61
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
62
+ second element is the number of inference steps.
63
+ """
64
+ if timesteps is not None and sigmas is not None:
65
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
66
+ if timesteps is not None:
67
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
68
+ if not accepts_timesteps:
69
+ raise ValueError(
70
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
71
+ f" timestep schedules. Please check whether you are using the correct scheduler."
72
+ )
73
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
74
+ timesteps = scheduler.timesteps
75
+ num_inference_steps = len(timesteps)
76
+ elif sigmas is not None:
77
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
78
+ if not accept_sigmas:
79
+ raise ValueError(
80
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
81
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
82
+ )
83
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
84
+ timesteps = scheduler.timesteps
85
+ num_inference_steps = len(timesteps)
86
+ else:
87
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
88
+ timesteps = scheduler.timesteps
89
+ return timesteps, num_inference_steps
90
+
91
+
92
+ @dataclass
93
+ class WanPipelineOutput(BaseOutput):
94
+ r"""
95
+ Output class for CogVideo pipelines.
96
+
97
+ Args:
98
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
99
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
100
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
101
+ `(batch_size, num_frames, channels, height, width)`.
102
+ """
103
+
104
+ videos: torch.Tensor
105
+
106
+
107
+ class Wan2_2Pipeline(DiffusionPipeline):
108
+ r"""
109
+ Pipeline for text-to-video generation using Wan.
110
+
111
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
112
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
113
+ """
114
+
115
+ _optional_components = ["transformer_2"]
116
+ model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae"
117
+
118
+ _callback_tensor_inputs = [
119
+ "latents",
120
+ "prompt_embeds",
121
+ "negative_prompt_embeds",
122
+ ]
123
+
124
+ def __init__(
125
+ self,
126
+ tokenizer: AutoTokenizer,
127
+ text_encoder: WanT5EncoderModel,
128
+ vae: AutoencoderKLWan,
129
+ transformer: Wan2_2Transformer3DModel,
130
+ transformer_2: Wan2_2Transformer3DModel = None,
131
+ scheduler: FlowMatchEulerDiscreteScheduler = None,
132
+ ):
133
+ super().__init__()
134
+
135
+ self.register_modules(
136
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
137
+ transformer_2=transformer_2, scheduler=scheduler
138
+ )
139
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio)
140
+
141
+ def _get_t5_prompt_embeds(
142
+ self,
143
+ prompt: Union[str, List[str]] = None,
144
+ num_videos_per_prompt: int = 1,
145
+ max_sequence_length: int = 512,
146
+ device: Optional[torch.device] = None,
147
+ dtype: Optional[torch.dtype] = None,
148
+ ):
149
+ device = device or self._execution_device
150
+ dtype = dtype or self.text_encoder.dtype
151
+
152
+ prompt = [prompt] if isinstance(prompt, str) else prompt
153
+ batch_size = len(prompt)
154
+
155
+ text_inputs = self.tokenizer(
156
+ prompt,
157
+ padding="max_length",
158
+ max_length=max_sequence_length,
159
+ truncation=True,
160
+ add_special_tokens=True,
161
+ return_tensors="pt",
162
+ )
163
+ text_input_ids = text_inputs.input_ids
164
+ prompt_attention_mask = text_inputs.attention_mask
165
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
166
+
167
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
168
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
169
+ logger.warning(
170
+ "The following part of your input was truncated because `max_sequence_length` is set to "
171
+ f" {max_sequence_length} tokens: {removed_text}"
172
+ )
173
+
174
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
175
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
176
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
177
+
178
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
179
+ _, seq_len, _ = prompt_embeds.shape
180
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
181
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
182
+
183
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
184
+
185
+ def encode_prompt(
186
+ self,
187
+ prompt: Union[str, List[str]],
188
+ negative_prompt: Optional[Union[str, List[str]]] = None,
189
+ do_classifier_free_guidance: bool = True,
190
+ num_videos_per_prompt: int = 1,
191
+ prompt_embeds: Optional[torch.Tensor] = None,
192
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
193
+ max_sequence_length: int = 512,
194
+ device: Optional[torch.device] = None,
195
+ dtype: Optional[torch.dtype] = None,
196
+ ):
197
+ r"""
198
+ Encodes the prompt into text encoder hidden states.
199
+
200
+ Args:
201
+ prompt (`str` or `List[str]`, *optional*):
202
+ prompt to be encoded
203
+ negative_prompt (`str` or `List[str]`, *optional*):
204
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
205
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
206
+ less than `1`).
207
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
208
+ Whether to use classifier free guidance or not.
209
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
210
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
211
+ prompt_embeds (`torch.Tensor`, *optional*):
212
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
213
+ provided, text embeddings will be generated from `prompt` input argument.
214
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
215
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
216
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
217
+ argument.
218
+ device: (`torch.device`, *optional*):
219
+ torch device
220
+ dtype: (`torch.dtype`, *optional*):
221
+ torch dtype
222
+ """
223
+ device = device or self._execution_device
224
+
225
+ prompt = [prompt] if isinstance(prompt, str) else prompt
226
+ if prompt is not None:
227
+ batch_size = len(prompt)
228
+ else:
229
+ batch_size = prompt_embeds.shape[0]
230
+
231
+ if prompt_embeds is None:
232
+ prompt_embeds = self._get_t5_prompt_embeds(
233
+ prompt=prompt,
234
+ num_videos_per_prompt=num_videos_per_prompt,
235
+ max_sequence_length=max_sequence_length,
236
+ device=device,
237
+ dtype=dtype,
238
+ )
239
+
240
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
241
+ negative_prompt = negative_prompt or ""
242
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
243
+
244
+ if prompt is not None and type(prompt) is not type(negative_prompt):
245
+ raise TypeError(
246
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
247
+ f" {type(prompt)}."
248
+ )
249
+ elif batch_size != len(negative_prompt):
250
+ raise ValueError(
251
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
252
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
253
+ " the batch size of `prompt`."
254
+ )
255
+
256
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
257
+ prompt=negative_prompt,
258
+ num_videos_per_prompt=num_videos_per_prompt,
259
+ max_sequence_length=max_sequence_length,
260
+ device=device,
261
+ dtype=dtype,
262
+ )
263
+
264
+ return prompt_embeds, negative_prompt_embeds
265
+
266
+ def prepare_latents(
267
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
268
+ ):
269
+ if isinstance(generator, list) and len(generator) != batch_size:
270
+ raise ValueError(
271
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
272
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
273
+ )
274
+
275
+ shape = (
276
+ batch_size,
277
+ num_channels_latents,
278
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
279
+ height // self.vae.spatial_compression_ratio,
280
+ width // self.vae.spatial_compression_ratio,
281
+ )
282
+
283
+ if latents is None:
284
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
285
+ else:
286
+ latents = latents.to(device)
287
+
288
+ # scale the initial noise by the standard deviation required by the scheduler
289
+ if hasattr(self.scheduler, "init_noise_sigma"):
290
+ latents = latents * self.scheduler.init_noise_sigma
291
+ return latents
292
+
293
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
294
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
295
+ frames = (frames / 2 + 0.5).clamp(0, 1)
296
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
297
+ frames = frames.cpu().float().numpy()
298
+ return frames
299
+
300
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
301
+ def prepare_extra_step_kwargs(self, generator, eta):
302
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
303
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
304
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
305
+ # and should be between [0, 1]
306
+
307
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
308
+ extra_step_kwargs = {}
309
+ if accepts_eta:
310
+ extra_step_kwargs["eta"] = eta
311
+
312
+ # check if the scheduler accepts generator
313
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
314
+ if accepts_generator:
315
+ extra_step_kwargs["generator"] = generator
316
+ return extra_step_kwargs
317
+
318
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
319
+ def check_inputs(
320
+ self,
321
+ prompt,
322
+ height,
323
+ width,
324
+ negative_prompt,
325
+ callback_on_step_end_tensor_inputs,
326
+ prompt_embeds=None,
327
+ negative_prompt_embeds=None,
328
+ ):
329
+ if height % 8 != 0 or width % 8 != 0:
330
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
331
+
332
+ if callback_on_step_end_tensor_inputs is not None and not all(
333
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
334
+ ):
335
+ raise ValueError(
336
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
337
+ )
338
+ if prompt is not None and prompt_embeds is not None:
339
+ raise ValueError(
340
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
341
+ " only forward one of the two."
342
+ )
343
+ elif prompt is None and prompt_embeds is None:
344
+ raise ValueError(
345
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
346
+ )
347
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
348
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
349
+
350
+ if prompt is not None and negative_prompt_embeds is not None:
351
+ raise ValueError(
352
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
353
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
354
+ )
355
+
356
+ if negative_prompt is not None and negative_prompt_embeds is not None:
357
+ raise ValueError(
358
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
359
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
360
+ )
361
+
362
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
363
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
364
+ raise ValueError(
365
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
366
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
367
+ f" {negative_prompt_embeds.shape}."
368
+ )
369
+
370
+ @property
371
+ def guidance_scale(self):
372
+ return self._guidance_scale
373
+
374
+ @property
375
+ def num_timesteps(self):
376
+ return self._num_timesteps
377
+
378
+ @property
379
+ def attention_kwargs(self):
380
+ return self._attention_kwargs
381
+
382
+ @property
383
+ def interrupt(self):
384
+ return self._interrupt
385
+
386
+ @torch.no_grad()
387
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
388
+ def __call__(
389
+ self,
390
+ prompt: Optional[Union[str, List[str]]] = None,
391
+ negative_prompt: Optional[Union[str, List[str]]] = None,
392
+ height: int = 480,
393
+ width: int = 720,
394
+ num_frames: int = 49,
395
+ num_inference_steps: int = 50,
396
+ timesteps: Optional[List[int]] = None,
397
+ guidance_scale: float = 6,
398
+ num_videos_per_prompt: int = 1,
399
+ eta: float = 0.0,
400
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
401
+ latents: Optional[torch.FloatTensor] = None,
402
+ prompt_embeds: Optional[torch.FloatTensor] = None,
403
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
404
+ output_type: str = "numpy",
405
+ return_dict: bool = False,
406
+ callback_on_step_end: Optional[
407
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
408
+ ] = None,
409
+ attention_kwargs: Optional[Dict[str, Any]] = None,
410
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
411
+ max_sequence_length: int = 512,
412
+ boundary: float = 0.875,
413
+ comfyui_progressbar: bool = False,
414
+ shift: int = 5,
415
+ ) -> Union[WanPipelineOutput, Tuple]:
416
+ """
417
+ Function invoked when calling the pipeline for generation.
418
+ Args:
419
+
420
+ Examples:
421
+
422
+ Returns:
423
+
424
+ """
425
+
426
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
427
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
428
+ num_videos_per_prompt = 1
429
+
430
+ # 1. Check inputs. Raise error if not correct
431
+ self.check_inputs(
432
+ prompt,
433
+ height,
434
+ width,
435
+ negative_prompt,
436
+ callback_on_step_end_tensor_inputs,
437
+ prompt_embeds,
438
+ negative_prompt_embeds,
439
+ )
440
+ self._guidance_scale = guidance_scale
441
+ self._attention_kwargs = attention_kwargs
442
+ self._interrupt = False
443
+
444
+ # 2. Default call parameters
445
+ if prompt is not None and isinstance(prompt, str):
446
+ batch_size = 1
447
+ elif prompt is not None and isinstance(prompt, list):
448
+ batch_size = len(prompt)
449
+ else:
450
+ batch_size = prompt_embeds.shape[0]
451
+
452
+ device = self._execution_device
453
+ weight_dtype = self.text_encoder.dtype
454
+
455
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
456
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
457
+ # corresponds to doing no classifier free guidance.
458
+ do_classifier_free_guidance = guidance_scale > 1.0
459
+
460
+ # 3. Encode input prompt
461
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
462
+ prompt,
463
+ negative_prompt,
464
+ do_classifier_free_guidance,
465
+ num_videos_per_prompt=num_videos_per_prompt,
466
+ prompt_embeds=prompt_embeds,
467
+ negative_prompt_embeds=negative_prompt_embeds,
468
+ max_sequence_length=max_sequence_length,
469
+ device=device,
470
+ )
471
+ if do_classifier_free_guidance:
472
+ in_prompt_embeds = negative_prompt_embeds + prompt_embeds
473
+ else:
474
+ in_prompt_embeds = prompt_embeds
475
+
476
+ # 4. Prepare timesteps
477
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
478
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
479
+ elif isinstance(self.scheduler, FlowUniPCMultistepScheduler):
480
+ self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift)
481
+ timesteps = self.scheduler.timesteps
482
+ elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler):
483
+ sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift)
484
+ timesteps, _ = retrieve_timesteps(
485
+ self.scheduler,
486
+ device=device,
487
+ sigmas=sampling_sigmas)
488
+ else:
489
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
490
+ self._num_timesteps = len(timesteps)
491
+ if comfyui_progressbar:
492
+ from comfy.utils import ProgressBar
493
+ pbar = ProgressBar(num_inference_steps + 1)
494
+
495
+ # 5. Prepare latents
496
+ latent_channels = self.transformer.config.in_channels
497
+ latents = self.prepare_latents(
498
+ batch_size * num_videos_per_prompt,
499
+ latent_channels,
500
+ num_frames,
501
+ height,
502
+ width,
503
+ weight_dtype,
504
+ device,
505
+ generator,
506
+ latents,
507
+ )
508
+ if comfyui_progressbar:
509
+ pbar.update(1)
510
+
511
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
512
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
513
+
514
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio)
515
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
516
+ # 7. Denoising loop
517
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
518
+ self.transformer.num_inference_steps = num_inference_steps
519
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
520
+ for i, t in enumerate(timesteps):
521
+ self.transformer.current_steps = i
522
+
523
+ if self.interrupt:
524
+ continue
525
+
526
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
527
+ if hasattr(self.scheduler, "scale_model_input"):
528
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
529
+
530
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
531
+ timestep = t.expand(latent_model_input.shape[0])
532
+
533
+ if self.transformer_2 is not None:
534
+ if t >= boundary * self.scheduler.config.num_train_timesteps:
535
+ local_transformer = self.transformer_2
536
+ else:
537
+ local_transformer = self.transformer
538
+ else:
539
+ local_transformer = self.transformer
540
+
541
+ # predict noise model_output
542
+ with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device):
543
+ noise_pred = local_transformer(
544
+ x=latent_model_input,
545
+ context=in_prompt_embeds,
546
+ t=timestep,
547
+ seq_len=seq_len,
548
+ )
549
+
550
+ # perform guidance
551
+ if do_classifier_free_guidance:
552
+ if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))):
553
+ sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0]
554
+ else:
555
+ sample_guide_scale = self.guidance_scale
556
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
557
+ noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond)
558
+
559
+ # compute the previous noisy sample x_t -> x_t-1
560
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
561
+
562
+ if callback_on_step_end is not None:
563
+ callback_kwargs = {}
564
+ for k in callback_on_step_end_tensor_inputs:
565
+ callback_kwargs[k] = locals()[k]
566
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
567
+
568
+ latents = callback_outputs.pop("latents", latents)
569
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
570
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
571
+
572
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
573
+ progress_bar.update()
574
+ if comfyui_progressbar:
575
+ pbar.update(1)
576
+
577
+ if output_type == "numpy":
578
+ video = self.decode_latents(latents)
579
+ elif not output_type == "latent":
580
+ video = self.decode_latents(latents)
581
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
582
+ else:
583
+ video = latents
584
+
585
+ # Offload all models
586
+ self.maybe_free_model_hooks()
587
+
588
+ if not return_dict:
589
+ video = torch.from_numpy(video)
590
+
591
+ return WanPipelineOutput(videos=video)
videox_fun/ui/cogvideox_fun_ui.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
2
+ """
3
+ import os
4
+ import random
5
+
6
+ import cv2
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+ from safetensors import safe_open
12
+
13
+ from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
14
+ from ..models import (AutoencoderKLCogVideoX, CogVideoXTransformer3DModel,
15
+ T5EncoderModel, T5Tokenizer)
16
+ from ..pipeline import (CogVideoXFunControlPipeline,
17
+ CogVideoXFunInpaintPipeline, CogVideoXFunPipeline)
18
+ from ..utils.fp8_optimization import (convert_model_weight_to_float8, replace_parameters_by_name,
19
+ convert_weight_dtype_wrapper)
20
+ from ..utils.lora_utils import merge_lora, unmerge_lora
21
+ from ..utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, timer,
22
+ get_video_to_video_latent, save_videos_grid)
23
+ from .controller import (Fun_Controller, Fun_Controller_Client,
24
+ all_cheduler_dict, css, ddpm_scheduler_dict,
25
+ flow_scheduler_dict, gradio_version,
26
+ gradio_version_is_above_4)
27
+ from .ui import (create_cfg_and_seedbox,
28
+ create_fake_finetune_models_checkpoints,
29
+ create_fake_height_width, create_fake_model_checkpoints,
30
+ create_fake_model_type, create_finetune_models_checkpoints,
31
+ create_generation_method,
32
+ create_generation_methods_and_video_length,
33
+ create_height_width, create_model_checkpoints,
34
+ create_model_type, create_prompts, create_samplers,
35
+ create_ui_outputs)
36
+ from ..dist import set_multi_gpus_devices, shard_model
37
+
38
+
39
+ class CogVideoXFunController(Fun_Controller):
40
+ def update_diffusion_transformer(self, diffusion_transformer_dropdown):
41
+ print(f"Update diffusion transformer: {diffusion_transformer_dropdown}")
42
+ self.diffusion_transformer_dropdown = diffusion_transformer_dropdown
43
+ if diffusion_transformer_dropdown == "none":
44
+ return gr.update()
45
+ self.vae = AutoencoderKLCogVideoX.from_pretrained(
46
+ diffusion_transformer_dropdown,
47
+ subfolder="vae",
48
+ ).to(self.weight_dtype)
49
+
50
+ # Get Transformer
51
+ self.transformer = CogVideoXTransformer3DModel.from_pretrained(
52
+ diffusion_transformer_dropdown,
53
+ subfolder="transformer",
54
+ low_cpu_mem_usage=True,
55
+ ).to(self.weight_dtype)
56
+
57
+ # Get tokenizer and text_encoder
58
+ tokenizer = T5Tokenizer.from_pretrained(
59
+ diffusion_transformer_dropdown, subfolder="tokenizer"
60
+ )
61
+ text_encoder = T5EncoderModel.from_pretrained(
62
+ diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype
63
+ )
64
+
65
+ # Get pipeline
66
+ if self.model_type == "Inpaint":
67
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
68
+ self.pipeline = CogVideoXFunInpaintPipeline(
69
+ tokenizer=tokenizer,
70
+ text_encoder=text_encoder,
71
+ vae=self.vae,
72
+ transformer=self.transformer,
73
+ scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
74
+ )
75
+ else:
76
+ self.pipeline = CogVideoXFunPipeline(
77
+ tokenizer=tokenizer,
78
+ text_encoder=text_encoder,
79
+ vae=self.vae,
80
+ transformer=self.transformer,
81
+ scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
82
+ )
83
+ else:
84
+ self.pipeline = CogVideoXFunControlPipeline(
85
+ diffusion_transformer_dropdown,
86
+ vae=self.vae,
87
+ transformer=self.transformer,
88
+ scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
89
+ torch_dtype=self.weight_dtype
90
+ )
91
+
92
+ if self.ulysses_degree > 1 or self.ring_degree > 1:
93
+ from functools import partial
94
+ self.transformer.enable_multi_gpus_inference()
95
+ if self.fsdp_dit:
96
+ shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype)
97
+ self.pipeline.transformer = shard_fn(self.pipeline.transformer)
98
+ print("Add FSDP DIT")
99
+ if self.fsdp_text_encoder:
100
+ shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype)
101
+ self.pipeline.text_encoder = shard_fn(self.pipeline.text_encoder)
102
+ print("Add FSDP TEXT ENCODER")
103
+
104
+ if self.compile_dit:
105
+ for i in range(len(self.pipeline.transformer.transformer_blocks)):
106
+ self.pipeline.transformer.transformer_blocks[i] = torch.compile(self.pipeline.transformer.transformer_blocks[i])
107
+ print("Add Compile")
108
+
109
+ if self.GPU_memory_mode == "sequential_cpu_offload":
110
+ self.pipeline.enable_sequential_cpu_offload(device=self.device)
111
+ elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
112
+ convert_model_weight_to_float8(self.pipeline.transformer, exclude_module_name=[], device=self.device)
113
+ convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype)
114
+ self.pipeline.enable_model_cpu_offload(device=self.device)
115
+ elif self.GPU_memory_mode == "model_cpu_offload":
116
+ self.pipeline.enable_model_cpu_offload(device=self.device)
117
+ elif self.GPU_memory_mode == "model_full_load_and_qfloat8":
118
+ convert_model_weight_to_float8(self.pipeline.transformer, exclude_module_name=[], device=self.device)
119
+ convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype)
120
+ self.pipeline.to(self.device)
121
+ else:
122
+ self.pipeline.to(self.device)
123
+ print("Update diffusion transformer done")
124
+ return gr.update()
125
+
126
+ @timer
127
+ def generate(
128
+ self,
129
+ diffusion_transformer_dropdown,
130
+ base_model_dropdown,
131
+ lora_model_dropdown,
132
+ lora_alpha_slider,
133
+ prompt_textbox,
134
+ negative_prompt_textbox,
135
+ sampler_dropdown,
136
+ sample_step_slider,
137
+ resize_method,
138
+ width_slider,
139
+ height_slider,
140
+ base_resolution,
141
+ generation_method,
142
+ length_slider,
143
+ overlap_video_length,
144
+ partial_video_length,
145
+ cfg_scale_slider,
146
+ start_image,
147
+ end_image,
148
+ validation_video,
149
+ validation_video_mask,
150
+ control_video,
151
+ denoise_strength,
152
+ seed_textbox,
153
+ ref_image = None,
154
+ enable_teacache = None,
155
+ teacache_threshold = None,
156
+ num_skip_start_steps = None,
157
+ teacache_offload = None,
158
+ cfg_skip_ratio = None,
159
+ enable_riflex = None,
160
+ riflex_k = None,
161
+ base_model_2_dropdown=None,
162
+ lora_model_2_dropdown=None,
163
+ fps = None,
164
+ is_api = False,
165
+ ):
166
+ self.clear_cache()
167
+
168
+ print(f"Input checking.")
169
+ _, comment = self.input_check(
170
+ resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api
171
+ )
172
+ print(f"Input checking down")
173
+ if comment != "OK":
174
+ return "", comment
175
+ is_image = True if generation_method == "Image Generation" else False
176
+
177
+ if self.base_model_path != base_model_dropdown:
178
+ self.update_base_model(base_model_dropdown)
179
+
180
+ if self.lora_model_path != lora_model_dropdown:
181
+ self.update_lora_model(lora_model_dropdown)
182
+
183
+ print(f"Load scheduler.")
184
+ self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
185
+ print(f"Load scheduler down.")
186
+
187
+ if resize_method == "Resize according to Reference":
188
+ print(f"Calculate height and width according to Reference.")
189
+ height_slider, width_slider = self.get_height_width_from_reference(
190
+ base_resolution, start_image, validation_video, control_video,
191
+ )
192
+
193
+ if self.lora_model_path != "none":
194
+ print(f"Merge Lora.")
195
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
196
+ print(f"Merge Lora done.")
197
+
198
+ if fps is None:
199
+ fps = 8
200
+
201
+ print(f"Generate seed.")
202
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
203
+ else: seed_textbox = np.random.randint(0, 1e10)
204
+ generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox))
205
+ print(f"Generate seed done.")
206
+
207
+ try:
208
+ print(f"Generation.")
209
+ if self.model_type == "Inpaint":
210
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
211
+ if generation_method == "Long Video Generation":
212
+ if validation_video is not None:
213
+ raise gr.Error(f"Video to Video is not Support Long Video Generation now.")
214
+ init_frames = 0
215
+ last_frames = init_frames + partial_video_length
216
+ while init_frames < length_slider:
217
+ if last_frames >= length_slider:
218
+ _partial_video_length = length_slider - init_frames
219
+ _partial_video_length = int((_partial_video_length - 1) // self.vae.config.temporal_compression_ratio * self.vae.config.temporal_compression_ratio) + 1
220
+
221
+ if _partial_video_length <= 0:
222
+ break
223
+ else:
224
+ _partial_video_length = partial_video_length
225
+
226
+ if last_frames >= length_slider:
227
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
228
+ else:
229
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
230
+
231
+ with torch.no_grad():
232
+ sample = self.pipeline(
233
+ prompt_textbox,
234
+ negative_prompt = negative_prompt_textbox,
235
+ num_inference_steps = sample_step_slider,
236
+ guidance_scale = cfg_scale_slider,
237
+ width = width_slider,
238
+ height = height_slider,
239
+ num_frames = _partial_video_length,
240
+ generator = generator,
241
+
242
+ video = input_video,
243
+ mask_video = input_video_mask,
244
+ strength = 1,
245
+ ).videos
246
+
247
+ if init_frames != 0:
248
+ mix_ratio = torch.from_numpy(
249
+ np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
250
+ ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
251
+
252
+ new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
253
+ sample[:, :, :overlap_video_length] * mix_ratio
254
+ new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
255
+
256
+ sample = new_sample
257
+ else:
258
+ new_sample = sample
259
+
260
+ if last_frames >= length_slider:
261
+ break
262
+
263
+ start_image = [
264
+ Image.fromarray(
265
+ (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
266
+ ) for _index in range(-overlap_video_length, 0)
267
+ ]
268
+
269
+ init_frames = init_frames + _partial_video_length - overlap_video_length
270
+ last_frames = init_frames + _partial_video_length
271
+ else:
272
+ if validation_video is not None:
273
+ input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps)
274
+ strength = denoise_strength
275
+ else:
276
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
277
+ strength = 1
278
+
279
+ sample = self.pipeline(
280
+ prompt_textbox,
281
+ negative_prompt = negative_prompt_textbox,
282
+ num_inference_steps = sample_step_slider,
283
+ guidance_scale = cfg_scale_slider,
284
+ width = width_slider,
285
+ height = height_slider,
286
+ num_frames = length_slider if not is_image else 1,
287
+ generator = generator,
288
+
289
+ video = input_video,
290
+ mask_video = input_video_mask,
291
+ strength = strength,
292
+ ).videos
293
+ else:
294
+ sample = self.pipeline(
295
+ prompt_textbox,
296
+ negative_prompt = negative_prompt_textbox,
297
+ num_inference_steps = sample_step_slider,
298
+ guidance_scale = cfg_scale_slider,
299
+ width = width_slider,
300
+ height = height_slider,
301
+ num_frames = length_slider if not is_image else 1,
302
+ generator = generator
303
+ ).videos
304
+ else:
305
+ input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps)
306
+
307
+ sample = self.pipeline(
308
+ prompt_textbox,
309
+ negative_prompt = negative_prompt_textbox,
310
+ num_inference_steps = sample_step_slider,
311
+ guidance_scale = cfg_scale_slider,
312
+ width = width_slider,
313
+ height = height_slider,
314
+ num_frames = length_slider if not is_image else 1,
315
+ generator = generator,
316
+
317
+ control_video = input_video,
318
+ ).videos
319
+ except Exception as e:
320
+ self.auto_model_clear_cache(self.pipeline.transformer)
321
+ self.auto_model_clear_cache(self.pipeline.text_encoder)
322
+ self.auto_model_clear_cache(self.pipeline.vae)
323
+ self.clear_cache()
324
+
325
+ print(f"Error. error information is {str(e)}")
326
+ if self.lora_model_path != "none":
327
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
328
+ if is_api:
329
+ return "", f"Error. error information is {str(e)}"
330
+ else:
331
+ return gr.update(), gr.update(), f"Error. error information is {str(e)}"
332
+
333
+ self.clear_cache()
334
+ # lora part
335
+ if self.lora_model_path != "none":
336
+ print(f"Unmerge Lora.")
337
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
338
+ print(f"Unmerge Lora done.")
339
+
340
+ print(f"Saving outputs.")
341
+ save_sample_path = self.save_outputs(
342
+ is_image, length_slider, sample, fps=fps
343
+ )
344
+ print(f"Saving outputs done.")
345
+
346
+ if is_image or length_slider == 1:
347
+ if is_api:
348
+ return save_sample_path, "Success"
349
+ else:
350
+ if gradio_version_is_above_4:
351
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
352
+ else:
353
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
354
+ else:
355
+ if is_api:
356
+ return save_sample_path, "Success"
357
+ else:
358
+ if gradio_version_is_above_4:
359
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
360
+ else:
361
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
362
+
363
+ CogVideoXFunController_Host = CogVideoXFunController
364
+ CogVideoXFunController_Client = Fun_Controller_Client
365
+
366
+ def ui(GPU_memory_mode, scheduler_dict, compile_dit, weight_dtype, savedir_sample=None):
367
+ controller = CogVideoXFunController(
368
+ GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
369
+ compile_dit=compile_dit,
370
+ weight_dtype=weight_dtype, savedir_sample=savedir_sample,
371
+ )
372
+
373
+ with gr.Blocks(css=css) as demo:
374
+ gr.Markdown(
375
+ """
376
+ # CogVideoX-Fun:
377
+
378
+ A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
379
+
380
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
381
+ """
382
+ )
383
+ with gr.Column(variant="panel"):
384
+ model_type = create_model_type(visible=True)
385
+ diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \
386
+ create_model_checkpoints(controller, visible=True)
387
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \
388
+ create_finetune_models_checkpoints(controller, visible=True)
389
+
390
+ with gr.Column(variant="panel"):
391
+ prompt_textbox, negative_prompt_textbox = create_prompts()
392
+
393
+ with gr.Row():
394
+ with gr.Column():
395
+ sampler_dropdown, sample_step_slider = create_samplers(controller)
396
+
397
+ resize_method, width_slider, height_slider, base_resolution = create_height_width(
398
+ default_height = 384, default_width = 672, maximum_height = 1344,
399
+ maximum_width = 1344,
400
+ )
401
+ gr.Markdown(
402
+ """
403
+ V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames.
404
+ (V1.0和V1.1支持最大49��视频生成,V1.5支持最大85帧视频生成。)
405
+ """
406
+ )
407
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
408
+ create_generation_methods_and_video_length(
409
+ ["Video Generation", "Image Generation", "Long Video Generation"],
410
+ default_video_length=49,
411
+ maximum_video_length=85,
412
+ )
413
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
414
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], prompt_textbox
415
+ )
416
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
417
+
418
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
419
+
420
+ result_image, result_video, infer_progress = create_ui_outputs()
421
+
422
+ model_type.change(
423
+ fn=controller.update_model_type,
424
+ inputs=[model_type],
425
+ outputs=[]
426
+ )
427
+
428
+ def upload_generation_method(generation_method):
429
+ if generation_method == "Video Generation":
430
+ return [gr.update(visible=True, maximum=85, value=49, interactive=True), gr.update(visible=False), gr.update(visible=False)]
431
+ elif generation_method == "Image Generation":
432
+ return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)]
433
+ else:
434
+ return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)]
435
+ generation_method.change(
436
+ upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
437
+ )
438
+
439
+ def upload_source_method(source_method):
440
+ if source_method == "Text to Video (文本到视频)":
441
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
442
+ elif source_method == "Image to Video (图片到视频)":
443
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
444
+ elif source_method == "Video to Video (视频到视频)":
445
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
446
+ else:
447
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
448
+ source_method.change(
449
+ upload_source_method, source_method, [
450
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
451
+ validation_video, validation_video_mask, control_video
452
+ ]
453
+ )
454
+
455
+ def upload_resize_method(resize_method):
456
+ if resize_method == "Generate by":
457
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
458
+ else:
459
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
460
+ resize_method.change(
461
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
462
+ )
463
+
464
+ generate_button.click(
465
+ fn=controller.generate,
466
+ inputs=[
467
+ diffusion_transformer_dropdown,
468
+ base_model_dropdown,
469
+ lora_model_dropdown,
470
+ lora_alpha_slider,
471
+ prompt_textbox,
472
+ negative_prompt_textbox,
473
+ sampler_dropdown,
474
+ sample_step_slider,
475
+ resize_method,
476
+ width_slider,
477
+ height_slider,
478
+ base_resolution,
479
+ generation_method,
480
+ length_slider,
481
+ overlap_video_length,
482
+ partial_video_length,
483
+ cfg_scale_slider,
484
+ start_image,
485
+ end_image,
486
+ validation_video,
487
+ validation_video_mask,
488
+ control_video,
489
+ denoise_strength,
490
+ seed_textbox,
491
+ ],
492
+ outputs=[result_image, result_video, infer_progress]
493
+ )
494
+ return demo, controller
495
+
496
+ def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, compile_dit, weight_dtype, savedir_sample=None):
497
+ controller = CogVideoXFunController_Host(
498
+ GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type,
499
+ compile_dit=compile_dit,
500
+ weight_dtype=weight_dtype, savedir_sample=savedir_sample,
501
+ )
502
+
503
+ with gr.Blocks(css=css) as demo:
504
+ gr.Markdown(
505
+ """
506
+ # CogVideoX-Fun
507
+
508
+ A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
509
+
510
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
511
+ """
512
+ )
513
+ with gr.Column(variant="panel"):
514
+ model_type = create_fake_model_type(visible=False)
515
+ diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
516
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
517
+
518
+ with gr.Column(variant="panel"):
519
+ prompt_textbox, negative_prompt_textbox = create_prompts()
520
+
521
+ with gr.Row():
522
+ with gr.Column():
523
+ sampler_dropdown, sample_step_slider = create_samplers(controller)
524
+
525
+ resize_method, width_slider, height_slider, base_resolution = create_height_width(
526
+ default_height = 384, default_width = 672, maximum_height = 1344,
527
+ maximum_width = 1344,
528
+ )
529
+ gr.Markdown(
530
+ """
531
+ V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames.
532
+ (V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。)
533
+ """
534
+ )
535
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
536
+ create_generation_methods_and_video_length(
537
+ ["Video Generation", "Image Generation"],
538
+ default_video_length=49,
539
+ maximum_video_length=85,
540
+ )
541
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
542
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], prompt_textbox
543
+ )
544
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
545
+
546
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
547
+
548
+ result_image, result_video, infer_progress = create_ui_outputs()
549
+
550
+ def upload_generation_method(generation_method):
551
+ if generation_method == "Video Generation":
552
+ return gr.update(visible=True, minimum=8, maximum=85, value=49, interactive=True)
553
+ elif generation_method == "Image Generation":
554
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
555
+ generation_method.change(
556
+ upload_generation_method, generation_method, [length_slider]
557
+ )
558
+
559
+ def upload_source_method(source_method):
560
+ if source_method == "Text to Video (文本到视频)":
561
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
562
+ elif source_method == "Image to Video (图片到视频)":
563
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
564
+ elif source_method == "Video to Video (视频到视频)":
565
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
566
+ else:
567
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
568
+ source_method.change(
569
+ upload_source_method, source_method, [
570
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
571
+ validation_video, validation_video_mask, control_video
572
+ ]
573
+ )
574
+
575
+ def upload_resize_method(resize_method):
576
+ if resize_method == "Generate by":
577
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
578
+ else:
579
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
580
+ resize_method.change(
581
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
582
+ )
583
+
584
+ generate_button.click(
585
+ fn=controller.generate,
586
+ inputs=[
587
+ diffusion_transformer_dropdown,
588
+ base_model_dropdown,
589
+ lora_model_dropdown,
590
+ lora_alpha_slider,
591
+ prompt_textbox,
592
+ negative_prompt_textbox,
593
+ sampler_dropdown,
594
+ sample_step_slider,
595
+ resize_method,
596
+ width_slider,
597
+ height_slider,
598
+ base_resolution,
599
+ generation_method,
600
+ length_slider,
601
+ overlap_video_length,
602
+ partial_video_length,
603
+ cfg_scale_slider,
604
+ start_image,
605
+ end_image,
606
+ validation_video,
607
+ validation_video_mask,
608
+ control_video,
609
+ denoise_strength,
610
+ seed_textbox,
611
+ ],
612
+ outputs=[result_image, result_video, infer_progress]
613
+ )
614
+ return demo, controller
615
+
616
+ def ui_client(scheduler_dict, model_name, savedir_sample=None):
617
+ controller = CogVideoXFunController_Client(scheduler_dict, savedir_sample)
618
+
619
+ with gr.Blocks(css=css) as demo:
620
+ gr.Markdown(
621
+ """
622
+ # CogVideoX-Fun
623
+
624
+ A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
625
+
626
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
627
+ """
628
+ )
629
+ with gr.Column(variant="panel"):
630
+ diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
631
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
632
+
633
+ with gr.Column(variant="panel"):
634
+ prompt_textbox, negative_prompt_textbox = create_prompts()
635
+
636
+ with gr.Row():
637
+ with gr.Column():
638
+ sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50)
639
+
640
+ resize_method, width_slider, height_slider, base_resolution = create_fake_height_width(
641
+ default_height = 384, default_width = 672, maximum_height = 1344,
642
+ maximum_width = 1344,
643
+ )
644
+ gr.Markdown(
645
+ """
646
+ V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames.
647
+ (V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。)
648
+ """
649
+ )
650
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
651
+ create_generation_methods_and_video_length(
652
+ ["Video Generation", "Image Generation"],
653
+ default_video_length=49,
654
+ maximum_video_length=85,
655
+ )
656
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
657
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"], prompt_textbox
658
+ )
659
+
660
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
661
+
662
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
663
+
664
+ result_image, result_video, infer_progress = create_ui_outputs()
665
+
666
+ def upload_generation_method(generation_method):
667
+ if generation_method == "Video Generation":
668
+ return gr.update(visible=True, minimum=5, maximum=85, value=49, interactive=True)
669
+ elif generation_method == "Image Generation":
670
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
671
+ generation_method.change(
672
+ upload_generation_method, generation_method, [length_slider]
673
+ )
674
+
675
+ def upload_source_method(source_method):
676
+ if source_method == "Text to Video (文本到视频)":
677
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
678
+ elif source_method == "Image to Video (图片到视频)":
679
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)]
680
+ else:
681
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()]
682
+ source_method.change(
683
+ upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask]
684
+ )
685
+
686
+ def upload_resize_method(resize_method):
687
+ if resize_method == "Generate by":
688
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
689
+ else:
690
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
691
+ resize_method.change(
692
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
693
+ )
694
+
695
+ generate_button.click(
696
+ fn=controller.generate,
697
+ inputs=[
698
+ diffusion_transformer_dropdown,
699
+ base_model_dropdown,
700
+ lora_model_dropdown,
701
+ lora_alpha_slider,
702
+ prompt_textbox,
703
+ negative_prompt_textbox,
704
+ sampler_dropdown,
705
+ sample_step_slider,
706
+ resize_method,
707
+ width_slider,
708
+ height_slider,
709
+ base_resolution,
710
+ generation_method,
711
+ length_slider,
712
+ cfg_scale_slider,
713
+ start_image,
714
+ end_image,
715
+ validation_video,
716
+ validation_video_mask,
717
+ denoise_strength,
718
+ seed_textbox,
719
+ ],
720
+ outputs=[result_image, result_video, infer_progress]
721
+ )
722
+ return demo, controller
videox_fun/ui/controller.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
2
+ """
3
+ import base64
4
+ import gc
5
+ import json
6
+ import os
7
+ import hashlib
8
+ import random
9
+ from datetime import datetime
10
+ from glob import glob
11
+
12
+ import cv2
13
+ import gradio as gr
14
+ import numpy as np
15
+ import pkg_resources
16
+ import requests
17
+ import torch
18
+ from diffusers import (CogVideoXDDIMScheduler, DDIMScheduler,
19
+ DPMSolverMultistepScheduler,
20
+ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
21
+ FlowMatchEulerDiscreteScheduler, PNDMScheduler)
22
+ from omegaconf import OmegaConf
23
+ from PIL import Image
24
+ from safetensors import safe_open
25
+
26
+ from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
27
+ from ..utils.utils import save_videos_grid
28
+ from ..utils.fm_solvers import FlowDPMSolverMultistepScheduler
29
+ from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
30
+ from ..dist import set_multi_gpus_devices
31
+
32
+ gradio_version = pkg_resources.get_distribution("gradio").version
33
+ gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False
34
+
35
+ css = """
36
+ .toolbutton {
37
+ margin-buttom: 0em 0em 0em 0em;
38
+ max-width: 2.5em;
39
+ min-width: 2.5em !important;
40
+ height: 2.5em;
41
+ }
42
+ """
43
+
44
+ ddpm_scheduler_dict = {
45
+ "Euler": EulerDiscreteScheduler,
46
+ "Euler A": EulerAncestralDiscreteScheduler,
47
+ "DPM++": DPMSolverMultistepScheduler,
48
+ "PNDM": PNDMScheduler,
49
+ "DDIM": DDIMScheduler,
50
+ "DDIM_Origin": DDIMScheduler,
51
+ "DDIM_Cog": CogVideoXDDIMScheduler,
52
+ }
53
+ flow_scheduler_dict = {
54
+ "Flow": FlowMatchEulerDiscreteScheduler,
55
+ "Flow_Unipc": FlowUniPCMultistepScheduler,
56
+ "Flow_DPM++": FlowDPMSolverMultistepScheduler,
57
+ }
58
+ all_cheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict}
59
+
60
+ class Fun_Controller:
61
+ def __init__(
62
+ self, GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
63
+ config_path=None, ulysses_degree=1, ring_degree=1,
64
+ fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False,
65
+ weight_dtype=None, savedir_sample=None,
66
+ ):
67
+ # config dirs
68
+ self.basedir = os.getcwd()
69
+ self.config_dir = os.path.join(self.basedir, "config")
70
+ self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer")
71
+ self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
72
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
73
+ if savedir_sample is None:
74
+ self.savedir_sample = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
75
+ else:
76
+ self.savedir_sample = savedir_sample
77
+ os.makedirs(self.savedir_sample, exist_ok=True)
78
+
79
+ self.GPU_memory_mode = GPU_memory_mode
80
+ self.model_name = model_name
81
+ self.diffusion_transformer_dropdown = model_name
82
+ self.scheduler_dict = scheduler_dict
83
+ self.model_type = model_type
84
+ if config_path is not None:
85
+ self.config_path = os.path.realpath(config_path)
86
+ self.config = OmegaConf.load(config_path)
87
+ else:
88
+ self.config_path = None
89
+ self.ulysses_degree = ulysses_degree
90
+ self.ring_degree = ring_degree
91
+ self.fsdp_dit = fsdp_dit
92
+ self.fsdp_text_encoder = fsdp_text_encoder
93
+ self.compile_dit = compile_dit
94
+ self.weight_dtype = weight_dtype
95
+ self.device = set_multi_gpus_devices(self.ulysses_degree, self.ring_degree)
96
+
97
+ self.diffusion_transformer_list = []
98
+ self.motion_module_list = []
99
+ self.personalized_model_list = []
100
+ self.config_list = []
101
+
102
+ # config models
103
+ self.tokenizer = None
104
+ self.text_encoder = None
105
+ self.vae = None
106
+ self.transformer = None
107
+ self.transformer_2 = None
108
+ self.pipeline = None
109
+ self.base_model_path = "none"
110
+ self.base_model_2_path = "none"
111
+ self.lora_model_path = "none"
112
+ self.lora_model_2_path = "none"
113
+
114
+ self.refresh_config()
115
+ self.refresh_diffusion_transformer()
116
+ self.refresh_personalized_model()
117
+ if model_name != None:
118
+ self.update_diffusion_transformer(model_name)
119
+
120
+ def refresh_config(self):
121
+ config_list = []
122
+ for root, dirs, files in os.walk(self.config_dir):
123
+ for file in files:
124
+ if file.endswith(('.yaml', '.yml')):
125
+ full_path = os.path.join(root, file)
126
+ config_list.append(full_path)
127
+ self.config_list = config_list
128
+
129
+ def refresh_diffusion_transformer(self):
130
+ self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
131
+
132
+ def refresh_personalized_model(self):
133
+ personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
134
+ self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
135
+
136
+ def update_model_type(self, model_type):
137
+ self.model_type = model_type
138
+
139
+ def update_config(self, config_dropdown):
140
+ self.config_path = config_dropdown
141
+ self.config = OmegaConf.load(config_dropdown)
142
+ print(f"Update config: {config_dropdown}")
143
+
144
+ def update_diffusion_transformer(self, diffusion_transformer_dropdown):
145
+ pass
146
+
147
+ def update_base_model(self, base_model_dropdown, is_checkpoint_2=False):
148
+ if not is_checkpoint_2:
149
+ self.base_model_path = base_model_dropdown
150
+ else:
151
+ self.base_model_2_path = base_model_dropdown
152
+ print(f"Update base model: {base_model_dropdown}")
153
+ if base_model_dropdown == "none":
154
+ return gr.update()
155
+ if self.transformer is None and not is_checkpoint_2:
156
+ gr.Info(f"Please select a pretrained model path.")
157
+ print(f"Please select a pretrained model path.")
158
+ return gr.update(value=None)
159
+ elif self.transformer_2 is None and is_checkpoint_2:
160
+ gr.Info(f"Please select a pretrained model path.")
161
+ print(f"Please select a pretrained model path.")
162
+ return gr.update(value=None)
163
+ else:
164
+ base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
165
+ base_model_state_dict = {}
166
+ with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
167
+ for key in f.keys():
168
+ base_model_state_dict[key] = f.get_tensor(key)
169
+ if not is_checkpoint_2:
170
+ self.transformer.load_state_dict(base_model_state_dict, strict=False)
171
+ else:
172
+ self.transformer_2.load_state_dict(base_model_state_dict, strict=False)
173
+ print("Update base model done")
174
+ return gr.update()
175
+
176
+ def update_lora_model(self, lora_model_dropdown, is_checkpoint_2=False):
177
+ print(f"Update lora model: {lora_model_dropdown}")
178
+ if lora_model_dropdown == "none":
179
+ self.lora_model_path = "none"
180
+ return gr.update()
181
+ lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
182
+ if not is_checkpoint_2:
183
+ self.lora_model_path = lora_model_dropdown
184
+ else:
185
+ self.lora_model_2_path = lora_model_dropdown
186
+ return gr.update()
187
+
188
+ def clear_cache(self,):
189
+ gc.collect()
190
+ torch.cuda.empty_cache()
191
+ torch.cuda.ipc_collect()
192
+
193
+ def auto_model_clear_cache(self, model):
194
+ origin_device = model.device
195
+ model = model.to("cpu")
196
+ gc.collect()
197
+ torch.cuda.empty_cache()
198
+ torch.cuda.ipc_collect()
199
+ model = model.to(origin_device)
200
+
201
+ def input_check(self,
202
+ resize_method,
203
+ generation_method,
204
+ start_image,
205
+ end_image,
206
+ validation_video,
207
+ control_video,
208
+ is_api = False,
209
+ ):
210
+ if self.transformer is None:
211
+ if is_api:
212
+ return "", f"Please select a pretrained model path."
213
+ else:
214
+ raise gr.Error(f"Please select a pretrained model path.")
215
+
216
+ if control_video is not None and self.model_type == "Inpaint":
217
+ if is_api:
218
+ return "", f"If specifying the control video, please set the model_type == \"Control\". "
219
+ else:
220
+ raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ")
221
+
222
+ if control_video is None and self.model_type == "Control":
223
+ if is_api:
224
+ return "", f"If set the model_type == \"Control\", please specifying the control video. "
225
+ else:
226
+ raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ")
227
+
228
+ if resize_method == "Resize according to Reference":
229
+ if start_image is None and validation_video is None and control_video is None:
230
+ if is_api:
231
+ return "", f"Please upload an image when using \"Resize according to Reference\"."
232
+ else:
233
+ raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
234
+
235
+ if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None:
236
+ if is_api:
237
+ return "", f"Please select an image to video pretrained model while using image to video."
238
+ else:
239
+ raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
240
+
241
+ if self.transformer.config.in_channels == self.vae.config.latent_channels and generation_method == "Long Video Generation":
242
+ if is_api:
243
+ return "", f"Please select an image to video pretrained model while using long video generation."
244
+ else:
245
+ raise gr.Error(f"Please select an image to video pretrained model while using long video generation.")
246
+
247
+ if start_image is None and end_image is not None:
248
+ if is_api:
249
+ return "", f"If specifying the ending image of the video, please specify a starting image of the video."
250
+ else:
251
+ raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
252
+ return "", "OK"
253
+
254
+ def get_height_width_from_reference(
255
+ self,
256
+ base_resolution,
257
+ start_image,
258
+ validation_video,
259
+ control_video,
260
+ ):
261
+ spatial_compression_ratio = self.vae.config.spatial_compression_ratio if hasattr(self.vae.config, "spatial_compression_ratio") else 8
262
+ aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
263
+ if self.model_type == "Inpaint":
264
+ if validation_video is not None:
265
+ original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
266
+ else:
267
+ original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
268
+ else:
269
+ original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size
270
+ closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
271
+ height_slider, width_slider = [int(x / spatial_compression_ratio / 2) * spatial_compression_ratio * 2 for x in closest_size]
272
+ return height_slider, width_slider
273
+
274
+ def save_outputs(self, is_image, length_slider, sample, fps):
275
+ def save_results():
276
+ if not os.path.exists(self.savedir_sample):
277
+ os.makedirs(self.savedir_sample, exist_ok=True)
278
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
279
+ prefix = str(index).zfill(8)
280
+
281
+ md5_hash = hashlib.md5(sample.cpu().numpy().tobytes()).hexdigest()
282
+
283
+ if is_image or length_slider == 1:
284
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.png")
285
+ print(f"Saving to {save_sample_path}")
286
+ image = sample[0, :, 0]
287
+ image = image.transpose(0, 1).transpose(1, 2)
288
+ image = (image * 255).numpy().astype(np.uint8)
289
+ image = Image.fromarray(image)
290
+ image.save(save_sample_path)
291
+
292
+ else:
293
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.mp4")
294
+ print(f"Saving to {save_sample_path}")
295
+ save_videos_grid(sample, save_sample_path, fps=fps)
296
+ return save_sample_path
297
+
298
+ if self.ulysses_degree * self.ring_degree > 1:
299
+ import torch.distributed as dist
300
+ if dist.get_rank() == 0:
301
+ save_sample_path = save_results()
302
+ else:
303
+ save_sample_path = None
304
+ else:
305
+ save_sample_path = save_results()
306
+ return save_sample_path
307
+
308
+ def generate(
309
+ self,
310
+ diffusion_transformer_dropdown,
311
+ base_model_dropdown,
312
+ lora_model_dropdown,
313
+ lora_alpha_slider,
314
+ prompt_textbox,
315
+ negative_prompt_textbox,
316
+ sampler_dropdown,
317
+ sample_step_slider,
318
+ resize_method,
319
+ width_slider,
320
+ height_slider,
321
+ base_resolution,
322
+ generation_method,
323
+ length_slider,
324
+ overlap_video_length,
325
+ partial_video_length,
326
+ cfg_scale_slider,
327
+ start_image,
328
+ end_image,
329
+ validation_video,
330
+ validation_video_mask,
331
+ control_video,
332
+ denoise_strength,
333
+ seed_textbox,
334
+ enable_teacache = None,
335
+ teacache_threshold = None,
336
+ num_skip_start_steps = None,
337
+ teacache_offload = None,
338
+ cfg_skip_ratio = None,
339
+ enable_riflex = None,
340
+ riflex_k = None,
341
+ is_api = False,
342
+ ):
343
+ pass
344
+
345
+ def post_to_host(
346
+ diffusion_transformer_dropdown,
347
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
348
+ prompt_textbox, negative_prompt_textbox,
349
+ sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
350
+ base_resolution, generation_method, length_slider, cfg_scale_slider,
351
+ start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox,
352
+ ref_image = None, enable_teacache = None, teacache_threshold = None, num_skip_start_steps = None,
353
+ teacache_offload = None, cfg_skip_ratio = None,enable_riflex = None, riflex_k = None,
354
+ ):
355
+ if start_image is not None:
356
+ with open(start_image, 'rb') as file:
357
+ file_content = file.read()
358
+ start_image_encoded_content = base64.b64encode(file_content)
359
+ start_image = start_image_encoded_content.decode('utf-8')
360
+
361
+ if end_image is not None:
362
+ with open(end_image, 'rb') as file:
363
+ file_content = file.read()
364
+ end_image_encoded_content = base64.b64encode(file_content)
365
+ end_image = end_image_encoded_content.decode('utf-8')
366
+
367
+ if validation_video is not None:
368
+ with open(validation_video, 'rb') as file:
369
+ file_content = file.read()
370
+ validation_video_encoded_content = base64.b64encode(file_content)
371
+ validation_video = validation_video_encoded_content.decode('utf-8')
372
+
373
+ if validation_video_mask is not None:
374
+ with open(validation_video_mask, 'rb') as file:
375
+ file_content = file.read()
376
+ validation_video_mask_encoded_content = base64.b64encode(file_content)
377
+ validation_video_mask = validation_video_mask_encoded_content.decode('utf-8')
378
+
379
+ if ref_image is not None:
380
+ with open(ref_image, 'rb') as file:
381
+ file_content = file.read()
382
+ ref_image_encoded_content = base64.b64encode(file_content)
383
+ ref_image = ref_image_encoded_content.decode('utf-8')
384
+
385
+ datas = {
386
+ "base_model_path": base_model_dropdown,
387
+ "lora_model_path": lora_model_dropdown,
388
+ "lora_alpha_slider": lora_alpha_slider,
389
+ "prompt_textbox": prompt_textbox,
390
+ "negative_prompt_textbox": negative_prompt_textbox,
391
+ "sampler_dropdown": sampler_dropdown,
392
+ "sample_step_slider": sample_step_slider,
393
+ "resize_method": resize_method,
394
+ "width_slider": width_slider,
395
+ "height_slider": height_slider,
396
+ "base_resolution": base_resolution,
397
+ "generation_method": generation_method,
398
+ "length_slider": length_slider,
399
+ "cfg_scale_slider": cfg_scale_slider,
400
+ "start_image": start_image,
401
+ "end_image": end_image,
402
+ "validation_video": validation_video,
403
+ "validation_video_mask": validation_video_mask,
404
+ "denoise_strength": denoise_strength,
405
+ "seed_textbox": seed_textbox,
406
+
407
+ "ref_image": ref_image,
408
+ "enable_teacache": enable_teacache,
409
+ "teacache_threshold": teacache_threshold,
410
+ "num_skip_start_steps": num_skip_start_steps,
411
+ "teacache_offload": teacache_offload,
412
+ "cfg_skip_ratio": cfg_skip_ratio,
413
+ "enable_riflex": enable_riflex,
414
+ "riflex_k": riflex_k,
415
+ }
416
+
417
+ session = requests.session()
418
+ session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")})
419
+
420
+ response = session.post(url=f'{os.environ.get("EAS_URL")}/videox_fun/infer_forward', json=datas, timeout=300)
421
+
422
+ outputs = response.json()
423
+ return outputs
424
+
425
+
426
+ class Fun_Controller_Client:
427
+ def __init__(self, scheduler_dict, savedir_sample):
428
+ self.basedir = os.getcwd()
429
+ if savedir_sample is None:
430
+ self.savedir_sample = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
431
+ else:
432
+ self.savedir_sample = savedir_sample
433
+ os.makedirs(self.savedir_sample, exist_ok=True)
434
+
435
+ self.scheduler_dict = scheduler_dict
436
+
437
+ def generate(
438
+ self,
439
+ diffusion_transformer_dropdown,
440
+ base_model_dropdown,
441
+ lora_model_dropdown,
442
+ lora_alpha_slider,
443
+ prompt_textbox,
444
+ negative_prompt_textbox,
445
+ sampler_dropdown,
446
+ sample_step_slider,
447
+ resize_method,
448
+ width_slider,
449
+ height_slider,
450
+ base_resolution,
451
+ generation_method,
452
+ length_slider,
453
+ cfg_scale_slider,
454
+ start_image,
455
+ end_image,
456
+ validation_video,
457
+ validation_video_mask,
458
+ denoise_strength,
459
+ seed_textbox,
460
+ ref_image = None,
461
+ enable_teacache = None,
462
+ teacache_threshold = None,
463
+ num_skip_start_steps = None,
464
+ teacache_offload = None,
465
+ cfg_skip_ratio = None,
466
+ enable_riflex = None,
467
+ riflex_k = None,
468
+ ):
469
+ is_image = True if generation_method == "Image Generation" else False
470
+
471
+ outputs = post_to_host(
472
+ diffusion_transformer_dropdown,
473
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
474
+ prompt_textbox, negative_prompt_textbox,
475
+ sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
476
+ base_resolution, generation_method, length_slider, cfg_scale_slider,
477
+ start_image, end_image, validation_video, validation_video_mask, denoise_strength,
478
+ seed_textbox, ref_image = ref_image, enable_teacache = enable_teacache, teacache_threshold = teacache_threshold,
479
+ num_skip_start_steps = num_skip_start_steps, teacache_offload = teacache_offload,
480
+ cfg_skip_ratio = cfg_skip_ratio, enable_riflex = enable_riflex, riflex_k = riflex_k,
481
+ )
482
+
483
+ try:
484
+ base64_encoding = outputs["base64_encoding"]
485
+ except:
486
+ return gr.Image(visible=False, value=None), gr.Video(None, visible=True), outputs["message"]
487
+
488
+ decoded_data = base64.b64decode(base64_encoding)
489
+
490
+ if not os.path.exists(self.savedir_sample):
491
+ os.makedirs(self.savedir_sample, exist_ok=True)
492
+ md5_hash = hashlib.md5(decoded_data).hexdigest()
493
+
494
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
495
+ prefix = str(index).zfill(8)
496
+
497
+ if is_image or length_slider == 1:
498
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.png")
499
+ print(f"Saving to {save_sample_path}")
500
+ with open(save_sample_path, "wb") as file:
501
+ file.write(decoded_data)
502
+ if gradio_version_is_above_4:
503
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
504
+ else:
505
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
506
+ else:
507
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.mp4")
508
+ print(f"Saving to {save_sample_path}")
509
+ with open(save_sample_path, "wb") as file:
510
+ file.write(decoded_data)
511
+ if gradio_version_is_above_4:
512
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
513
+ else:
514
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
videox_fun/ui/ui.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import gradio as gr
4
+
5
+
6
+ def create_model_type(visible):
7
+ gr.Markdown(
8
+ """
9
+ ### Model Type.
10
+ """,
11
+ visible=visible,
12
+ )
13
+ with gr.Row():
14
+ model_type = gr.Dropdown(
15
+ label="The model type of the model",
16
+ choices=["Inpaint", "Control"],
17
+ value="Inpaint",
18
+ visible=visible,
19
+ interactive=True,
20
+ )
21
+ return model_type
22
+
23
+ def create_fake_model_type(visible):
24
+ gr.Markdown(
25
+ """
26
+ ### Model Type.
27
+ """,
28
+ visible=visible,
29
+ )
30
+ with gr.Row():
31
+ model_type = gr.Dropdown(
32
+ label="The model type of the model",
33
+ choices=["Inpaint", "Control"],
34
+ value="Inpaint",
35
+ interactive=False,
36
+ visible=visible,
37
+ )
38
+ return model_type
39
+
40
+ def create_model_checkpoints(controller, visible, default_model="none"):
41
+ gr.Markdown(
42
+ """
43
+ ### Model checkpoints.
44
+ """
45
+ )
46
+ with gr.Row(visible=visible):
47
+ diffusion_transformer_dropdown = gr.Dropdown(
48
+ label="Pretrained Model Path",
49
+ choices=list(set(controller.diffusion_transformer_list + [default_model])),
50
+ value=default_model,
51
+ interactive=True,
52
+ )
53
+ diffusion_transformer_dropdown.change(
54
+ fn=controller.update_diffusion_transformer,
55
+ inputs=[diffusion_transformer_dropdown],
56
+ outputs=[diffusion_transformer_dropdown]
57
+ )
58
+
59
+ diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
60
+ def refresh_diffusion_transformer():
61
+ controller.refresh_diffusion_transformer()
62
+ return gr.update(choices=controller.diffusion_transformer_list)
63
+ diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
64
+
65
+ return diffusion_transformer_dropdown, diffusion_transformer_refresh_button
66
+
67
+ def create_fake_model_checkpoints(model_name, visible):
68
+ gr.Markdown(
69
+ """
70
+ ### Model checkpoints.
71
+ """
72
+ )
73
+ with gr.Row(visible=visible):
74
+ diffusion_transformer_dropdown = gr.Dropdown(
75
+ label="Pretrained Model Path",
76
+ choices=[model_name],
77
+ value=model_name,
78
+ interactive=False,
79
+ )
80
+ return diffusion_transformer_dropdown
81
+
82
+ def create_finetune_models_checkpoints(controller, visible, add_checkpoint_2=False, default_lora="none"):
83
+ with gr.Row(visible=visible):
84
+ base_model_dropdown = gr.Dropdown(
85
+ label="Select base Dreambooth model",
86
+ choices=["none"] + controller.personalized_model_list,
87
+ value="none",
88
+ interactive=True,
89
+ )
90
+ if add_checkpoint_2:
91
+ base_model_2_dropdown = gr.Dropdown(
92
+ label="Select base Dreambooth model",
93
+ choices=["none"] + controller.personalized_model_list,
94
+ value="none",
95
+ interactive=True,
96
+ )
97
+
98
+ lora_model_dropdown = gr.Dropdown(
99
+ label="Select LoRA model",
100
+ choices=list(set(["none"] + controller.personalized_model_list + [default_lora])),
101
+ value=default_lora,
102
+ interactive=True,
103
+ )
104
+ if add_checkpoint_2:
105
+ lora_model_2_dropdown = gr.Dropdown(
106
+ label="Select LoRA model",
107
+ choices=["none"] + controller.personalized_model_list,
108
+ value="none",
109
+ interactive=True,
110
+ )
111
+
112
+ lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.55, minimum=0, maximum=2, interactive=True)
113
+
114
+ personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
115
+ def update_personalized_model():
116
+ controller.refresh_personalized_model()
117
+ return [
118
+ gr.update(choices=controller.personalized_model_list),
119
+ gr.update(choices=["none"] + controller.personalized_model_list)
120
+ ]
121
+ personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
122
+
123
+ if not add_checkpoint_2:
124
+ return base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button
125
+ else:
126
+ return [base_model_dropdown, base_model_2_dropdown], [lora_model_dropdown, lora_model_2_dropdown], \
127
+ lora_alpha_slider, personalized_refresh_button
128
+
129
+ def create_fake_finetune_models_checkpoints(visible):
130
+ with gr.Row():
131
+ base_model_dropdown = gr.Dropdown(
132
+ label="Select base Dreambooth model",
133
+ choices=["none"],
134
+ value="none",
135
+ interactive=False,
136
+ visible=False
137
+ )
138
+ with gr.Column(visible=False):
139
+ gr.Markdown(
140
+ """
141
+ ### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/CogVideoX-Fun/wiki/Training-Lora).
142
+ """
143
+ )
144
+ with gr.Row():
145
+ lora_model_dropdown = gr.Dropdown(
146
+ label="Select LoRA model",
147
+ choices=["none"],
148
+ value="none",
149
+ interactive=True,
150
+ )
151
+
152
+ lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.55, minimum=0, maximum=2, interactive=True)
153
+
154
+ return base_model_dropdown, lora_model_dropdown, lora_alpha_slider
155
+
156
+ def create_teacache_params(
157
+ enable_teacache = True,
158
+ teacache_threshold = 0.10,
159
+ num_skip_start_steps = 1,
160
+ teacache_offload = False,
161
+ ):
162
+ enable_teacache = gr.Checkbox(label="Enable TeaCache", value=enable_teacache)
163
+ teacache_threshold = gr.Slider(0.00, 0.25, value=teacache_threshold, step=0.01, label="TeaCache Threshold")
164
+ num_skip_start_steps = gr.Slider(0, 10, value=num_skip_start_steps, step=5, label="Number of Skip Start Steps")
165
+ teacache_offload = gr.Checkbox(label="Offload TeaCache to CPU", value=teacache_offload)
166
+ return enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload
167
+
168
+ def create_cfg_skip_params(
169
+ cfg_skip_ratio = 0
170
+ ):
171
+ cfg_skip_ratio = gr.Slider(0.00, 0.50, value=cfg_skip_ratio, step=0.01, label="CFG Skip Ratio", visible=False)
172
+ return cfg_skip_ratio
173
+
174
+ def create_cfg_riflex_k(
175
+ enable_riflex = False,
176
+ riflex_k = 6
177
+ ):
178
+ enable_riflex = gr.Checkbox(label="Enable Riflex", value=enable_riflex, visible=False)
179
+ riflex_k = gr.Slider(0, 10, value=riflex_k, step=1, label="Riflex Intrinsic Frequency Index", visible=False)
180
+ return enable_riflex, riflex_k
181
+
182
+ def create_prompts(
183
+ prompt="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
184
+ negative_prompt="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
185
+ ):
186
+ gr.Markdown(
187
+ """
188
+ ### Configs for Generation.
189
+ """
190
+ )
191
+
192
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2, value=prompt)
193
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value=negative_prompt)
194
+ return prompt_textbox, negative_prompt_textbox
195
+
196
+ def create_samplers(controller, maximum_step=100):
197
+ with gr.Row():
198
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(controller.scheduler_dict.keys()), value=list(controller.scheduler_dict.keys())[0])
199
+ sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=maximum_step, step=1)
200
+
201
+ return sampler_dropdown, sample_step_slider
202
+
203
+ def create_height_width(default_height, default_width, maximum_height, maximum_width):
204
+ resize_method = gr.Radio(
205
+ ["Generate by", "Resize according to Reference"],
206
+ value="Generate by",
207
+ show_label=False,
208
+ )
209
+ width_slider = gr.Slider(label="Width", value=default_width, minimum=128, maximum=maximum_width, step=16)
210
+ height_slider = gr.Slider(label="Height", value=default_height, minimum=128, maximum=maximum_height, step=16)
211
+ base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 640, 768, 896, 960, 1024], visible=False)
212
+
213
+ return resize_method, width_slider, height_slider, base_resolution
214
+
215
+ def create_fake_height_width(default_height, default_width, maximum_height, maximum_width):
216
+ resize_method = gr.Radio(
217
+ ["Generate by", "Resize according to Reference"],
218
+ value="Generate by",
219
+ show_label=False,
220
+ )
221
+ width_slider = gr.Slider(label="Width", value=default_width, minimum=128, maximum=maximum_width, step=16, interactive=False)
222
+ height_slider = gr.Slider(label="Height", value=default_height, minimum=128, maximum=maximum_height, step=16, interactive=False)
223
+ base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 640, 768, 896, 960, 1024], interactive=False, visible=False)
224
+
225
+ return resize_method, width_slider, height_slider, base_resolution
226
+
227
+ def create_generation_methods_and_video_length(
228
+ generation_method_options,
229
+ default_video_length,
230
+ maximum_video_length
231
+ ):
232
+ with gr.Group():
233
+ generation_method = gr.Radio(
234
+ generation_method_options,
235
+ value="Video Generation",
236
+ show_label=False,
237
+ visible=False
238
+ )
239
+ with gr.Row():
240
+ length_slider = gr.Slider(label="Animation length", value=default_video_length, minimum=1, maximum=maximum_video_length, step=4, visible=False)
241
+ overlap_video_length = gr.Slider(label="Overlap length", value=4, minimum=1, maximum=4, step=1, visible=False)
242
+ partial_video_length = gr.Slider(label="Partial video generation length", value=25, minimum=5, maximum=maximum_video_length, step=4, visible=False)
243
+
244
+ return generation_method, length_slider, overlap_video_length, partial_video_length
245
+
246
+ def create_generation_method(source_method_options, prompt_textbox, support_end_image=True, support_ref_image=False, default_video=None, video_examples=None):
247
+ default_method = source_method_options[0] if source_method_options else "Text to Video"
248
+ source_method = gr.Radio(
249
+ source_method_options,
250
+ value=default_method,
251
+ show_label=False,
252
+ )
253
+ with gr.Column(visible = (default_method == "Image to Video")) as image_to_video_col:
254
+ start_image = gr.Image(
255
+ label="The image at the beginning of the video", show_label=True,
256
+ elem_id="i2v_start", sources="upload", type="filepath",
257
+ )
258
+
259
+ template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
260
+ def select_template(evt: gr.SelectData):
261
+ text = {
262
+ "asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.",
263
+ "asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.",
264
+ "asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.",
265
+ "asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.",
266
+ "asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.",
267
+ }[template_gallery_path[evt.index]]
268
+ return template_gallery_path[evt.index], text
269
+
270
+ template_gallery = gr.Gallery(
271
+ template_gallery_path,
272
+ columns=5, rows=1,
273
+ height=140,
274
+ allow_preview=False,
275
+ container=False,
276
+ label="Template Examples",
277
+ )
278
+ template_gallery.select(select_template, None, [start_image, prompt_textbox])
279
+
280
+ with gr.Accordion("The image at the ending of the video", open=False, visible=support_end_image):
281
+ end_image = gr.Image(label="The image at the ending of the video", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
282
+
283
+ with gr.Column(visible = (default_method == "Video to Video")) as video_to_video_col:
284
+ with gr.Row():
285
+ validation_video = gr.Video(
286
+ label="The video to convert", show_label=True,
287
+ elem_id="v2v", sources=["upload"], value=default_video,
288
+ )
289
+ if video_examples:
290
+ gr.Examples(
291
+ examples=video_examples,
292
+ inputs=[validation_video, prompt_textbox] if len(video_examples[0]) > 1 else validation_video,
293
+ label="Video Examples"
294
+ )
295
+
296
+ # Removed Mask Accordion entirely per request or hidden. User said "mask这个不需要"
297
+ # validation_video_mask = gr.Image(
298
+ # label="The mask of the video to inpaint",
299
+ # show_label=False, elem_id="v2v_mask", sources="upload", type="filepath",
300
+ # visible=False
301
+ # )
302
+ validation_video_mask = gr.Image(visible=False, value=None)
303
+
304
+ # Denoise strength default 1.0, hidden
305
+ denoise_strength = gr.Slider(label="Denoise strength", value=1.00, minimum=0.10, maximum=1.00, step=0.01, visible=False)
306
+
307
+ with gr.Column(visible = False) as control_video_col:
308
+ gr.Markdown(
309
+ """
310
+ Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
311
+ """
312
+ )
313
+ control_video = gr.Video(
314
+ label="The control video", show_label=True,
315
+ elem_id="v2v_control", sources="upload",
316
+ )
317
+ ref_image = gr.Image(
318
+ label="The reference image for control video", show_label=True,
319
+ elem_id="ref_image", sources="upload", type="filepath", visible=support_ref_image
320
+ )
321
+ return image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image
322
+
323
+ def create_cfg_and_seedbox(gradio_version_is_above_4):
324
+ # cfg default 6, hidden
325
+ cfg_scale_slider = gr.Slider(label="CFG Scale", value=6.0, minimum=0, maximum=20, visible=False)
326
+
327
+ with gr.Row():
328
+ seed_textbox = gr.Textbox(label="Seed", value=43)
329
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
330
+ seed_button.click(
331
+ fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)),
332
+ inputs=[],
333
+ outputs=[seed_textbox]
334
+ )
335
+ return cfg_scale_slider, seed_textbox, seed_button
336
+
337
+ def create_ui_outputs():
338
+ with gr.Column():
339
+ result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
340
+ result_video = gr.Video(label="Generated Animation", interactive=False)
341
+ infer_progress = gr.Textbox(
342
+ label="Generation Info",
343
+ value="No task currently",
344
+ interactive=False
345
+ )
346
+ return result_image, result_video, infer_progress
347
+
348
+ def create_config(controller):
349
+ gr.Markdown(
350
+ """
351
+ ### Config Path (配置文件路径)
352
+ """
353
+ )
354
+ with gr.Row():
355
+ config_dropdown = gr.Dropdown(
356
+ label="Config Path",
357
+ choices=controller.config_list,
358
+ value=controller.config_path,
359
+ interactive=True,
360
+ )
361
+ config_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
362
+ def refresh_config():
363
+ controller.refresh_config()
364
+ return gr.update(choices=controller.config_list)
365
+ config_refresh_button.click(fn=refresh_config, inputs=[], outputs=[config_dropdown])
366
+ return config_dropdown, config_refresh_button
videox_fun/ui/wan2_2_fun_ui.py ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
2
+ """
3
+ import os
4
+ import random
5
+
6
+ import cv2
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ from omegaconf import OmegaConf
11
+ from PIL import Image
12
+ from safetensors import safe_open
13
+
14
+ from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
15
+ from ..dist import set_multi_gpus_devices, shard_model
16
+ from ..models import (AutoencoderKLWan, AutoencoderKLWan3_8, AutoTokenizer,
17
+ CLIPModel, Wan2_2Transformer3DModel, WanT5EncoderModel)
18
+ from ..models.cache_utils import get_teacache_coefficients
19
+ from ..pipeline import Wan2_2FunControlPipeline, Wan2_2FunPipeline, Wan2_2FunInpaintPipeline
20
+ from ..utils.fp8_optimization import (convert_model_weight_to_float8,
21
+ convert_weight_dtype_wrapper,
22
+ replace_parameters_by_name)
23
+ from ..utils.lora_utils import merge_lora, unmerge_lora
24
+ from ..utils.utils import (filter_kwargs, get_image_latent,
25
+ get_image_to_video_latent,
26
+ get_video_to_video_latent, save_videos_grid, timer)
27
+ from .controller import (Fun_Controller, Fun_Controller_Client,
28
+ all_cheduler_dict, css, ddpm_scheduler_dict,
29
+ flow_scheduler_dict, gradio_version,
30
+ gradio_version_is_above_4)
31
+ from .ui import (create_cfg_and_seedbox, create_cfg_riflex_k,
32
+ create_cfg_skip_params, create_config,
33
+ create_fake_finetune_models_checkpoints,
34
+ create_fake_height_width, create_fake_model_checkpoints,
35
+ create_fake_model_type, create_finetune_models_checkpoints,
36
+ create_generation_method,
37
+ create_generation_methods_and_video_length,
38
+ create_height_width, create_model_checkpoints,
39
+ create_model_type, create_prompts, create_samplers,
40
+ create_teacache_params, create_ui_outputs)
41
+
42
+
43
+ class Wan2_2_Fun_Controller(Fun_Controller):
44
+ def update_diffusion_transformer(self, diffusion_transformer_dropdown):
45
+ print(f"Update diffusion transformer: {diffusion_transformer_dropdown}")
46
+ self.model_name = diffusion_transformer_dropdown
47
+ self.diffusion_transformer_dropdown = diffusion_transformer_dropdown
48
+ if diffusion_transformer_dropdown == "none":
49
+ return gr.update()
50
+ Chosen_AutoencoderKL = {
51
+ "AutoencoderKLWan": AutoencoderKLWan,
52
+ "AutoencoderKLWan3_8": AutoencoderKLWan3_8
53
+ }[self.config['vae_kwargs'].get('vae_type', 'AutoencoderKLWan')]
54
+ self.vae = Chosen_AutoencoderKL.from_pretrained(
55
+ os.path.join(diffusion_transformer_dropdown, self.config['vae_kwargs'].get('vae_subpath', 'vae')),
56
+ additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']),
57
+ ).to(self.weight_dtype)
58
+
59
+ # Get Transformer
60
+ self.transformer = Wan2_2Transformer3DModel.from_pretrained(
61
+ os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer')),
62
+ transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']),
63
+ low_cpu_mem_usage=True,
64
+ torch_dtype=self.weight_dtype,
65
+ )
66
+ if self.config['transformer_additional_kwargs'].get('transformer_combination_type', 'single') == "moe":
67
+ self.transformer_2 = Wan2_2Transformer3DModel.from_pretrained(
68
+ os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer')),
69
+ transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']),
70
+ low_cpu_mem_usage=True,
71
+ torch_dtype=self.weight_dtype,
72
+ )
73
+ else:
74
+ self.transformer_2 = None
75
+
76
+ # Get Tokenizer
77
+ self.tokenizer = AutoTokenizer.from_pretrained(
78
+ os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
79
+ )
80
+
81
+ # Get Text encoder
82
+ self.text_encoder = WanT5EncoderModel.from_pretrained(
83
+ os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
84
+ additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']),
85
+ low_cpu_mem_usage=True,
86
+ torch_dtype=self.weight_dtype,
87
+ )
88
+ self.text_encoder = self.text_encoder.eval()
89
+
90
+ Chosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]]
91
+ self.scheduler = Chosen_Scheduler(
92
+ **filter_kwargs(Chosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs']))
93
+ )
94
+
95
+ # Get pipeline
96
+ if self.model_type == "Inpaint":
97
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
98
+ self.pipeline = Wan2_2FunInpaintPipeline(
99
+ vae=self.vae,
100
+ tokenizer=self.tokenizer,
101
+ text_encoder=self.text_encoder,
102
+ transformer=self.transformer,
103
+ transformer_2=self.transformer_2,
104
+ scheduler=self.scheduler,
105
+ )
106
+ else:
107
+ self.pipeline = Wan2_2FunPipeline(
108
+ vae=self.vae,
109
+ tokenizer=self.tokenizer,
110
+ text_encoder=self.text_encoder,
111
+ transformer=self.transformer,
112
+ transformer_2=self.transformer_2,
113
+ scheduler=self.scheduler,
114
+ )
115
+ else:
116
+ self.pipeline = Wan2_2FunControlPipeline(
117
+ vae=self.vae,
118
+ tokenizer=self.tokenizer,
119
+ text_encoder=self.text_encoder,
120
+ transformer=self.transformer,
121
+ transformer_2=self.transformer_2,
122
+ scheduler=self.scheduler,
123
+ )
124
+
125
+ if self.ulysses_degree > 1 or self.ring_degree > 1:
126
+ from functools import partial
127
+ self.transformer.enable_multi_gpus_inference()
128
+ if self.transformer_2 is not None:
129
+ self.transformer_2.enable_multi_gpus_inference()
130
+ if self.fsdp_dit:
131
+ shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype)
132
+ self.pipeline.transformer = shard_fn(self.pipeline.transformer)
133
+ if self.transformer_2 is not None:
134
+ self.pipeline.transformer_2 = shard_fn(self.pipeline.transformer_2)
135
+ print("Add FSDP DIT")
136
+ if self.fsdp_text_encoder:
137
+ shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype)
138
+ self.pipeline.text_encoder = shard_fn(self.pipeline.text_encoder)
139
+ print("Add FSDP TEXT ENCODER")
140
+
141
+ if self.compile_dit:
142
+ for i in range(len(self.pipeline.transformer.blocks)):
143
+ self.pipeline.transformer.blocks[i] = torch.compile(self.pipeline.transformer.blocks[i])
144
+ if self.transformer_2 is not None:
145
+ for i in range(len(self.pipeline.transformer_2.blocks)):
146
+ self.pipeline.transformer_2.blocks[i] = torch.compile(self.pipeline.transformer_2.blocks[i])
147
+ print("Add Compile")
148
+
149
+ if self.GPU_memory_mode == "sequential_cpu_offload":
150
+ replace_parameters_by_name(self.transformer, ["modulation",], device=self.device)
151
+ self.transformer.freqs = self.transformer.freqs.to(device=self.device)
152
+ if self.transformer_2 is not None:
153
+ replace_parameters_by_name(self.transformer_2, ["modulation",], device=self.device)
154
+ self.transformer_2.freqs = self.transformer_2.freqs.to(device=self.device)
155
+ self.pipeline.enable_sequential_cpu_offload(device=self.device)
156
+ elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
157
+ convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device)
158
+ convert_weight_dtype_wrapper(self.transformer, self.weight_dtype)
159
+ if self.transformer_2 is not None:
160
+ convert_model_weight_to_float8(self.transformer_2, exclude_module_name=["modulation",], device=self.device)
161
+ convert_weight_dtype_wrapper(self.transformer_2, self.weight_dtype)
162
+ self.pipeline.enable_model_cpu_offload(device=self.device)
163
+ elif self.GPU_memory_mode == "model_cpu_offload":
164
+ self.pipeline.enable_model_cpu_offload(device=self.device)
165
+ elif self.GPU_memory_mode == "model_full_load_and_qfloat8":
166
+ convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device)
167
+ convert_weight_dtype_wrapper(self.transformer, self.weight_dtype)
168
+ if self.transformer_2 is not None:
169
+ convert_model_weight_to_float8(self.transformer_2, exclude_module_name=["modulation",], device=self.device)
170
+ convert_weight_dtype_wrapper(self.transformer_2, self.weight_dtype)
171
+ self.pipeline.to(self.device)
172
+ else:
173
+ self.pipeline.to(self.device)
174
+ print("Update diffusion transformer done")
175
+ return gr.update()
176
+
177
+ @timer
178
+ def generate(
179
+ self,
180
+ diffusion_transformer_dropdown,
181
+ base_model_dropdown,
182
+ lora_model_dropdown,
183
+ lora_alpha_slider,
184
+ prompt_textbox,
185
+ negative_prompt_textbox,
186
+ sampler_dropdown,
187
+ sample_step_slider,
188
+ resize_method,
189
+ width_slider,
190
+ height_slider,
191
+ base_resolution,
192
+ generation_method,
193
+ length_slider,
194
+ overlap_video_length,
195
+ partial_video_length,
196
+ cfg_scale_slider,
197
+ start_image,
198
+ end_image,
199
+ validation_video,
200
+ validation_video_mask,
201
+ control_video,
202
+ denoise_strength,
203
+ seed_textbox,
204
+ ref_image = None,
205
+ enable_teacache = None,
206
+ teacache_threshold = None,
207
+ num_skip_start_steps = None,
208
+ teacache_offload = None,
209
+ cfg_skip_ratio = None,
210
+ enable_riflex = None,
211
+ riflex_k = None,
212
+ base_model_2_dropdown=None,
213
+ lora_model_2_dropdown=None,
214
+ fps = None,
215
+ is_api = False,
216
+ ):
217
+ self.clear_cache()
218
+
219
+ print(f"Input checking.")
220
+ _, comment = self.input_check(
221
+ resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api
222
+ )
223
+ print(f"Input checking down")
224
+ if comment != "OK":
225
+ return "", comment
226
+ is_image = True if generation_method == "Image Generation" else False
227
+
228
+ if self.base_model_path != base_model_dropdown:
229
+ self.update_base_model(base_model_dropdown)
230
+ if self.base_model_2_path != base_model_2_dropdown:
231
+ self.update_lora_model(base_model_2_dropdown, is_checkpoint_2=True)
232
+
233
+ if self.lora_model_path != lora_model_dropdown:
234
+ self.update_lora_model(lora_model_dropdown)
235
+ if self.lora_model_2_path != lora_model_2_dropdown:
236
+ self.update_lora_model(lora_model_2_dropdown, is_checkpoint_2=True)
237
+
238
+ print(f"Load scheduler.")
239
+ scheduler_config = self.pipeline.scheduler.config
240
+ if sampler_dropdown == "Flow_Unipc" or sampler_dropdown == "Flow_DPM++":
241
+ scheduler_config['shift'] = 1
242
+ self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config)
243
+ print(f"Load scheduler down.")
244
+
245
+ if resize_method == "Resize according to Reference":
246
+ print(f"Calculate height and width according to Reference.")
247
+ height_slider, width_slider = self.get_height_width_from_reference(
248
+ base_resolution, start_image, validation_video, control_video,
249
+ )
250
+
251
+ if self.lora_model_path != "none":
252
+ print(f"Merge Lora.")
253
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
254
+ if self.transformer_2 is not None:
255
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_2_path, multiplier=lora_alpha_slider, sub_transformer_name="transformer_2")
256
+ print(f"Merge Lora done.")
257
+
258
+ coefficients = get_teacache_coefficients(self.diffusion_transformer_dropdown) if enable_teacache else None
259
+ if coefficients is not None:
260
+ print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.")
261
+ self.pipeline.transformer.enable_teacache(
262
+ coefficients, sample_step_slider, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload
263
+ )
264
+ if self.transformer_2 is not None:
265
+ self.pipeline.transformer_2.share_teacache(self.pipeline.transformer)
266
+ else:
267
+ print(f"Disable TeaCache.")
268
+ self.pipeline.transformer.disable_teacache()
269
+ if self.transformer_2 is not None:
270
+ self.pipeline.transformer_2.disable_teacache()
271
+
272
+ if cfg_skip_ratio is not None and cfg_skip_ratio >= 0:
273
+ print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.")
274
+ self.pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, sample_step_slider)
275
+ if self.transformer_2 is not None:
276
+ self.pipeline.transformer_2.share_cfg_skip(self.pipeline.transformer)
277
+
278
+ print(f"Generate seed.")
279
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
280
+ else: seed_textbox = np.random.randint(0, 1e10)
281
+ generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox))
282
+ print(f"Generate seed done.")
283
+
284
+ if fps is None:
285
+ fps = 16
286
+ boundary = self.config['transformer_additional_kwargs'].get('boundary', 0.875)
287
+
288
+ if enable_riflex:
289
+ print(f"Enable riflex")
290
+ latent_frames = (int(length_slider) - 1) // self.vae.config.temporal_compression_ratio + 1
291
+ self.pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1)
292
+ if self.transformer_2 is not None:
293
+ self.pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1)
294
+
295
+ try:
296
+ print(f"Generation.")
297
+ if self.model_type == "Inpaint":
298
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
299
+ if validation_video is not None:
300
+ input_video, input_video_mask, _, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps)
301
+ else:
302
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
303
+
304
+ sample = self.pipeline(
305
+ prompt_textbox,
306
+ negative_prompt = negative_prompt_textbox,
307
+ num_inference_steps = sample_step_slider,
308
+ guidance_scale = cfg_scale_slider,
309
+ width = width_slider,
310
+ height = height_slider,
311
+ num_frames = length_slider if not is_image else 1,
312
+ generator = generator,
313
+
314
+ video = input_video,
315
+ mask_video = input_video_mask,
316
+ boundary = boundary
317
+ ).videos
318
+ else:
319
+ sample = self.pipeline(
320
+ prompt_textbox,
321
+ negative_prompt = negative_prompt_textbox,
322
+ num_inference_steps = sample_step_slider,
323
+ guidance_scale = cfg_scale_slider,
324
+ width = width_slider,
325
+ height = height_slider,
326
+ num_frames = length_slider if not is_image else 1,
327
+ generator = generator,
328
+ boundary = boundary
329
+ ).videos
330
+ else:
331
+ inpaint_video, inpaint_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
332
+
333
+ if ref_image is not None:
334
+ ref_image = get_image_latent(ref_image, sample_size=(height_slider, width_slider))
335
+
336
+ input_video, input_video_mask, _, _ = get_video_to_video_latent(control_video, video_length=length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps, ref_image=None)
337
+
338
+ sample = self.pipeline(
339
+ prompt_textbox,
340
+ negative_prompt = negative_prompt_textbox,
341
+ num_inference_steps = sample_step_slider,
342
+ guidance_scale = cfg_scale_slider,
343
+ width = width_slider,
344
+ height = height_slider,
345
+ num_frames = length_slider if not is_image else 1,
346
+ generator = generator,
347
+
348
+ video = inpaint_video,
349
+ mask_video = inpaint_video_mask,
350
+ control_video = input_video,
351
+ ref_image = ref_image,
352
+ boundary = boundary,
353
+ ).videos
354
+ print(f"Generation done.")
355
+ except Exception as e:
356
+ self.auto_model_clear_cache(self.pipeline.transformer)
357
+ self.auto_model_clear_cache(self.pipeline.text_encoder)
358
+ self.auto_model_clear_cache(self.pipeline.vae)
359
+ self.clear_cache()
360
+
361
+ print(f"Error. error information is {str(e)}")
362
+ if self.lora_model_path != "none":
363
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
364
+ if is_api:
365
+ return "", f"Error. error information is {str(e)}"
366
+ else:
367
+ return gr.update(), gr.update(), f"Error. error information is {str(e)}"
368
+
369
+ self.clear_cache()
370
+ # lora part
371
+ if self.lora_model_path != "none":
372
+ print(f"Unmerge Lora.")
373
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
374
+ print(f"Unmerge Lora done.")
375
+
376
+ print(f"Saving outputs.")
377
+ save_sample_path = self.save_outputs(
378
+ is_image, length_slider, sample, fps=fps
379
+ )
380
+ print(f"Saving outputs done.")
381
+
382
+ if is_image or length_slider == 1:
383
+ if is_api:
384
+ return save_sample_path, "Success"
385
+ else:
386
+ if gradio_version_is_above_4:
387
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
388
+ else:
389
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
390
+ else:
391
+ if is_api:
392
+ return save_sample_path, "Success"
393
+ else:
394
+ if gradio_version_is_above_4:
395
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
396
+ else:
397
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
398
+
399
+ Wan2_2_Fun_Controller_Host = Wan2_2_Fun_Controller
400
+ Wan2_2_Fun_Controller_Client = Fun_Controller_Client
401
+
402
+ def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype, savedir_sample=None):
403
+ controller = Wan2_2_Fun_Controller(
404
+ GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
405
+ config_path=config_path, compile_dit=compile_dit,
406
+ weight_dtype=weight_dtype, savedir_sample=savedir_sample,
407
+ )
408
+
409
+ with gr.Blocks(css=css) as demo:
410
+ gr.Markdown(
411
+ """
412
+ # Wan2.2-Fun:
413
+
414
+ A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos.
415
+
416
+ [Github](https://github.com/aigc-apps/VideoX-Fun/)
417
+ """
418
+ )
419
+ with gr.Column(variant="panel"):
420
+ config_dropdown, config_refresh_button = create_config(controller)
421
+ model_type = create_model_type(visible=True)
422
+ diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \
423
+ create_model_checkpoints(controller, visible=True)
424
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \
425
+ create_finetune_models_checkpoints(controller, visible=True, add_checkpoint_2=True)
426
+ base_model_dropdown, base_model_2_dropdown = base_model_dropdown
427
+ lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown
428
+
429
+ with gr.Row():
430
+ enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \
431
+ create_teacache_params(True, 0.10, 1, False)
432
+ cfg_skip_ratio = create_cfg_skip_params(0)
433
+ enable_riflex, riflex_k = create_cfg_riflex_k(False, 6)
434
+
435
+ with gr.Column(variant="panel"):
436
+ prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
437
+
438
+ with gr.Row():
439
+ with gr.Column():
440
+ sampler_dropdown, sample_step_slider = create_samplers(controller)
441
+
442
+ resize_method, width_slider, height_slider, base_resolution = create_height_width(
443
+ default_height = 480, default_width = 832, maximum_height = 1344,
444
+ maximum_width = 1344,
445
+ )
446
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
447
+ create_generation_methods_and_video_length(
448
+ ["Video Generation", "Image Generation"],
449
+ default_video_length=81,
450
+ maximum_video_length=161,
451
+ )
452
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
453
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video Control (视频控制)"], prompt_textbox, support_ref_image=True
454
+ )
455
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
456
+
457
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
458
+
459
+ result_image, result_video, infer_progress = create_ui_outputs()
460
+
461
+ config_dropdown.change(
462
+ fn=controller.update_config,
463
+ inputs=[config_dropdown],
464
+ outputs=[]
465
+ )
466
+
467
+ model_type.change(
468
+ fn=controller.update_model_type,
469
+ inputs=[model_type],
470
+ outputs=[]
471
+ )
472
+
473
+ def upload_generation_method(generation_method):
474
+ if generation_method == "Video Generation":
475
+ return [gr.update(visible=True, maximum=161, value=81, interactive=True), gr.update(visible=False), gr.update(visible=False)]
476
+ elif generation_method == "Image Generation":
477
+ return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)]
478
+ else:
479
+ return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)]
480
+ generation_method.change(
481
+ upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
482
+ )
483
+
484
+ def upload_source_method(source_method):
485
+ if source_method == "Text to Video (文本到视频)":
486
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
487
+ elif source_method == "Image to Video (图片到视频)":
488
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
489
+ elif source_method == "Video to Video (视频到视频)":
490
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
491
+ else:
492
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
493
+ source_method.change(
494
+ upload_source_method, source_method, [
495
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
496
+ validation_video, validation_video_mask, control_video
497
+ ]
498
+ )
499
+
500
+ def upload_resize_method(resize_method):
501
+ if resize_method == "Generate by":
502
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
503
+ else:
504
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
505
+ resize_method.change(
506
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
507
+ )
508
+
509
+ generate_button.click(
510
+ fn=controller.generate,
511
+ inputs=[
512
+ diffusion_transformer_dropdown,
513
+ base_model_dropdown,
514
+ lora_model_dropdown,
515
+ lora_alpha_slider,
516
+ prompt_textbox,
517
+ negative_prompt_textbox,
518
+ sampler_dropdown,
519
+ sample_step_slider,
520
+ resize_method,
521
+ width_slider,
522
+ height_slider,
523
+ base_resolution,
524
+ generation_method,
525
+ length_slider,
526
+ overlap_video_length,
527
+ partial_video_length,
528
+ cfg_scale_slider,
529
+ start_image,
530
+ end_image,
531
+ validation_video,
532
+ validation_video_mask,
533
+ control_video,
534
+ denoise_strength,
535
+ seed_textbox,
536
+ ref_image,
537
+ enable_teacache,
538
+ teacache_threshold,
539
+ num_skip_start_steps,
540
+ teacache_offload,
541
+ cfg_skip_ratio,
542
+ enable_riflex,
543
+ riflex_k,
544
+ base_model_2_dropdown,
545
+ lora_model_2_dropdown
546
+ ],
547
+ outputs=[result_image, result_video, infer_progress]
548
+ )
549
+ return demo, controller
550
+
551
+ def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, config_path, compile_dit, weight_dtype, savedir_sample=None):
552
+ controller = Wan2_2_Fun_Controller_Host(
553
+ GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type,
554
+ config_path=config_path, compile_dit=compile_dit,
555
+ weight_dtype=weight_dtype, savedir_sample=savedir_sample,
556
+ )
557
+
558
+ with gr.Blocks(css=css) as demo:
559
+ gr.Markdown(
560
+ """
561
+ # Wan2.2-Fun:
562
+
563
+ A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos.
564
+
565
+ [Github](https://github.com/aigc-apps/VideoX-Fun/)
566
+ """
567
+ )
568
+ with gr.Column(variant="panel"):
569
+ model_type = create_fake_model_type(visible=False)
570
+ diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
571
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider = \
572
+ create_fake_finetune_models_checkpoints(visible=True, add_checkpoint_2=True)
573
+ base_model_dropdown, base_model_2_dropdown = base_model_dropdown
574
+ lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown
575
+
576
+ with gr.Row():
577
+ enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \
578
+ create_teacache_params(True, 0.10, 1, False)
579
+ cfg_skip_ratio = create_cfg_skip_params(0)
580
+ enable_riflex, riflex_k = create_cfg_riflex_k(False, 6)
581
+
582
+ with gr.Column(variant="panel"):
583
+ prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
584
+
585
+ with gr.Row():
586
+ with gr.Column():
587
+ sampler_dropdown, sample_step_slider = create_samplers(controller)
588
+
589
+ resize_method, width_slider, height_slider, base_resolution = create_height_width(
590
+ default_height = 480, default_width = 832, maximum_height = 1344,
591
+ maximum_width = 1344,
592
+ )
593
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
594
+ create_generation_methods_and_video_length(
595
+ ["Video Generation", "Image Generation"],
596
+ default_video_length=81,
597
+ maximum_video_length=161,
598
+ )
599
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
600
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video Control (视频控制)"], prompt_textbox, support_ref_image=True
601
+ )
602
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
603
+
604
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
605
+
606
+ result_image, result_video, infer_progress = create_ui_outputs()
607
+
608
+ def upload_generation_method(generation_method):
609
+ if generation_method == "Video Generation":
610
+ return gr.update(visible=True, minimum=1, maximum=161, value=81, interactive=True)
611
+ elif generation_method == "Image Generation":
612
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
613
+ generation_method.change(
614
+ upload_generation_method, generation_method, [length_slider]
615
+ )
616
+
617
+ def upload_source_method(source_method):
618
+ if source_method == "Text to Video (文本到视频)":
619
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
620
+ elif source_method == "Image to Video (图片到视频)":
621
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
622
+ elif source_method == "Video to Video (视频到视频)":
623
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
624
+ else:
625
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
626
+ source_method.change(
627
+ upload_source_method, source_method, [
628
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
629
+ validation_video, validation_video_mask, control_video
630
+ ]
631
+ )
632
+
633
+ def upload_resize_method(resize_method):
634
+ if resize_method == "Generate by":
635
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
636
+ else:
637
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
638
+ resize_method.change(
639
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
640
+ )
641
+
642
+ generate_button.click(
643
+ fn=controller.generate,
644
+ inputs=[
645
+ diffusion_transformer_dropdown,
646
+ base_model_dropdown,
647
+ lora_model_dropdown,
648
+ lora_alpha_slider,
649
+ prompt_textbox,
650
+ negative_prompt_textbox,
651
+ sampler_dropdown,
652
+ sample_step_slider,
653
+ resize_method,
654
+ width_slider,
655
+ height_slider,
656
+ base_resolution,
657
+ generation_method,
658
+ length_slider,
659
+ overlap_video_length,
660
+ partial_video_length,
661
+ cfg_scale_slider,
662
+ start_image,
663
+ end_image,
664
+ validation_video,
665
+ validation_video_mask,
666
+ control_video,
667
+ denoise_strength,
668
+ seed_textbox,
669
+ ref_image,
670
+ enable_teacache,
671
+ teacache_threshold,
672
+ num_skip_start_steps,
673
+ teacache_offload,
674
+ cfg_skip_ratio,
675
+ enable_riflex,
676
+ riflex_k,
677
+ base_model_2_dropdown,
678
+ lora_model_2_dropdown
679
+ ],
680
+ outputs=[result_image, result_video, infer_progress]
681
+ )
682
+ return demo, controller
683
+
684
+ def ui_client(scheduler_dict, model_name, savedir_sample=None):
685
+ controller = Wan2_2_Fun_Controller_Client(scheduler_dict, savedir_sample)
686
+
687
+ with gr.Blocks(css=css) as demo:
688
+ gr.Markdown(
689
+ """
690
+ # Wan2.2-Fun:
691
+
692
+ A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos.
693
+
694
+ [Github](https://github.com/aigc-apps/VideoX-Fun/)
695
+ """
696
+ )
697
+ with gr.Column(variant="panel"):
698
+ diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
699
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider = \
700
+ create_fake_finetune_models_checkpoints(visible=True, add_checkpoint_2=True)
701
+ base_model_dropdown, base_model_2_dropdown = base_model_dropdown
702
+ lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown
703
+
704
+ with gr.Row():
705
+ enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \
706
+ create_teacache_params(True, 0.10, 1, False)
707
+ cfg_skip_ratio = create_cfg_skip_params(0)
708
+ enable_riflex, riflex_k = create_cfg_riflex_k(False, 6)
709
+
710
+ with gr.Column(variant="panel"):
711
+ prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
712
+
713
+ with gr.Row():
714
+ with gr.Column():
715
+ sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50)
716
+
717
+ resize_method, width_slider, height_slider, base_resolution = create_fake_height_width(
718
+ default_height = 480, default_width = 832, maximum_height = 1344,
719
+ maximum_width = 1344,
720
+ )
721
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
722
+ create_generation_methods_and_video_length(
723
+ ["Video Generation", "Image Generation"],
724
+ default_video_length=81,
725
+ maximum_video_length=161,
726
+ )
727
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method(
728
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox
729
+ )
730
+
731
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
732
+
733
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
734
+
735
+ result_image, result_video, infer_progress = create_ui_outputs()
736
+
737
+ def upload_generation_method(generation_method):
738
+ if generation_method == "Video Generation":
739
+ return gr.update(visible=True, minimum=5, maximum=161, value=49, interactive=True)
740
+ elif generation_method == "Image Generation":
741
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
742
+ generation_method.change(
743
+ upload_generation_method, generation_method, [length_slider]
744
+ )
745
+
746
+ def upload_source_method(source_method):
747
+ if source_method == "Text to Video (文本到视频)":
748
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
749
+ elif source_method == "Image to Video (图片到视频)":
750
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)]
751
+ else:
752
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()]
753
+ source_method.change(
754
+ upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask]
755
+ )
756
+
757
+ def upload_resize_method(resize_method):
758
+ if resize_method == "Generate by":
759
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
760
+ else:
761
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
762
+ resize_method.change(
763
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
764
+ )
765
+
766
+ generate_button.click(
767
+ fn=controller.generate,
768
+ inputs=[
769
+ diffusion_transformer_dropdown,
770
+ base_model_dropdown,
771
+ lora_model_dropdown,
772
+ lora_alpha_slider,
773
+ prompt_textbox,
774
+ negative_prompt_textbox,
775
+ sampler_dropdown,
776
+ sample_step_slider,
777
+ resize_method,
778
+ width_slider,
779
+ height_slider,
780
+ base_resolution,
781
+ generation_method,
782
+ length_slider,
783
+ cfg_scale_slider,
784
+ start_image,
785
+ end_image,
786
+ validation_video,
787
+ validation_video_mask,
788
+ denoise_strength,
789
+ seed_textbox,
790
+ ref_image,
791
+ enable_teacache,
792
+ teacache_threshold,
793
+ num_skip_start_steps,
794
+ teacache_offload,
795
+ cfg_skip_ratio,
796
+ enable_riflex,
797
+ riflex_k,
798
+ base_model_2_dropdown,
799
+ lora_model_2_dropdown
800
+ ],
801
+ outputs=[result_image, result_video, infer_progress]
802
+ )
803
+ return demo, controller