haodongli commited on
Commit
4b35c4e
·
1 Parent(s): 2cfedc8
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +34 -0
  2. README.md +15 -0
  3. app.py +77 -0
  4. ckpt/model_config.yaml +7 -0
  5. configs/img_config/data_diode_all.yaml +8 -0
  6. configs/img_config/data_eth3d.yaml +11 -0
  7. configs/img_config/data_kitti_eigen_test.yaml +8 -0
  8. configs/img_config/data_nyu_test.yaml +8 -0
  9. configs/img_config/data_scannet_val.yaml +6 -0
  10. configs/scannetv1_test.txt +312 -0
  11. configs/vid_config/img_sintel.yaml +3 -0
  12. configs/vid_config/vid_bonn.yaml +2 -0
  13. configs/vid_config/vid_kitti.yaml +2 -0
  14. configs/vid_config/vid_scannet.yaml +3 -0
  15. configs/vid_config/vid_sintel.yaml +3 -0
  16. diffsynth/__init__.py +4 -0
  17. diffsynth/configs/__init__.py +0 -0
  18. diffsynth/configs/model_config.py +705 -0
  19. diffsynth/data/__init__.py +1 -0
  20. diffsynth/data/video.py +244 -0
  21. diffsynth/distributed/__init__.py +0 -0
  22. diffsynth/distributed/xdit_context_parallel.py +129 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/downloader.py +116 -0
  25. diffsynth/models/model_manager.py +416 -0
  26. diffsynth/models/tiler.py +234 -0
  27. diffsynth/models/utils.py +185 -0
  28. diffsynth/models/wan_video_camera_controller.py +221 -0
  29. diffsynth/models/wan_video_dit.py +974 -0
  30. diffsynth/models/wan_video_image_encoder.py +902 -0
  31. diffsynth/models/wan_video_motion_controller.py +44 -0
  32. diffsynth/models/wan_video_text_encoder.py +269 -0
  33. diffsynth/models/wan_video_vace.py +113 -0
  34. diffsynth/models/wan_video_vae.py +828 -0
  35. diffsynth/pipelines/__init__.py +1 -0
  36. diffsynth/pipelines/wan_video_new_determine.py +1730 -0
  37. diffsynth/schedulers/__init__.py +3 -0
  38. diffsynth/schedulers/continuous_ode.py +59 -0
  39. diffsynth/schedulers/ddim.py +105 -0
  40. diffsynth/schedulers/flow_match.py +116 -0
  41. diffsynth/util/alignment.py +131 -0
  42. diffsynth/util/depth_transform.py +98 -0
  43. diffsynth/util/metric.py +337 -0
  44. diffsynth/util/normal_utils.py +78 -0
  45. diffsynth/util/seed_all.py +33 -0
  46. diffsynth/vram_management/__init__.py +2 -0
  47. diffsynth/vram_management/gradient_checkpointing.py +34 -0
  48. diffsynth/vram_management/layers.py +167 -0
  49. examples/__init__.py +0 -0
  50. examples/dataset/__init__.py +17 -0
.gitignore ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output/
2
+ models/
3
+ ._____temp/
4
+ .vscode/
5
+ *__pycache__/*
6
+ *.pyc
7
+ video/
8
+ *.safetensors
9
+ *.pth
10
+ omni/
11
+ pcd/
12
+
13
+ models/
14
+ models_ms/
15
+ !diffsynth/models**
16
+ diffsynth/models/__pycache__/
17
+ ckpt/DVD/
18
+ outputs/
19
+ inference_results/
20
+ *.mp4
21
+ ckpt/DVD
22
+ .msc
23
+ .mv
24
+ ckpt/.cache
25
+ ckpt/.gitattributes
26
+ ckpt/README.md
27
+ overlap_plots/
28
+ test_script/test_from_trained_all_vid_test.py
29
+ test_script/test_single_video_batch.py
30
+ DVD.egg-info/
31
+ infer_bash/video_test.sh
32
+ ckpt/test/
33
+ !demo/robot_navi.mp4
34
+ !demo/drone.mp4
README.md CHANGED
@@ -5,10 +5,25 @@ colorFrom: green
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 6.9.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  short_description: Official demo of DVD (https://dvd-project.github.io/)
 
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 6.9.0
8
+ python_version: 3.10.20
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
12
  short_description: Official demo of DVD (https://dvd-project.github.io/)
13
+ tags:
14
+ - video diffusion
15
+ - video depth estimation
16
  ---
17
 
18
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
19
+
20
+ If you find our work useful in your research, please consider citing our paper🌹:
21
+
22
+ ```
23
+ @article{zhang2026dvd,
24
+ title={DVD: Deterministic Video Depth Estimation with Generative Priors},
25
+ author={Zhang, Hongfei and Chen, Harold Haodong and Liao, Chenfei and He, Jing and Zhang, Zixin and Li, Haodong and Liang, Yihao and Chen, Kanghao and Ren, Bin and Zheng, Xu and Yang, Shuai and Zhou, Kun and Li, Yinchuan and Sebe, Nicu and Chen, Ying-Cong},
26
+ journal={arXiv preprint arXiv:2603.12250},
27
+ year={2026}
28
+ }
29
+ ```
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces # must be first!
2
+ import os
3
+ from pathlib import Path
4
+
5
+ REPO_ROOT = Path(__file__).resolve().parent
6
+ GRADIO_TMP = REPO_ROOT / ".gradio_cache"
7
+ GRADIO_TMP.mkdir(parents=True, exist_ok=True)
8
+
9
+ os.environ["GRADIO_TEMP_DIR"] = str(GRADIO_TMP)
10
+ print(f"Gradio temp/cache dir: {GRADIO_TMP}")
11
+
12
+ import torch
13
+ from argparse import Namespace
14
+ import subprocess
15
+ from test_script.test_single_video import *
16
+
17
+ import gradio as gr
18
+
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ yaml_args = OmegaConf.load(f"{REPO_ROOT}/ckpt/model_config.yaml")
21
+ pipeline = None
22
+
23
+
24
+ @spaces.GPU
25
+ def fn(input_video):
26
+ global pipeline, yaml_args, device
27
+ if pipeline is None:
28
+ if not os.path.exists(f"{REPO_ROOT}/ckpt/model.safetensors"):
29
+ subprocess.run(["bash", f"{REPO_ROOT}/infer_bash/download_ckpt.sh"], check=True)
30
+ pipeline = load_model(f"{REPO_ROOT}/ckpt", yaml_args)
31
+
32
+ input_video_basename = os.path.basename(input_video)
33
+ input_tensor, orig_size, origin_fps = load_video_data(Namespace(
34
+ input_video=input_video,
35
+ height=480,
36
+ width=640,
37
+ ))
38
+ depth = predict_depth(pipeline, input_tensor, orig_size, Namespace(
39
+ window_size=81,
40
+ overlap=21
41
+ ))
42
+ output_video = save_results(depth, origin_fps, Namespace(
43
+ input_video=input_video,
44
+ output_dir=REPO_ROOT,
45
+ grayscale=False
46
+ ))
47
+
48
+ return output_video
49
+
50
+
51
+ if __name__ == "__main__":
52
+ inputs = [
53
+ gr.Video(label="Input Video", autoplay=True),
54
+ ]
55
+ outputs = [
56
+ gr.Video(label="Output Video", autoplay=True),
57
+ ]
58
+
59
+ demo = gr.Interface(
60
+ fn=fn,
61
+ title="DVD: Deterministic Video Depth Estimation with Generative Priors",
62
+ description="""
63
+ <strong>Please consider starring <span style="color: orange">&#9733;</span> our <a href="https://github.com/EnVision-Research/DVD" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this demo useful!</strong>
64
+ """,
65
+ inputs=inputs,
66
+ outputs=outputs,
67
+ examples=[
68
+ [f"{REPO_ROOT}/demo/drone.mp4"],
69
+ [f"{REPO_ROOT}/demo/robot_navi.mp4"]
70
+ ]
71
+ )
72
+
73
+ demo.queue(default_concurrency_limit=1)
74
+ demo.launch(
75
+ # server_name="0.0.0.0",
76
+ # server_port=1324,
77
+ )
ckpt/model_config.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ model_id_with_origin_paths: Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth
2
+ trainable_models: dit
3
+ mode: regression
4
+ denoise_step: 0.5
5
+ training_target: x
6
+ lora_base_model: dit
7
+ lora_rank: 512
configs/img_config/data_diode_all.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ name: diode
2
+ disp_name: diode_val_all
3
+ # dir: diode
4
+ # dataset_dir: diode
5
+ # filename_ls_path: data_split/diode/diode_val_all_filename_list.txt
6
+ processing_res: 640
7
+ dir: diode
8
+ filename: data_split/diode/diode_val_all_filename_list.txt
configs/img_config/data_eth3d.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: eth3d
2
+ disp_name: eth3d_full
3
+ # dataset_dir: eth3d
4
+ # dir: eth3d/eth3d.tar
5
+ # filename_ls_path: data_split/eth3d/eth3d_filename_list.txt
6
+ dir: eth3d
7
+ # dir: eth3d/eth3d.tar
8
+ filename: data_split/eth3d/eth3d_filename_list.txt
9
+ resize_to_hw: [480, 720]
10
+ # processing_res: 768
11
+ # alignment_max_res: 1024
configs/img_config/data_kitti_eigen_test.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ name: kitti
2
+ disp_name: kitti_eigen_test_full
3
+ # dataset_dir: kitti
4
+ # filename_ls_path: data_split/kitti/eigen_test_files_with_gt.txt
5
+ kitti_bm_crop: true
6
+ valid_mask_crop: eigen
7
+ dir: kitti
8
+ filename: data_split/kitti/eigen_test_files_with_gt.txt
configs/img_config/data_nyu_test.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ name: nyu_v2
2
+ disp_name: nyu_test_full
3
+ # dataset_dir: nyuv2
4
+ # filename_ls_path: data_split/nyu/labeled/filename_list_test.txt
5
+ eigen_valid_mask: true
6
+
7
+ dir: nyuv2
8
+ filename: data_split/nyu/labeled/filename_list_test.txt
configs/img_config/data_scannet_val.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: scannet
2
+ disp_name: scannet_val_800_1
3
+ # dataset_dir: scannet
4
+ # filename_ls_path: data_split/scannet/scannet_val_sampled_list_800_1.txt
5
+ dir: scannet
6
+ filename: data_split/scannet/scannet_val_sampled_list_800_1.txt
configs/scannetv1_test.txt ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scene0568_00
2
+ scene0568_01
3
+ scene0568_02
4
+ scene0304_00
5
+ scene0488_00
6
+ scene0488_01
7
+ scene0412_00
8
+ scene0412_01
9
+ scene0217_00
10
+ scene0019_00
11
+ scene0019_01
12
+ scene0414_00
13
+ scene0575_00
14
+ scene0575_01
15
+ scene0575_02
16
+ scene0426_00
17
+ scene0426_01
18
+ scene0426_02
19
+ scene0426_03
20
+ scene0549_00
21
+ scene0549_01
22
+ scene0578_00
23
+ scene0578_01
24
+ scene0578_02
25
+ scene0665_00
26
+ scene0665_01
27
+ scene0050_00
28
+ scene0050_01
29
+ scene0050_02
30
+ scene0257_00
31
+ scene0025_00
32
+ scene0025_01
33
+ scene0025_02
34
+ scene0583_00
35
+ scene0583_01
36
+ scene0583_02
37
+ scene0701_00
38
+ scene0701_01
39
+ scene0701_02
40
+ scene0580_00
41
+ scene0580_01
42
+ scene0565_00
43
+ scene0169_00
44
+ scene0169_01
45
+ scene0655_00
46
+ scene0655_01
47
+ scene0655_02
48
+ scene0063_00
49
+ scene0221_00
50
+ scene0221_01
51
+ scene0591_00
52
+ scene0591_01
53
+ scene0591_02
54
+ scene0678_00
55
+ scene0678_01
56
+ scene0678_02
57
+ scene0462_00
58
+ scene0427_00
59
+ scene0595_00
60
+ scene0193_00
61
+ scene0193_01
62
+ scene0164_00
63
+ scene0164_01
64
+ scene0164_02
65
+ scene0164_03
66
+ scene0598_00
67
+ scene0598_01
68
+ scene0598_02
69
+ scene0599_00
70
+ scene0599_01
71
+ scene0599_02
72
+ scene0328_00
73
+ scene0300_00
74
+ scene0300_01
75
+ scene0354_00
76
+ scene0458_00
77
+ scene0458_01
78
+ scene0423_00
79
+ scene0423_01
80
+ scene0423_02
81
+ scene0307_00
82
+ scene0307_01
83
+ scene0307_02
84
+ scene0606_00
85
+ scene0606_01
86
+ scene0606_02
87
+ scene0432_00
88
+ scene0432_01
89
+ scene0608_00
90
+ scene0608_01
91
+ scene0608_02
92
+ scene0651_00
93
+ scene0651_01
94
+ scene0651_02
95
+ scene0430_00
96
+ scene0430_01
97
+ scene0689_00
98
+ scene0357_00
99
+ scene0357_01
100
+ scene0574_00
101
+ scene0574_01
102
+ scene0574_02
103
+ scene0329_00
104
+ scene0329_01
105
+ scene0329_02
106
+ scene0153_00
107
+ scene0153_01
108
+ scene0616_00
109
+ scene0616_01
110
+ scene0671_00
111
+ scene0671_01
112
+ scene0618_00
113
+ scene0382_00
114
+ scene0382_01
115
+ scene0490_00
116
+ scene0621_00
117
+ scene0607_00
118
+ scene0607_01
119
+ scene0149_00
120
+ scene0695_00
121
+ scene0695_01
122
+ scene0695_02
123
+ scene0695_03
124
+ scene0389_00
125
+ scene0377_00
126
+ scene0377_01
127
+ scene0377_02
128
+ scene0342_00
129
+ scene0139_00
130
+ scene0629_00
131
+ scene0629_01
132
+ scene0629_02
133
+ scene0496_00
134
+ scene0633_00
135
+ scene0633_01
136
+ scene0518_00
137
+ scene0652_00
138
+ scene0406_00
139
+ scene0406_01
140
+ scene0406_02
141
+ scene0144_00
142
+ scene0144_01
143
+ scene0494_00
144
+ scene0278_00
145
+ scene0278_01
146
+ scene0316_00
147
+ scene0609_00
148
+ scene0609_01
149
+ scene0609_02
150
+ scene0609_03
151
+ scene0084_00
152
+ scene0084_01
153
+ scene0084_02
154
+ scene0696_00
155
+ scene0696_01
156
+ scene0696_02
157
+ scene0351_00
158
+ scene0351_01
159
+ scene0643_00
160
+ scene0644_00
161
+ scene0645_00
162
+ scene0645_01
163
+ scene0645_02
164
+ scene0081_00
165
+ scene0081_01
166
+ scene0081_02
167
+ scene0647_00
168
+ scene0647_01
169
+ scene0535_00
170
+ scene0353_00
171
+ scene0353_01
172
+ scene0353_02
173
+ scene0559_00
174
+ scene0559_01
175
+ scene0559_02
176
+ scene0593_00
177
+ scene0593_01
178
+ scene0246_00
179
+ scene0653_00
180
+ scene0653_01
181
+ scene0064_00
182
+ scene0064_01
183
+ scene0356_00
184
+ scene0356_01
185
+ scene0356_02
186
+ scene0030_00
187
+ scene0030_01
188
+ scene0030_02
189
+ scene0222_00
190
+ scene0222_01
191
+ scene0338_00
192
+ scene0338_01
193
+ scene0338_02
194
+ scene0378_00
195
+ scene0378_01
196
+ scene0378_02
197
+ scene0660_00
198
+ scene0553_00
199
+ scene0553_01
200
+ scene0553_02
201
+ scene0527_00
202
+ scene0663_00
203
+ scene0663_01
204
+ scene0663_02
205
+ scene0664_00
206
+ scene0664_01
207
+ scene0664_02
208
+ scene0334_00
209
+ scene0334_01
210
+ scene0334_02
211
+ scene0046_00
212
+ scene0046_01
213
+ scene0046_02
214
+ scene0203_00
215
+ scene0203_01
216
+ scene0203_02
217
+ scene0088_00
218
+ scene0088_01
219
+ scene0088_02
220
+ scene0088_03
221
+ scene0086_00
222
+ scene0086_01
223
+ scene0086_02
224
+ scene0670_00
225
+ scene0670_01
226
+ scene0256_00
227
+ scene0256_01
228
+ scene0256_02
229
+ scene0249_00
230
+ scene0441_00
231
+ scene0658_00
232
+ scene0704_00
233
+ scene0704_01
234
+ scene0187_00
235
+ scene0187_01
236
+ scene0131_00
237
+ scene0131_01
238
+ scene0131_02
239
+ scene0207_00
240
+ scene0207_01
241
+ scene0207_02
242
+ scene0461_00
243
+ scene0011_00
244
+ scene0011_01
245
+ scene0343_00
246
+ scene0251_00
247
+ scene0077_00
248
+ scene0077_01
249
+ scene0684_00
250
+ scene0684_01
251
+ scene0550_00
252
+ scene0686_00
253
+ scene0686_01
254
+ scene0686_02
255
+ scene0208_00
256
+ scene0500_00
257
+ scene0500_01
258
+ scene0552_00
259
+ scene0552_01
260
+ scene0648_00
261
+ scene0648_01
262
+ scene0435_00
263
+ scene0435_01
264
+ scene0435_02
265
+ scene0435_03
266
+ scene0690_00
267
+ scene0690_01
268
+ scene0693_00
269
+ scene0693_01
270
+ scene0693_02
271
+ scene0700_00
272
+ scene0700_01
273
+ scene0700_02
274
+ scene0699_00
275
+ scene0231_00
276
+ scene0231_01
277
+ scene0231_02
278
+ scene0697_00
279
+ scene0697_01
280
+ scene0697_02
281
+ scene0697_03
282
+ scene0474_00
283
+ scene0474_01
284
+ scene0474_02
285
+ scene0474_03
286
+ scene0474_04
287
+ scene0474_05
288
+ scene0355_00
289
+ scene0355_01
290
+ scene0146_00
291
+ scene0146_01
292
+ scene0146_02
293
+ scene0196_00
294
+ scene0702_00
295
+ scene0702_01
296
+ scene0702_02
297
+ scene0314_00
298
+ scene0277_00
299
+ scene0277_01
300
+ scene0277_02
301
+ scene0095_00
302
+ scene0095_01
303
+ scene0015_00
304
+ scene0100_00
305
+ scene0100_01
306
+ scene0100_02
307
+ scene0558_00
308
+ scene0558_01
309
+ scene0558_02
310
+ scene0685_00
311
+ scene0685_01
312
+ scene0685_02
configs/vid_config/img_sintel.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: sintel
2
+ dir: Sintel/training
3
+ stack_scene_depth: false
configs/vid_config/vid_bonn.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ name: bonn
2
+ dir: rgbd_bonn_dataset
configs/vid_config/vid_kitti.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ name: kitti
2
+ dir: kitti_depth
configs/vid_config/vid_scannet.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: scannet
2
+ dir: scannet
3
+ split_ls: 'configs/scannetv1_test.txt'
configs/vid_config/vid_sintel.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: sintel
2
+ dir: Sintel/training
3
+ stack_scene_depth: true
diffsynth/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .data import *
2
+ from .models import *
3
+ from .pipelines import *
4
+ from .schedulers import *
diffsynth/configs/__init__.py ADDED
File without changes
diffsynth/configs/model_config.py ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+
3
+ from ..models.wan_video_dit import WanModel
4
+ from ..models.wan_video_image_encoder import WanImageEncoder
5
+ from ..models.wan_video_motion_controller import WanMotionControllerModel
6
+ from ..models.wan_video_text_encoder import WanTextEncoder
7
+ from ..models.wan_video_vace import VaceWanModel
8
+ from ..models.wan_video_vae import WanVideoVAE
9
+
10
+ model_loader_configs = [
11
+ # These configs are provided for detecting model type automatically.
12
+ # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
13
+ (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
14
+ (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
15
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
16
+ (None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
17
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
18
+ (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
19
+ (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
20
+ (None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
21
+ (None, "70ddad9d3a133785da5ea371aae09504", ["wan_video_dit"], [WanModel], "civitai"),
22
+ (None, "26bde73488a92e64cc20b0a7485b9e5b", ["wan_video_dit"], [WanModel], "civitai"),
23
+ (None, "ac6a5aa74f4a0aab6f64eb9a72f19901", ["wan_video_dit"], [WanModel], "civitai"),
24
+ (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
25
+ (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
26
+ (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
27
+ (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
28
+ (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
29
+ (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
30
+ (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
31
+ (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
32
+ (None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
33
+ ]
34
+ huggingface_model_loader_configs = [
35
+ # These configs are provided for detecting model type automatically.
36
+ # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
37
+ ("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
38
+ ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
39
+ ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
40
+ ("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
41
+ ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
42
+ ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
43
+ ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
44
+ ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
45
+ ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
46
+ ("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
47
+ ("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
48
+ ("Qwen2_5_VLForConditionalGeneration", "diffsynth.models.qwenvl", "qwenvl", "Qwen25VL_7b_Embedder"),
49
+ ]
50
+ patch_model_loader_configs = [
51
+ # These configs are provided for detecting model type automatically.
52
+ # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
53
+ ]
54
+
55
+ preset_models_on_huggingface = {
56
+ "HunyuanDiT": [
57
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
58
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
59
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
60
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
61
+ ],
62
+ "stable-video-diffusion-img2vid-xt": [
63
+ ("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
64
+ ],
65
+ "ExVideo-SVD-128f-v1": [
66
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
67
+ ],
68
+ # Stable Diffusion
69
+ "StableDiffusion_v15": [
70
+ ("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
71
+ ],
72
+ "DreamShaper_8": [
73
+ ("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
74
+ ],
75
+ # Textual Inversion
76
+ "TextualInversion_VeryBadImageNegative_v1.3": [
77
+ ("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
78
+ ],
79
+ # Stable Diffusion XL
80
+ "StableDiffusionXL_v1": [
81
+ ("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
82
+ ],
83
+ "BluePencilXL_v200": [
84
+ ("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
85
+ ],
86
+ "StableDiffusionXL_Turbo": [
87
+ ("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
88
+ ],
89
+ # Stable Diffusion 3
90
+ "StableDiffusion3": [
91
+ ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
92
+ ],
93
+ "StableDiffusion3_without_T5": [
94
+ ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
95
+ ],
96
+ # ControlNet
97
+ "ControlNet_v11f1p_sd15_depth": [
98
+ ("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
99
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
100
+ ],
101
+ "ControlNet_v11p_sd15_softedge": [
102
+ ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
103
+ ("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
104
+ ],
105
+ "ControlNet_v11f1e_sd15_tile": [
106
+ ("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
107
+ ],
108
+ "ControlNet_v11p_sd15_lineart": [
109
+ ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
110
+ ("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
111
+ ("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
112
+ ],
113
+ "ControlNet_union_sdxl_promax": [
114
+ ("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
115
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
116
+ ],
117
+ # AnimateDiff
118
+ "AnimateDiff_v2": [
119
+ ("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
120
+ ],
121
+ "AnimateDiff_xl_beta": [
122
+ ("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
123
+ ],
124
+
125
+ # Qwen Prompt
126
+ "QwenPrompt": [
127
+ ("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
128
+ ("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
129
+ ("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
130
+ ("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
131
+ ("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
132
+ ("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
133
+ ("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
134
+ ("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
135
+ ],
136
+ # Beautiful Prompt
137
+ "BeautifulPrompt": [
138
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
139
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
140
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
141
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
142
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
143
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
144
+ ],
145
+ # Omost prompt
146
+ "OmostPrompt":[
147
+ ("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
148
+ ("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
149
+ ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
150
+ ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
151
+ ("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
152
+ ("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
153
+ ("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
154
+ ("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
155
+ ],
156
+ # Translator
157
+ "opus-mt-zh-en": [
158
+ ("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
159
+ ("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
160
+ ("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
161
+ ("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
162
+ ("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
163
+ ("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
164
+ ("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
165
+ ("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
166
+ ],
167
+ # IP-Adapter
168
+ "IP-Adapter-SD": [
169
+ ("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
170
+ ("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
171
+ ],
172
+ "IP-Adapter-SDXL": [
173
+ ("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
174
+ ("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
175
+ ],
176
+ "SDXL-vae-fp16-fix": [
177
+ ("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
178
+ ],
179
+ # Kolors
180
+ "Kolors": [
181
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
182
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
183
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
184
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
185
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
186
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
187
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
188
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
189
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
190
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
191
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
192
+ ],
193
+ # FLUX
194
+ "FLUX.1-dev": [
195
+ ("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
196
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
197
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
198
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
199
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
200
+ ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
201
+ ("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
202
+ ],
203
+ "InstantX/FLUX.1-dev-IP-Adapter": {
204
+ "file_list": [
205
+ ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
206
+ ("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
207
+ ("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
208
+ ],
209
+ "load_path": [
210
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
211
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
212
+ ],
213
+ },
214
+ # RIFE
215
+ "RIFE": [
216
+ ("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
217
+ ],
218
+ # CogVideo
219
+ "CogVideoX-5B": [
220
+ ("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
221
+ ("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
222
+ ("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
223
+ ("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
224
+ ("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
225
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
226
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
227
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
228
+ ("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
229
+ ],
230
+ # Stable Diffusion 3.5
231
+ "StableDiffusion3.5-large": [
232
+ ("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
233
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
234
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
235
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
236
+ ],
237
+ }
238
+ preset_models_on_modelscope = {
239
+ # Hunyuan DiT
240
+ "HunyuanDiT": [
241
+ ("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
242
+ ("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
243
+ ("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
244
+ ("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
245
+ ],
246
+ # Stable Video Diffusion
247
+ "stable-video-diffusion-img2vid-xt": [
248
+ ("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
249
+ ],
250
+ # ExVideo
251
+ "ExVideo-SVD-128f-v1": [
252
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
253
+ ],
254
+ "ExVideo-CogVideoX-LoRA-129f-v1": [
255
+ ("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
256
+ ],
257
+ # Stable Diffusion
258
+ "StableDiffusion_v15": [
259
+ ("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
260
+ ],
261
+ "DreamShaper_8": [
262
+ ("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
263
+ ],
264
+ "AingDiffusion_v12": [
265
+ ("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
266
+ ],
267
+ "Flat2DAnimerge_v45Sharp": [
268
+ ("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
269
+ ],
270
+ # Textual Inversion
271
+ "TextualInversion_VeryBadImageNegative_v1.3": [
272
+ ("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
273
+ ],
274
+ # Stable Diffusion XL
275
+ "StableDiffusionXL_v1": [
276
+ ("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
277
+ ],
278
+ "BluePencilXL_v200": [
279
+ ("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
280
+ ],
281
+ "StableDiffusionXL_Turbo": [
282
+ ("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
283
+ ],
284
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
285
+ ("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
286
+ ],
287
+ # Stable Diffusion 3
288
+ "StableDiffusion3": [
289
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
290
+ ],
291
+ "StableDiffusion3_without_T5": [
292
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
293
+ ],
294
+ # ControlNet
295
+ "ControlNet_v11f1p_sd15_depth": [
296
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
297
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
298
+ ],
299
+ "ControlNet_v11p_sd15_softedge": [
300
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
301
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
302
+ ],
303
+ "ControlNet_v11f1e_sd15_tile": [
304
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
305
+ ],
306
+ "ControlNet_v11p_sd15_lineart": [
307
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
308
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
309
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
310
+ ],
311
+ "ControlNet_union_sdxl_promax": [
312
+ ("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
313
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
314
+ ],
315
+ "Annotators:Depth": [
316
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
317
+ ],
318
+ "Annotators:Softedge": [
319
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
320
+ ],
321
+ "Annotators:Lineart": [
322
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
323
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
324
+ ],
325
+ "Annotators:Normal": [
326
+ ("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
327
+ ],
328
+ "Annotators:Openpose": [
329
+ ("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
330
+ ("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
331
+ ("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
332
+ ],
333
+ # AnimateDiff
334
+ "AnimateDiff_v2": [
335
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
336
+ ],
337
+ "AnimateDiff_xl_beta": [
338
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
339
+ ],
340
+ # RIFE
341
+ "RIFE": [
342
+ ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
343
+ ],
344
+ # Qwen Prompt
345
+ "QwenPrompt": {
346
+ "file_list": [
347
+ ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
348
+ ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
349
+ ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
350
+ ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
351
+ ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
352
+ ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
353
+ ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
354
+ ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
355
+ ],
356
+ "load_path": [
357
+ "models/QwenPrompt/qwen2-1.5b-instruct",
358
+ ],
359
+ },
360
+ # Beautiful Prompt
361
+ "BeautifulPrompt": {
362
+ "file_list": [
363
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
364
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
365
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
366
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
367
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
368
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
369
+ ],
370
+ "load_path": [
371
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
372
+ ],
373
+ },
374
+ # Omost prompt
375
+ "OmostPrompt": {
376
+ "file_list": [
377
+ ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
378
+ ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
379
+ ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
380
+ ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
381
+ ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
382
+ ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
383
+ ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
384
+ ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
385
+ ],
386
+ "load_path": [
387
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
388
+ ],
389
+ },
390
+ # Translator
391
+ "opus-mt-zh-en": {
392
+ "file_list": [
393
+ ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
394
+ ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
395
+ ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
396
+ ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
397
+ ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
398
+ ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
399
+ ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
400
+ ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
401
+ ],
402
+ "load_path": [
403
+ "models/translator/opus-mt-zh-en",
404
+ ],
405
+ },
406
+ # IP-Adapter
407
+ "IP-Adapter-SD": [
408
+ ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
409
+ ("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
410
+ ],
411
+ "IP-Adapter-SDXL": [
412
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
413
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
414
+ ],
415
+ # Kolors
416
+ "Kolors": {
417
+ "file_list": [
418
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
419
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
420
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
421
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
422
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
423
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
424
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
425
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
426
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
427
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
428
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
429
+ ],
430
+ "load_path": [
431
+ "models/kolors/Kolors/text_encoder",
432
+ "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
433
+ "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
434
+ ],
435
+ },
436
+ "SDXL-vae-fp16-fix": [
437
+ ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
438
+ ],
439
+ # FLUX
440
+ "FLUX.1-dev": {
441
+ "file_list": [
442
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
443
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
444
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
445
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
446
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
447
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
448
+ ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
449
+ ],
450
+ "load_path": [
451
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
452
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
453
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
454
+ "models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
455
+ ],
456
+ },
457
+ "FLUX.1-schnell": {
458
+ "file_list": [
459
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
460
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
461
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
462
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
463
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
464
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
465
+ ("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
466
+ ],
467
+ "load_path": [
468
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
469
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
470
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
471
+ "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
472
+ ],
473
+ },
474
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
475
+ ("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
476
+ ],
477
+ "jasperai/Flux.1-dev-Controlnet-Depth": [
478
+ ("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
479
+ ],
480
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
481
+ ("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
482
+ ],
483
+ "jasperai/Flux.1-dev-Controlnet-Upscaler": [
484
+ ("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
485
+ ],
486
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
487
+ ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
488
+ ],
489
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
490
+ ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
491
+ ],
492
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
493
+ ("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
494
+ ],
495
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
496
+ ("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
497
+ ],
498
+ "InstantX/FLUX.1-dev-IP-Adapter": {
499
+ "file_list": [
500
+ ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
501
+ ("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
502
+ ("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
503
+ ],
504
+ "load_path": [
505
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
506
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
507
+ ],
508
+ },
509
+ "InfiniteYou":{
510
+ "file_list":[
511
+ ("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
512
+ ("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors", "models/InfiniteYou/InfuseNetModel"),
513
+ ("ByteDance/InfiniteYou", "infu_flux_v1.0/aes_stage2/image_proj_model.bin", "models/InfiniteYou"),
514
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/1k3d68.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
515
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/2d106det.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
516
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/genderage.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
517
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/glintr100.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
518
+ ("ByteDance/InfiniteYou", "supports/insightface/models/antelopev2/scrfd_10g_bnkps.onnx", "models/InfiniteYou/insightface/models/antelopev2"),
519
+ ],
520
+ "load_path":[
521
+ [
522
+ "models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00001-of-00002.safetensors",
523
+ "models/InfiniteYou/InfuseNetModel/diffusion_pytorch_model-00002-of-00002.safetensors"
524
+ ],
525
+ "models/InfiniteYou/image_proj_model.bin",
526
+ ],
527
+ },
528
+ # ESRGAN
529
+ "ESRGAN_x4": [
530
+ ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
531
+ ],
532
+ # RIFE
533
+ "RIFE": [
534
+ ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
535
+ ],
536
+ # Omnigen
537
+ "OmniGen-v1": {
538
+ "file_list": [
539
+ ("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
540
+ ("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
541
+ ("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
542
+ ("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
543
+ ("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
544
+ ("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
545
+ ],
546
+ "load_path": [
547
+ "models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
548
+ "models/OmniGen/OmniGen-v1/model.safetensors",
549
+ ]
550
+ },
551
+ # CogVideo
552
+ "CogVideoX-5B": {
553
+ "file_list": [
554
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
555
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
556
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
557
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
558
+ ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
559
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
560
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
561
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
562
+ ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
563
+ ],
564
+ "load_path": [
565
+ "models/CogVideo/CogVideoX-5b/text_encoder",
566
+ "models/CogVideo/CogVideoX-5b/transformer",
567
+ "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
568
+ ],
569
+ },
570
+ # Stable Diffusion 3.5
571
+ "StableDiffusion3.5-large": [
572
+ ("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
573
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
574
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
575
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
576
+ ],
577
+ "StableDiffusion3.5-medium": [
578
+ ("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
579
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
580
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
581
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
582
+ ],
583
+ "StableDiffusion3.5-large-turbo": [
584
+ ("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
585
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
586
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
587
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
588
+ ],
589
+ "HunyuanVideo":{
590
+ "file_list": [
591
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
592
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
593
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
594
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
595
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
596
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
597
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
598
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
599
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
600
+ ],
601
+ "load_path": [
602
+ "models/HunyuanVideo/text_encoder/model.safetensors",
603
+ "models/HunyuanVideo/text_encoder_2",
604
+ "models/HunyuanVideo/vae/pytorch_model.pt",
605
+ "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
606
+ ],
607
+ },
608
+ "HunyuanVideoI2V":{
609
+ "file_list": [
610
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
611
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
612
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
613
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
614
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
615
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
616
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
617
+ ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
618
+ ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
619
+ ],
620
+ "load_path": [
621
+ "models/HunyuanVideoI2V/text_encoder/model.safetensors",
622
+ "models/HunyuanVideoI2V/text_encoder_2",
623
+ "models/HunyuanVideoI2V/vae/pytorch_model.pt",
624
+ "models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
625
+ ],
626
+ },
627
+ "HunyuanVideo-fp8":{
628
+ "file_list": [
629
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
630
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
631
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
632
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
633
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
634
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
635
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
636
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
637
+ ("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
638
+ ],
639
+ "load_path": [
640
+ "models/HunyuanVideo/text_encoder/model.safetensors",
641
+ "models/HunyuanVideo/text_encoder_2",
642
+ "models/HunyuanVideo/vae/pytorch_model.pt",
643
+ "models/HunyuanVideo/transformers/model.fp8.safetensors"
644
+ ],
645
+ },
646
+ }
647
+ Preset_model_id: TypeAlias = Literal[
648
+ "HunyuanDiT",
649
+ "stable-video-diffusion-img2vid-xt",
650
+ "ExVideo-SVD-128f-v1",
651
+ "ExVideo-CogVideoX-LoRA-129f-v1",
652
+ "StableDiffusion_v15",
653
+ "DreamShaper_8",
654
+ "AingDiffusion_v12",
655
+ "Flat2DAnimerge_v45Sharp",
656
+ "TextualInversion_VeryBadImageNegative_v1.3",
657
+ "StableDiffusionXL_v1",
658
+ "BluePencilXL_v200",
659
+ "StableDiffusionXL_Turbo",
660
+ "ControlNet_v11f1p_sd15_depth",
661
+ "ControlNet_v11p_sd15_softedge",
662
+ "ControlNet_v11f1e_sd15_tile",
663
+ "ControlNet_v11p_sd15_lineart",
664
+ "AnimateDiff_v2",
665
+ "AnimateDiff_xl_beta",
666
+ "RIFE",
667
+ "BeautifulPrompt",
668
+ "opus-mt-zh-en",
669
+ "IP-Adapter-SD",
670
+ "IP-Adapter-SDXL",
671
+ "StableDiffusion3",
672
+ "StableDiffusion3_without_T5",
673
+ "Kolors",
674
+ "SDXL-vae-fp16-fix",
675
+ "ControlNet_union_sdxl_promax",
676
+ "FLUX.1-dev",
677
+ "FLUX.1-schnell",
678
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha",
679
+ "jasperai/Flux.1-dev-Controlnet-Depth",
680
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals",
681
+ "jasperai/Flux.1-dev-Controlnet-Upscaler",
682
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
683
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
684
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
685
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
686
+ "InstantX/FLUX.1-dev-IP-Adapter",
687
+ "InfiniteYou",
688
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
689
+ "QwenPrompt",
690
+ "OmostPrompt",
691
+ "ESRGAN_x4",
692
+ "RIFE",
693
+ "OmniGen-v1",
694
+ "CogVideoX-5B",
695
+ "Annotators:Depth",
696
+ "Annotators:Softedge",
697
+ "Annotators:Lineart",
698
+ "Annotators:Normal",
699
+ "Annotators:Openpose",
700
+ "StableDiffusion3.5-large",
701
+ "StableDiffusion3.5-medium",
702
+ "HunyuanVideo",
703
+ "HunyuanVideo-fp8",
704
+ "HunyuanVideoI2V",
705
+ ]
diffsynth/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .video import VideoData, save_video, save_frames
diffsynth/data/video.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import imageio
4
+ import imageio_ffmpeg as ffmpeg
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from matplotlib import cm
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+
11
+
12
+ class LowMemoryVideo:
13
+ def __init__(self, file_name):
14
+ self.reader = imageio.get_reader(file_name)
15
+
16
+ def __len__(self):
17
+ return self.reader.count_frames()
18
+
19
+ def __getitem__(self, item):
20
+ return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
21
+
22
+ def __del__(self):
23
+ self.reader.close()
24
+
25
+
26
+ def split_file_name(file_name):
27
+ result = []
28
+ number = -1
29
+ for i in file_name:
30
+ if ord(i) >= ord("0") and ord(i) <= ord("9"):
31
+ if number == -1:
32
+ number = 0
33
+ number = number * 10 + ord(i) - ord("0")
34
+ else:
35
+ if number != -1:
36
+ result.append(number)
37
+ number = -1
38
+ result.append(i)
39
+ if number != -1:
40
+ result.append(number)
41
+ result = tuple(result)
42
+ return result
43
+
44
+
45
+ def search_for_images(folder):
46
+ file_list = [
47
+ i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")
48
+ ]
49
+ file_list = [(split_file_name(file_name), file_name)
50
+ for file_name in file_list]
51
+ file_list = [i[1] for i in sorted(file_list)]
52
+ file_list = [os.path.join(folder, i) for i in file_list]
53
+ return file_list
54
+
55
+
56
+ class LowMemoryImageFolder:
57
+ def __init__(self, folder, file_list=None):
58
+ if file_list is None:
59
+ self.file_list = search_for_images(folder)
60
+ else:
61
+ self.file_list = [
62
+ os.path.join(folder, file_name) for file_name in file_list
63
+ ]
64
+
65
+ def __len__(self):
66
+ return len(self.file_list)
67
+
68
+ def __getitem__(self, item):
69
+ return Image.open(self.file_list[item]).convert("RGB")
70
+
71
+ def __del__(self):
72
+ pass
73
+
74
+
75
+ def crop_and_resize(image, height, width):
76
+ image = np.array(image)
77
+ image_height, image_width, _ = image.shape
78
+ if image_height / image_width < height / width:
79
+ croped_width = int(image_height / height * width)
80
+ left = (image_width - croped_width) // 2
81
+ image = image[:, left: left + croped_width]
82
+ image = Image.fromarray(image).resize((width, height))
83
+ else:
84
+ croped_height = int(image_width / width * height)
85
+ left = (image_height - croped_height) // 2
86
+ image = image[left: left + croped_height, :]
87
+ image = Image.fromarray(image).resize((width, height))
88
+ return image
89
+
90
+
91
+ class VideoData:
92
+ def __init__(
93
+ self, video_file=None, image_folder=None, height=None, width=None, **kwargs
94
+ ):
95
+ if video_file is not None:
96
+ self.data_type = "video"
97
+ self.data = LowMemoryVideo(video_file, **kwargs)
98
+ elif image_folder is not None:
99
+ self.data_type = "images"
100
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
101
+ else:
102
+ raise ValueError("Cannot open video or image folder")
103
+ self.length = None
104
+ self.set_shape(height, width)
105
+
106
+ def raw_data(self):
107
+ frames = []
108
+ for i in range(self.__len__()):
109
+ frames.append(self.__getitem__(i))
110
+ return frames
111
+
112
+ def set_length(self, length):
113
+ self.length = length
114
+
115
+ def set_shape(self, height, width):
116
+ self.height = height
117
+ self.width = width
118
+
119
+ def __len__(self):
120
+ if self.length is None:
121
+ return len(self.data)
122
+ else:
123
+ return self.length
124
+
125
+ def shape(self):
126
+ if self.height is not None and self.width is not None:
127
+ return self.height, self.width
128
+ else:
129
+ height, width, _ = self.__getitem__(0).shape
130
+ return height, width
131
+
132
+ def __getitem__(self, item):
133
+ frame = self.data.__getitem__(item)
134
+ width, height = frame.size
135
+ if self.height is not None and self.width is not None:
136
+ if self.height != height or self.width != width:
137
+ frame = crop_and_resize(frame, self.height, self.width)
138
+ return frame
139
+
140
+ def __del__(self):
141
+ pass
142
+
143
+ def save_images(self, folder):
144
+ os.makedirs(folder, exist_ok=True)
145
+ for i in tqdm(range(self.__len__()), desc="Saving images"):
146
+ frame = self.__getitem__(i)
147
+ frame.save(os.path.join(folder, f"{i}.png"))
148
+
149
+
150
+ def save_video_ffmpeg(frames, save_path, fps):
151
+ # frames: numpy array T H W C, dtype=uint8
152
+ frames = np.array(frames)
153
+ if frames.dtype != np.uint8:
154
+ frames = (frames * 255).clip(0, 255).astype(np.uint8)
155
+ T, H, W, C = frames.shape
156
+ assert C in [1, 3, 4]
157
+
158
+ writer = ffmpeg.write_frames(
159
+ save_path,
160
+ (W, H),
161
+ fps=fps,
162
+ quality=9,
163
+ macro_block_size=None, # 避免补边
164
+ )
165
+ for frame in frames:
166
+ writer.send(frame)
167
+
168
+ writer.close()
169
+
170
+
171
+ def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None, grayscale=True):
172
+ writer = imageio.get_writer(
173
+ save_path, fps=fps, quality=quality, macro_block_size=1, ffmpeg_params=ffmpeg_params
174
+ )
175
+ if not grayscale:
176
+ cmap = plt.get_cmap('Spectral_r')
177
+ lut = (cmap(np.linspace(0, 1, 256))[:, :3] * 255).astype(np.uint8)
178
+
179
+ for frame in frames:
180
+ frame = np.array(frame)
181
+
182
+ if not grayscale:
183
+ if frame.ndim == 3:
184
+ if frame.shape[-1] >= 3:
185
+ frame = frame[..., 0]
186
+ elif frame.shape[-1] == 1:
187
+ frame = frame[..., 0]
188
+
189
+ if frame.dtype in [np.float32, np.float64]:
190
+ indices = (frame * 255).clip(0, 255).astype(np.uint8)
191
+ else:
192
+ indices = frame.clip(0, 255).astype(np.uint8)
193
+
194
+ frame_out = lut[indices]
195
+
196
+ else:
197
+
198
+ if frame.dtype in [np.float32, np.float64]:
199
+ frame_out = (frame * 255).clip(0, 255).astype(np.uint8)
200
+ else:
201
+ frame_out = frame.astype(np.uint8)
202
+
203
+ if frame_out.ndim == 3 and frame_out.shape[-1] == 1:
204
+ frame_out = frame_out[..., 0]
205
+
206
+ writer.append_data(frame_out)
207
+
208
+ writer.close()
209
+ # writer = imageio.get_writer(
210
+ # save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
211
+ # )
212
+ # import matplotlib.pyplot as plt
213
+ # from matplotlib import cm
214
+
215
+ # cmap = plt.get_cmap('Spectral_r') if not grayscale else None
216
+
217
+ # for frame in frames:
218
+ # frame = np.array(frame)
219
+
220
+ # if not grayscale:
221
+ # if frame.ndim == 3 and frame.shape[-1] >= 3:
222
+ # frame = frame[..., 0]
223
+
224
+ # if frame.dtype == np.uint8:
225
+ # frame = frame.astype(np.float32) / 255.0
226
+
227
+ # rgba_frame = cmap(frame)
228
+ # frame = (rgba_frame[..., :3] * 255).clip(0, 255).astype(np.uint8)
229
+
230
+ # else:
231
+ # if frame.dtype == np.float32 or frame.dtype == np.float64:
232
+ # frame = (frame * 255).clip(0, 255).astype(np.uint8)
233
+
234
+ # writer.append_data(frame)
235
+
236
+ # writer.close()
237
+
238
+
239
+ def save_frames(frames, save_path):
240
+ os.makedirs(save_path, exist_ok=True)
241
+ for i, frame in enumerate(tqdm(frames, desc="Saving images")):
242
+ frame.save(os.path.join(save_path, f"{i}.png"))
243
+ frame.save(os.path.join(save_path, f"{i}.png"))
244
+ frame.save(os.path.join(save_path, f"{i}.png"))
diffsynth/distributed/__init__.py ADDED
File without changes
diffsynth/distributed/xdit_context_parallel.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+ from einops import rearrange
4
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
5
+ get_sequence_parallel_world_size,
6
+ get_sp_group)
7
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
8
+
9
+ def sinusoidal_embedding_1d(dim, position):
10
+ sinusoid = torch.outer(position.type(torch.float64), torch.pow(
11
+ 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
12
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
13
+ return x.to(position.dtype)
14
+
15
+ def pad_freqs(original_tensor, target_len):
16
+ seq_len, s1, s2 = original_tensor.shape
17
+ pad_size = target_len - seq_len
18
+ padding_tensor = torch.ones(
19
+ pad_size,
20
+ s1,
21
+ s2,
22
+ dtype=original_tensor.dtype,
23
+ device=original_tensor.device)
24
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
25
+ return padded_tensor
26
+
27
+ def rope_apply(x, freqs, num_heads):
28
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
29
+ s_per_rank = x.shape[1]
30
+
31
+ x_out = torch.view_as_complex(x.to(torch.float64).reshape(
32
+ x.shape[0], x.shape[1], x.shape[2], -1, 2))
33
+
34
+ sp_size = get_sequence_parallel_world_size()
35
+ sp_rank = get_sequence_parallel_rank()
36
+ freqs = pad_freqs(freqs, s_per_rank * sp_size)
37
+ freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
38
+
39
+ x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
40
+ return x_out.to(x.dtype)
41
+
42
+ def usp_dit_forward(self,
43
+ x: torch.Tensor,
44
+ timestep: torch.Tensor,
45
+ context: torch.Tensor,
46
+ clip_feature: Optional[torch.Tensor] = None,
47
+ y: Optional[torch.Tensor] = None,
48
+ use_gradient_checkpointing: bool = False,
49
+ use_gradient_checkpointing_offload: bool = False,
50
+ **kwargs,
51
+ ):
52
+ t = self.time_embedding(
53
+ sinusoidal_embedding_1d(self.freq_dim, timestep))
54
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
55
+ context = self.text_embedding(context)
56
+
57
+ if self.has_image_input:
58
+ x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
59
+ clip_embdding = self.img_emb(clip_feature)
60
+ context = torch.cat([clip_embdding, context], dim=1)
61
+
62
+ x, (f, h, w) = self.patchify(x)
63
+
64
+ freqs = torch.cat([
65
+ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
66
+ self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
67
+ self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
68
+ ], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
69
+
70
+ def create_custom_forward(module):
71
+ def custom_forward(*inputs):
72
+ return module(*inputs)
73
+ return custom_forward
74
+
75
+ # Context Parallel
76
+ x = torch.chunk(
77
+ x, get_sequence_parallel_world_size(),
78
+ dim=1)[get_sequence_parallel_rank()]
79
+
80
+ for block in self.blocks:
81
+ if self.training and use_gradient_checkpointing:
82
+ if use_gradient_checkpointing_offload:
83
+ with torch.autograd.graph.save_on_cpu():
84
+ x = torch.utils.checkpoint.checkpoint(
85
+ create_custom_forward(block),
86
+ x, context, t_mod, freqs,
87
+ use_reentrant=False,
88
+ )
89
+ else:
90
+ x = torch.utils.checkpoint.checkpoint(
91
+ create_custom_forward(block),
92
+ x, context, t_mod, freqs,
93
+ use_reentrant=False,
94
+ )
95
+ else:
96
+ x = block(x, context, t_mod, freqs)
97
+
98
+ x = self.head(x, t)
99
+
100
+ # Context Parallel
101
+ x = get_sp_group().all_gather(x, dim=1)
102
+
103
+ # unpatchify
104
+ x = self.unpatchify(x, (f, h, w))
105
+ return x
106
+
107
+
108
+ def usp_attn_forward(self, x, freqs):
109
+ q = self.norm_q(self.q(x))
110
+ k = self.norm_k(self.k(x))
111
+ v = self.v(x)
112
+
113
+ q = rope_apply(q, freqs, self.num_heads)
114
+ k = rope_apply(k, freqs, self.num_heads)
115
+ q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
116
+ k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
117
+ v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
118
+
119
+ x = xFuserLongContextAttention()(
120
+ None,
121
+ query=q,
122
+ key=k,
123
+ value=v,
124
+ )
125
+ x = x.flatten(2)
126
+
127
+ del q, k, v
128
+ torch.cuda.empty_cache()
129
+ return self.o(x)
diffsynth/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_manager import *
diffsynth/models/downloader.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from typing import List
4
+
5
+ from huggingface_hub import hf_hub_download
6
+ from modelscope import snapshot_download
7
+ from typing_extensions import Literal, TypeAlias
8
+
9
+ from ..configs.model_config import (Preset_model_id,
10
+ preset_models_on_huggingface,
11
+ preset_models_on_modelscope)
12
+
13
+
14
+ def download_from_modelscope(model_id, origin_file_path, local_dir):
15
+ os.makedirs(local_dir, exist_ok=True)
16
+ file_name = os.path.basename(origin_file_path)
17
+ if file_name in os.listdir(local_dir):
18
+ print(f" {file_name} has been already in {local_dir}.")
19
+ else:
20
+ print(f" Start downloading {os.path.join(local_dir, file_name)}")
21
+ snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
22
+ downloaded_file_path = os.path.join(local_dir, origin_file_path)
23
+ target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
24
+ if downloaded_file_path != target_file_path:
25
+ shutil.move(downloaded_file_path, target_file_path)
26
+ shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
27
+
28
+
29
+ def download_from_huggingface(model_id, origin_file_path, local_dir):
30
+ os.makedirs(local_dir, exist_ok=True)
31
+ file_name = os.path.basename(origin_file_path)
32
+ if file_name in os.listdir(local_dir):
33
+ print(f" {file_name} has been already in {local_dir}.")
34
+ else:
35
+ print(f" Start downloading {os.path.join(local_dir, file_name)}")
36
+ hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
37
+ downloaded_file_path = os.path.join(local_dir, origin_file_path)
38
+ target_file_path = os.path.join(local_dir, file_name)
39
+ if downloaded_file_path != target_file_path:
40
+ shutil.move(downloaded_file_path, target_file_path)
41
+ shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
42
+
43
+
44
+ Preset_model_website: TypeAlias = Literal[
45
+ "HuggingFace",
46
+ "ModelScope",
47
+ ]
48
+ website_to_preset_models = {
49
+ "HuggingFace": preset_models_on_huggingface,
50
+ "ModelScope": preset_models_on_modelscope,
51
+ }
52
+ website_to_download_fn = {
53
+ "HuggingFace": download_from_huggingface,
54
+ "ModelScope": download_from_modelscope,
55
+ }
56
+
57
+
58
+ def download_customized_models(
59
+ model_id,
60
+ origin_file_path,
61
+ local_dir,
62
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
63
+ ):
64
+ downloaded_files = []
65
+ for website in downloading_priority:
66
+ # Check if the file is downloaded.
67
+ file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
68
+ if file_to_download in downloaded_files:
69
+ continue
70
+ # Download
71
+ website_to_download_fn[website](model_id, origin_file_path, local_dir)
72
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
73
+ downloaded_files.append(file_to_download)
74
+ return downloaded_files
75
+
76
+
77
+ def download_models(
78
+ model_id_list: List[Preset_model_id] = [],
79
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
80
+ ):
81
+ print(f"Downloading models: {model_id_list}")
82
+ downloaded_files = []
83
+ load_files = []
84
+
85
+ for model_id in model_id_list:
86
+ for website in downloading_priority:
87
+ if model_id in website_to_preset_models[website]:
88
+
89
+ # Parse model metadata
90
+ model_metadata = website_to_preset_models[website][model_id]
91
+ if isinstance(model_metadata, list):
92
+ file_data = model_metadata
93
+ else:
94
+ file_data = model_metadata.get("file_list", [])
95
+
96
+ # Try downloading the model from this website.
97
+ model_files = []
98
+ for model_id, origin_file_path, local_dir in file_data:
99
+ # Check if the file is downloaded.
100
+ file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
101
+ if file_to_download in downloaded_files:
102
+ continue
103
+ # Download
104
+ website_to_download_fn[website](model_id, origin_file_path, local_dir)
105
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
106
+ downloaded_files.append(file_to_download)
107
+ model_files.append(file_to_download)
108
+
109
+ # If the model is successfully downloaded, break.
110
+ if len(model_files) > 0:
111
+ if isinstance(model_metadata, dict) and "load_path" in model_metadata:
112
+ model_files = model_metadata["load_path"]
113
+ load_files.extend(model_files)
114
+ break
115
+
116
+ return load_files
diffsynth/models/model_manager.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import json
3
+ import os
4
+ from typing import List
5
+
6
+ import torch
7
+
8
+ from ..configs.model_config import (huggingface_model_loader_configs,
9
+ model_loader_configs,
10
+ patch_model_loader_configs)
11
+ from .downloader import (Preset_model_id, Preset_model_website,
12
+ download_customized_models, download_models)
13
+ from .utils import (hash_state_dict_keys, init_weights_on_device,
14
+ load_state_dict, split_state_dict_with_prefix)
15
+
16
+
17
+ def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
18
+ loaded_model_names, loaded_models = [], []
19
+ for model_name, model_class in zip(model_names, model_classes):
20
+ print(f" model_name: {model_name} model_class: {model_class.__name__}")
21
+ state_dict_converter = model_class.state_dict_converter()
22
+ if model_resource == "civitai":
23
+ state_dict_results = state_dict_converter.from_civitai(state_dict)
24
+ elif model_resource == "diffusers":
25
+ state_dict_results = state_dict_converter.from_diffusers(state_dict)
26
+ if isinstance(state_dict_results, tuple):
27
+ model_state_dict, extra_kwargs = state_dict_results
28
+ print(f" This model is initialized with extra kwargs: {extra_kwargs}")
29
+ else:
30
+ model_state_dict, extra_kwargs = state_dict_results, {}
31
+ torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
32
+ with init_weights_on_device():
33
+ model = model_class(**extra_kwargs)
34
+ if hasattr(model, "eval"):
35
+ model = model.eval()
36
+ model.load_state_dict(model_state_dict, assign=True)
37
+ model = model.to(dtype=torch_dtype, device=device)
38
+ loaded_model_names.append(model_name)
39
+ loaded_models.append(model)
40
+ return loaded_model_names, loaded_models
41
+
42
+
43
+ def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
44
+ loaded_model_names, loaded_models = [], []
45
+ for model_name, model_class in zip(model_names, model_classes):
46
+ if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
47
+ model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
48
+ else:
49
+ model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
50
+ if torch_dtype == torch.float16 and hasattr(model, "half"):
51
+ model = model.half()
52
+ try:
53
+ model = model.to(device=device)
54
+ except:
55
+ pass
56
+ loaded_model_names.append(model_name)
57
+ loaded_models.append(model)
58
+ return loaded_model_names, loaded_models
59
+
60
+
61
+ def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
62
+ print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
63
+ base_state_dict = base_model.state_dict()
64
+ base_model.to("cpu")
65
+ del base_model
66
+ model = model_class(**extra_kwargs)
67
+ model.load_state_dict(base_state_dict, strict=False)
68
+ model.load_state_dict(state_dict, strict=False)
69
+ model.to(dtype=torch_dtype, device=device)
70
+ return model
71
+
72
+
73
+ def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
74
+ loaded_model_names, loaded_models = [], []
75
+ for model_name, model_class in zip(model_names, model_classes):
76
+ while True:
77
+ for model_id in range(len(model_manager.model)):
78
+ base_model_name = model_manager.model_name[model_id]
79
+ if base_model_name == model_name:
80
+ base_model_path = model_manager.model_path[model_id]
81
+ base_model = model_manager.model[model_id]
82
+ print(f" Adding patch model to {base_model_name} ({base_model_path})")
83
+ patched_model = load_single_patch_model_from_single_file(
84
+ state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
85
+ loaded_model_names.append(base_model_name)
86
+ loaded_models.append(patched_model)
87
+ model_manager.model.pop(model_id)
88
+ model_manager.model_path.pop(model_id)
89
+ model_manager.model_name.pop(model_id)
90
+ break
91
+ else:
92
+ break
93
+ return loaded_model_names, loaded_models
94
+
95
+
96
+
97
+ class ModelDetectorTemplate:
98
+ def __init__(self):
99
+ pass
100
+
101
+ def match(self, file_path="", state_dict={}):
102
+ return False
103
+
104
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
105
+ return [], []
106
+
107
+
108
+
109
+ class ModelDetectorFromSingleFile:
110
+ def __init__(self, model_loader_configs=[]):
111
+ self.keys_hash_with_shape_dict = {}
112
+ self.keys_hash_dict = {}
113
+ for metadata in model_loader_configs:
114
+ self.add_model_metadata(*metadata)
115
+
116
+
117
+ def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
118
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
119
+ if keys_hash is not None:
120
+ self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
121
+
122
+
123
+ def match(self, file_path="", state_dict={}):
124
+ if isinstance(file_path, str) and os.path.isdir(file_path):
125
+ return False
126
+ if len(state_dict) == 0:
127
+ state_dict = load_state_dict(file_path)
128
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
129
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
130
+ return True
131
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
132
+ if keys_hash in self.keys_hash_dict:
133
+ return True
134
+ return False
135
+
136
+
137
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
138
+ if len(state_dict) == 0:
139
+ state_dict = load_state_dict(file_path)
140
+
141
+ # Load models with strict matching
142
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
143
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
144
+ model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
145
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
146
+ return loaded_model_names, loaded_models
147
+
148
+ # Load models without strict matching
149
+ # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
150
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
151
+ if keys_hash in self.keys_hash_dict:
152
+ model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
153
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
154
+ return loaded_model_names, loaded_models
155
+
156
+ return loaded_model_names, loaded_models
157
+
158
+
159
+
160
+ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
161
+ def __init__(self, model_loader_configs=[]):
162
+ super().__init__(model_loader_configs)
163
+
164
+
165
+ def match(self, file_path="", state_dict={}):
166
+ if isinstance(file_path, str) and os.path.isdir(file_path):
167
+ return False
168
+ if len(state_dict) == 0:
169
+ state_dict = load_state_dict(file_path)
170
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
171
+ for sub_state_dict in splited_state_dict:
172
+ if super().match(file_path, sub_state_dict):
173
+ return True
174
+ return False
175
+
176
+
177
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
178
+ # Split the state_dict and load from each component
179
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
180
+ valid_state_dict = {}
181
+ for sub_state_dict in splited_state_dict:
182
+ if super().match(file_path, sub_state_dict):
183
+ valid_state_dict.update(sub_state_dict)
184
+ if super().match(file_path, valid_state_dict):
185
+ loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
186
+ else:
187
+ loaded_model_names, loaded_models = [], []
188
+ for sub_state_dict in splited_state_dict:
189
+ if super().match(file_path, sub_state_dict):
190
+ loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
191
+ loaded_model_names += loaded_model_names_
192
+ loaded_models += loaded_models_
193
+ return loaded_model_names, loaded_models
194
+
195
+
196
+
197
+ class ModelDetectorFromHuggingfaceFolder:
198
+ def __init__(self, model_loader_configs=[]):
199
+ self.architecture_dict = {}
200
+ for metadata in model_loader_configs:
201
+ self.add_model_metadata(*metadata)
202
+
203
+
204
+ def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
205
+ self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
206
+
207
+
208
+ def match(self, file_path="", state_dict={}):
209
+ if not isinstance(file_path, str) or os.path.isfile(file_path):
210
+ return False
211
+ file_list = os.listdir(file_path)
212
+ if "config.json" not in file_list:
213
+ return False
214
+ with open(os.path.join(file_path, "config.json"), "r") as f:
215
+ config = json.load(f)
216
+ if "architectures" not in config and "_class_name" not in config:
217
+ return False
218
+ return True
219
+
220
+
221
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
222
+ with open(os.path.join(file_path, "config.json"), "r") as f:
223
+ config = json.load(f)
224
+ loaded_model_names, loaded_models = [], []
225
+ architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
226
+ for architecture in architectures:
227
+ huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
228
+ if redirected_architecture is not None:
229
+ architecture = redirected_architecture
230
+ model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
231
+ loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
232
+ loaded_model_names += loaded_model_names_
233
+ loaded_models += loaded_models_
234
+ return loaded_model_names, loaded_models
235
+
236
+
237
+
238
+ class ModelDetectorFromPatchedSingleFile:
239
+ def __init__(self, model_loader_configs=[]):
240
+ self.keys_hash_with_shape_dict = {}
241
+ for metadata in model_loader_configs:
242
+ self.add_model_metadata(*metadata)
243
+
244
+
245
+ def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
246
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
247
+
248
+
249
+ def match(self, file_path="", state_dict={}):
250
+ if not isinstance(file_path, str) or os.path.isdir(file_path):
251
+ return False
252
+ if len(state_dict) == 0:
253
+ state_dict = load_state_dict(file_path)
254
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
255
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
256
+ return True
257
+ return False
258
+
259
+
260
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
261
+ if len(state_dict) == 0:
262
+ state_dict = load_state_dict(file_path)
263
+
264
+ # Load models with strict matching
265
+ loaded_model_names, loaded_models = [], []
266
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
267
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
268
+ model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
269
+ loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
270
+ state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
271
+ loaded_model_names += loaded_model_names_
272
+ loaded_models += loaded_models_
273
+ return loaded_model_names, loaded_models
274
+
275
+
276
+
277
+ class ModelManager:
278
+ def __init__(
279
+ self,
280
+ torch_dtype=torch.float16,
281
+ device="cuda",
282
+ model_id_list: List[Preset_model_id] = [],
283
+ downloading_priority: List[Preset_model_website] = [ "HuggingFace","ModelScope"],
284
+ file_path_list: List[str] = [],
285
+ ):
286
+ self.torch_dtype = torch_dtype
287
+ self.device = device
288
+ self.model = []
289
+ self.model_path = []
290
+ self.model_name = []
291
+ downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
292
+ self.model_detector = [
293
+ ModelDetectorFromSingleFile(model_loader_configs),
294
+ ModelDetectorFromSplitedSingleFile(model_loader_configs),
295
+ ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
296
+ ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
297
+ ]
298
+ self.load_models(downloaded_files + file_path_list)
299
+
300
+
301
+ def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
302
+ # print(f"Loading models from file: {file_path}")
303
+ if len(state_dict) == 0:
304
+ state_dict = load_state_dict(file_path)
305
+ model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
306
+ for model_name, model in zip(model_names, models):
307
+ self.model.append(model)
308
+ self.model_path.append(file_path)
309
+ self.model_name.append(model_name)
310
+ print(f" The following models are loaded: {model_names}.")
311
+
312
+
313
+ def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
314
+ # print(f"Loading models from folder: {file_path}")
315
+ model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
316
+ for model_name, model in zip(model_names, models):
317
+ self.model.append(model)
318
+ self.model_path.append(file_path)
319
+ self.model_name.append(model_name)
320
+ print(f" The following models are loaded: {model_names}.")
321
+
322
+
323
+ def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
324
+ print(f"Loading patch models from file: {file_path}")
325
+ model_names, models = load_patch_model_from_single_file(
326
+ state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
327
+ for model_name, model in zip(model_names, models):
328
+ self.model.append(model)
329
+ self.model_path.append(file_path)
330
+ self.model_name.append(model_name)
331
+ print(f" The following patched models are loaded: {model_names}.")
332
+
333
+
334
+ def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
335
+ if isinstance(file_path, list):
336
+ for file_path_ in file_path:
337
+ self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
338
+ else:
339
+ print(f"Loading LoRA models from file: {file_path}")
340
+ is_loaded = False
341
+ if len(state_dict) == 0:
342
+ state_dict = load_state_dict(file_path)
343
+ for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
344
+ for lora in get_lora_loaders():
345
+ match_results = lora.match(model, state_dict)
346
+ if match_results is not None:
347
+ print(f" Adding LoRA to {model_name} ({model_path}).")
348
+ lora_prefix, model_resource = match_results
349
+ lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
350
+ is_loaded = True
351
+ break
352
+ if not is_loaded:
353
+ print(f" Cannot load LoRA: {file_path}")
354
+
355
+
356
+ def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
357
+ # print(f"Loading models from: {file_path}")
358
+ if device is None: device = self.device
359
+ if torch_dtype is None: torch_dtype = self.torch_dtype
360
+ if isinstance(file_path, list):
361
+ state_dict = {}
362
+ for path in file_path:
363
+ state_dict.update(load_state_dict(path))
364
+ elif os.path.isfile(file_path):
365
+ state_dict = load_state_dict(file_path)
366
+ else:
367
+ state_dict = None
368
+ for model_detector in self.model_detector:
369
+ if model_detector.match(file_path, state_dict):
370
+ model_names, models = model_detector.load(
371
+ file_path, state_dict,
372
+ device=device, torch_dtype=torch_dtype,
373
+ allowed_model_names=model_names, model_manager=self
374
+ )
375
+ for model_name, model in zip(model_names, models):
376
+ self.model.append(model)
377
+ self.model_path.append(file_path)
378
+ self.model_name.append(model_name)
379
+ print(f" The following models are loaded: {model_names}.")
380
+ break
381
+ else:
382
+ print(f" We cannot detect the model type. No models are loaded.")
383
+
384
+
385
+ def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
386
+ for file_path in file_path_list:
387
+ self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
388
+
389
+
390
+ def fetch_model(self, model_name, file_path=None, require_model_path=False):
391
+ fetched_models = []
392
+ fetched_model_paths = []
393
+ for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
394
+ if file_path is not None and file_path != model_path:
395
+ continue
396
+ if model_name == model_name_:
397
+ fetched_models.append(model)
398
+ fetched_model_paths.append(model_path)
399
+ if len(fetched_models) == 0:
400
+ print(f"No {model_name} models available.")
401
+ return None
402
+ if len(fetched_models) == 1:
403
+ pass
404
+ # print(f"Using {model_name} from {fetched_model_paths[0]}.")
405
+ else:
406
+ print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
407
+ if require_model_path:
408
+ return fetched_models[0], fetched_model_paths[0]
409
+ else:
410
+ return fetched_models[0]
411
+
412
+
413
+ def to(self, device):
414
+ for model in self.model:
415
+ model.to(device)
416
+
diffsynth/models/tiler.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange, repeat
3
+
4
+
5
+ class TileWorker:
6
+ def __init__(self):
7
+ pass
8
+
9
+
10
+ def mask(self, height, width, border_width):
11
+ # Create a mask with shape (height, width).
12
+ # The centre area is filled with 1, and the border line is filled with values in range (0, 1].
13
+ x = torch.arange(height).repeat(width, 1).T
14
+ y = torch.arange(width).repeat(height, 1)
15
+ mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
16
+ mask = (mask / border_width).clip(0, 1)
17
+ return mask
18
+
19
+
20
+ def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
21
+ # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
22
+ batch_size, channel, _, _ = model_input.shape
23
+ model_input = model_input.to(device=tile_device, dtype=tile_dtype)
24
+ unfold_operator = torch.nn.Unfold(
25
+ kernel_size=(tile_size, tile_size),
26
+ stride=(tile_stride, tile_stride)
27
+ )
28
+ model_input = unfold_operator(model_input)
29
+ model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
30
+
31
+ return model_input
32
+
33
+
34
+ def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
35
+ # Call y=forward_fn(x) for each tile
36
+ tile_num = model_input.shape[-1]
37
+ model_output_stack = []
38
+
39
+ for tile_id in range(0, tile_num, tile_batch_size):
40
+
41
+ # process input
42
+ tile_id_ = min(tile_id + tile_batch_size, tile_num)
43
+ x = model_input[:, :, :, :, tile_id: tile_id_]
44
+ x = x.to(device=inference_device, dtype=inference_dtype)
45
+ x = rearrange(x, "b c h w n -> (n b) c h w")
46
+
47
+ # process output
48
+ y = forward_fn(x)
49
+ y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
50
+ y = y.to(device=tile_device, dtype=tile_dtype)
51
+ model_output_stack.append(y)
52
+
53
+ model_output = torch.concat(model_output_stack, dim=-1)
54
+ return model_output
55
+
56
+
57
+ def io_scale(self, model_output, tile_size):
58
+ # Determine the size modification happened in forward_fn
59
+ # We only consider the same scale on height and width.
60
+ io_scale = model_output.shape[2] / tile_size
61
+ return io_scale
62
+
63
+
64
+ def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
65
+ # The reversed function of tile
66
+ mask = self.mask(tile_size, tile_size, border_width)
67
+ mask = mask.to(device=tile_device, dtype=tile_dtype)
68
+ mask = rearrange(mask, "h w -> 1 1 h w 1")
69
+ model_output = model_output * mask
70
+
71
+ fold_operator = torch.nn.Fold(
72
+ output_size=(height, width),
73
+ kernel_size=(tile_size, tile_size),
74
+ stride=(tile_stride, tile_stride)
75
+ )
76
+ mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
77
+ model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
78
+ model_output = fold_operator(model_output) / fold_operator(mask)
79
+
80
+ return model_output
81
+
82
+
83
+ def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
84
+ # Prepare
85
+ inference_device, inference_dtype = model_input.device, model_input.dtype
86
+ height, width = model_input.shape[2], model_input.shape[3]
87
+ border_width = int(tile_stride*0.5) if border_width is None else border_width
88
+
89
+ # tile
90
+ model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
91
+
92
+ # inference
93
+ model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
94
+
95
+ # resize
96
+ io_scale = self.io_scale(model_output, tile_size)
97
+ height, width = int(height*io_scale), int(width*io_scale)
98
+ tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
99
+ border_width = int(border_width*io_scale)
100
+
101
+ # untile
102
+ model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
103
+
104
+ # Done!
105
+ model_output = model_output.to(device=inference_device, dtype=inference_dtype)
106
+ return model_output
107
+
108
+
109
+
110
+ class FastTileWorker:
111
+ def __init__(self):
112
+ pass
113
+
114
+
115
+ def build_mask(self, data, is_bound):
116
+ _, _, H, W = data.shape
117
+ h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
118
+ w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
119
+ border_width = (H + W) // 4
120
+ pad = torch.ones_like(h) * border_width
121
+ mask = torch.stack([
122
+ pad if is_bound[0] else h + 1,
123
+ pad if is_bound[1] else H - h,
124
+ pad if is_bound[2] else w + 1,
125
+ pad if is_bound[3] else W - w
126
+ ]).min(dim=0).values
127
+ mask = mask.clip(1, border_width)
128
+ mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
129
+ mask = rearrange(mask, "H W -> 1 H W")
130
+ return mask
131
+
132
+
133
+ def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
134
+ # Prepare
135
+ B, C, H, W = model_input.shape
136
+ border_width = int(tile_stride*0.5) if border_width is None else border_width
137
+ weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device)
138
+ values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device)
139
+
140
+ # Split tasks
141
+ tasks = []
142
+ for h in range(0, H, tile_stride):
143
+ for w in range(0, W, tile_stride):
144
+ if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
145
+ continue
146
+ h_, w_ = h + tile_size, w + tile_size
147
+ if h_ > H: h, h_ = H - tile_size, H
148
+ if w_ > W: w, w_ = W - tile_size, W
149
+ tasks.append((h, h_, w, w_))
150
+
151
+ # Run
152
+ for hl, hr, wl, wr in tasks:
153
+ # Forward
154
+ hidden_states_batch = forward_fn(hl, hr, wl, wr).to(dtype=tile_dtype, device=tile_device)
155
+
156
+ mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
157
+ values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
158
+ weight[:, :, hl:hr, wl:wr] += mask
159
+ values /= weight
160
+ return values
161
+
162
+
163
+
164
+ class TileWorker2Dto3D:
165
+ """
166
+ Process 3D tensors, but only enable TileWorker on 2D.
167
+ """
168
+ def __init__(self):
169
+ pass
170
+
171
+
172
+ def build_mask(self, T, H, W, dtype, device, is_bound, border_width):
173
+ t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
174
+ h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
175
+ w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
176
+ border_width = (H + W) // 4 if border_width is None else border_width
177
+ pad = torch.ones_like(h) * border_width
178
+ mask = torch.stack([
179
+ pad if is_bound[0] else t + 1,
180
+ pad if is_bound[1] else T - t,
181
+ pad if is_bound[2] else h + 1,
182
+ pad if is_bound[3] else H - h,
183
+ pad if is_bound[4] else w + 1,
184
+ pad if is_bound[5] else W - w
185
+ ]).min(dim=0).values
186
+ mask = mask.clip(1, border_width)
187
+ mask = (mask / border_width).to(dtype=dtype, device=device)
188
+ mask = rearrange(mask, "T H W -> 1 1 T H W")
189
+ return mask
190
+
191
+
192
+ def tiled_forward(
193
+ self,
194
+ forward_fn,
195
+ model_input,
196
+ tile_size, tile_stride,
197
+ tile_device="cpu", tile_dtype=torch.float32,
198
+ computation_device="cuda", computation_dtype=torch.float32,
199
+ border_width=None, scales=[1, 1, 1, 1],
200
+ progress_bar=lambda x:x
201
+ ):
202
+ B, C, T, H, W = model_input.shape
203
+ scale_C, scale_T, scale_H, scale_W = scales
204
+ tile_size_H, tile_size_W = tile_size
205
+ tile_stride_H, tile_stride_W = tile_stride
206
+
207
+ value = torch.zeros((B, int(C*scale_C), int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
208
+ weight = torch.zeros((1, 1, int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
209
+
210
+ # Split tasks
211
+ tasks = []
212
+ for h in range(0, H, tile_stride_H):
213
+ for w in range(0, W, tile_stride_W):
214
+ if (h-tile_stride_H >= 0 and h-tile_stride_H+tile_size_H >= H) or (w-tile_stride_W >= 0 and w-tile_stride_W+tile_size_W >= W):
215
+ continue
216
+ h_, w_ = h + tile_size_H, w + tile_size_W
217
+ if h_ > H: h, h_ = max(H - tile_size_H, 0), H
218
+ if w_ > W: w, w_ = max(W - tile_size_W, 0), W
219
+ tasks.append((h, h_, w, w_))
220
+
221
+ # Run
222
+ for hl, hr, wl, wr in progress_bar(tasks):
223
+ mask = self.build_mask(
224
+ int(T*scale_T), int((hr-hl)*scale_H), int((wr-wl)*scale_W),
225
+ tile_dtype, tile_device,
226
+ is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W),
227
+ border_width=border_width
228
+ )
229
+ grid_input = model_input[:, :, :, hl:hr, wl:wr].to(dtype=computation_dtype, device=computation_device)
230
+ grid_output = forward_fn(grid_input).to(dtype=tile_dtype, device=tile_device)
231
+ value[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += grid_output * mask
232
+ weight[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += mask
233
+ value = value / weight
234
+ return value
diffsynth/models/utils.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ from contextlib import contextmanager
4
+
5
+ import torch
6
+ from safetensors import safe_open
7
+
8
+
9
+ @contextmanager
10
+ def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
11
+
12
+ old_register_parameter = torch.nn.Module.register_parameter
13
+ if include_buffers:
14
+ old_register_buffer = torch.nn.Module.register_buffer
15
+
16
+ def register_empty_parameter(module, name, param):
17
+ old_register_parameter(module, name, param)
18
+ if param is not None:
19
+ param_cls = type(module._parameters[name])
20
+ kwargs = module._parameters[name].__dict__
21
+ kwargs["requires_grad"] = param.requires_grad
22
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
23
+
24
+ def register_empty_buffer(module, name, buffer, persistent=True):
25
+ old_register_buffer(module, name, buffer, persistent=persistent)
26
+ if buffer is not None:
27
+ module._buffers[name] = module._buffers[name].to(device)
28
+
29
+ def patch_tensor_constructor(fn):
30
+ def wrapper(*args, **kwargs):
31
+ kwargs["device"] = device
32
+ return fn(*args, **kwargs)
33
+
34
+ return wrapper
35
+
36
+ if include_buffers:
37
+ tensor_constructors_to_patch = {
38
+ torch_function_name: getattr(torch, torch_function_name)
39
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
40
+ }
41
+ else:
42
+ tensor_constructors_to_patch = {}
43
+
44
+ try:
45
+ torch.nn.Module.register_parameter = register_empty_parameter
46
+ if include_buffers:
47
+ torch.nn.Module.register_buffer = register_empty_buffer
48
+ for torch_function_name in tensor_constructors_to_patch.keys():
49
+ setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
50
+ yield
51
+ finally:
52
+ torch.nn.Module.register_parameter = old_register_parameter
53
+ if include_buffers:
54
+ torch.nn.Module.register_buffer = old_register_buffer
55
+ for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
56
+ setattr(torch, torch_function_name, old_torch_function)
57
+
58
+ def load_state_dict_from_folder(file_path, torch_dtype=None):
59
+ state_dict = {}
60
+ for file_name in os.listdir(file_path):
61
+ if "." in file_name and file_name.split(".")[-1] in [
62
+ "safetensors", "bin", "ckpt", "pth", "pt"
63
+ ]:
64
+ state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
65
+ return state_dict
66
+
67
+
68
+ def load_state_dict(file_path, torch_dtype=None, device="cpu"):
69
+ if file_path.endswith(".safetensors"):
70
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
71
+ else:
72
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
73
+
74
+
75
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
76
+ state_dict = {}
77
+ with safe_open(file_path, framework="pt", device=device) as f:
78
+ for k in f.keys():
79
+ state_dict[k] = f.get_tensor(k)
80
+ if torch_dtype is not None:
81
+ state_dict[k] = state_dict[k].to(torch_dtype)
82
+ return state_dict
83
+
84
+
85
+ def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
86
+ state_dict = torch.load(file_path, map_location=device, weights_only=True)
87
+ if torch_dtype is not None:
88
+ for i in state_dict:
89
+ if isinstance(state_dict[i], torch.Tensor):
90
+ state_dict[i] = state_dict[i].to(torch_dtype)
91
+ return state_dict
92
+
93
+
94
+ def search_for_embeddings(state_dict):
95
+ embeddings = []
96
+ for k in state_dict:
97
+ if isinstance(state_dict[k], torch.Tensor):
98
+ embeddings.append(state_dict[k])
99
+ elif isinstance(state_dict[k], dict):
100
+ embeddings += search_for_embeddings(state_dict[k])
101
+ return embeddings
102
+
103
+
104
+ def search_parameter(param, state_dict):
105
+ for name, param_ in state_dict.items():
106
+ if param.numel() == param_.numel():
107
+ if param.shape == param_.shape:
108
+ if torch.dist(param, param_) < 1e-3:
109
+ return name
110
+ else:
111
+ if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
112
+ return name
113
+ return None
114
+
115
+
116
+ def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
117
+ matched_keys = set()
118
+ with torch.no_grad():
119
+ for name in source_state_dict:
120
+ rename = search_parameter(source_state_dict[name], target_state_dict)
121
+ if rename is not None:
122
+ print(f'"{name}": "{rename}",')
123
+ matched_keys.add(rename)
124
+ elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
125
+ length = source_state_dict[name].shape[0] // 3
126
+ rename = []
127
+ for i in range(3):
128
+ rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
129
+ if None not in rename:
130
+ print(f'"{name}": {rename},')
131
+ for rename_ in rename:
132
+ matched_keys.add(rename_)
133
+ for name in target_state_dict:
134
+ if name not in matched_keys:
135
+ print("Cannot find", name, target_state_dict[name].shape)
136
+
137
+
138
+ def search_for_files(folder, extensions):
139
+ files = []
140
+ if os.path.isdir(folder):
141
+ for file in sorted(os.listdir(folder)):
142
+ files += search_for_files(os.path.join(folder, file), extensions)
143
+ elif os.path.isfile(folder):
144
+ for extension in extensions:
145
+ if folder.endswith(extension):
146
+ files.append(folder)
147
+ break
148
+ return files
149
+
150
+
151
+ def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
152
+ keys = []
153
+ for key, value in state_dict.items():
154
+ if isinstance(key, str):
155
+ if isinstance(value, torch.Tensor):
156
+ if with_shape:
157
+ shape = "_".join(map(str, list(value.shape)))
158
+ keys.append(key + ":" + shape)
159
+ keys.append(key)
160
+ elif isinstance(value, dict):
161
+ keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
162
+ keys.sort()
163
+ keys_str = ",".join(keys)
164
+ return keys_str
165
+
166
+
167
+ def split_state_dict_with_prefix(state_dict):
168
+ keys = sorted([key for key in state_dict if isinstance(key, str)])
169
+ prefix_dict = {}
170
+ for key in keys:
171
+ prefix = key if "." not in key else key.split(".")[0]
172
+ if prefix not in prefix_dict:
173
+ prefix_dict[prefix] = []
174
+ prefix_dict[prefix].append(key)
175
+ state_dicts = []
176
+ for prefix, keys in prefix_dict.items():
177
+ sub_state_dict = {key: state_dict[key] for key in keys}
178
+ state_dicts.append(sub_state_dict)
179
+ return state_dicts
180
+
181
+
182
+ def hash_state_dict_keys(state_dict, with_shape=True):
183
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
184
+ keys_str = keys_str.encode(encoding="UTF-8")
185
+ return hashlib.md5(keys_str).hexdigest()
diffsynth/models/wan_video_camera_controller.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from einops import rearrange
5
+ import os
6
+ from typing_extensions import Literal
7
+
8
+
9
+ class SimpleAdapter(nn.Module):
10
+ def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1):
11
+ super(SimpleAdapter, self).__init__()
12
+
13
+ # Pixel Unshuffle: reduce spatial dimensions by a factor of 8
14
+ self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
15
+
16
+ # Convolution: reduce spatial dimensions by a factor
17
+ # of 2 (without overlap)
18
+ self.conv = nn.Conv2d(in_dim * 64, out_dim,
19
+ kernel_size=kernel_size, stride=stride, padding=0)
20
+
21
+ # Residual blocks for feature extraction
22
+ self.residual_blocks = nn.Sequential(
23
+ *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]
24
+ )
25
+
26
+ def forward(self, x):
27
+ # Reshape to merge the frame dimension into batch
28
+ bs, c, f, h, w = x.size()
29
+ x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
30
+
31
+ # Pixel Unshuffle operation
32
+ x_unshuffled = self.pixel_unshuffle(x)
33
+
34
+ # Convolution operation
35
+ x_conv = self.conv(x_unshuffled)
36
+
37
+ # Feature extraction with residual blocks
38
+ out = self.residual_blocks(x_conv)
39
+
40
+ # Reshape to restore original bf dimension
41
+ out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
42
+
43
+ # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
44
+ out = out.permute(0, 2, 1, 3, 4)
45
+
46
+ return out
47
+
48
+ def process_camera_coordinates(
49
+ self,
50
+ direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
51
+ length: int,
52
+ height: int,
53
+ width: int,
54
+ speed: float = 1/54,
55
+ origin=(0, 0.532139961, 0.946026558, 0.5, 0.5,
56
+ 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
57
+ ):
58
+ if origin is None:
59
+ origin = (0, 0.532139961, 0.946026558, 0.5, 0.5,
60
+ 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
61
+ print(
62
+ f"Generating camera coordinates with direction: {direction}, length: {length}, speed: {speed}, origin: {origin}")
63
+ coordinates = generate_camera_coordinates(
64
+ direction, length, speed, origin)
65
+ print(f"Generated {len(coordinates)} camera coordinates.")
66
+ plucker_embedding = process_pose_file(coordinates, width, height)
67
+ print(
68
+ f"Processed camera coordinates into plucker embedding with shape: {plucker_embedding.shape}")
69
+ return plucker_embedding
70
+
71
+
72
+ class ResidualBlock(nn.Module):
73
+ def __init__(self, dim):
74
+ super(ResidualBlock, self).__init__()
75
+ self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
76
+ self.relu = nn.ReLU(inplace=True)
77
+ self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
78
+
79
+ def forward(self, x):
80
+ residual = x
81
+ out = self.relu(self.conv1(x))
82
+ out = self.conv2(out)
83
+ out += residual
84
+ return out
85
+
86
+
87
+ class Camera(object):
88
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
89
+ """
90
+
91
+ def __init__(self, entry):
92
+ fx, fy, cx, cy = entry[1:5]
93
+ self.fx = fx
94
+ self.fy = fy
95
+ self.cx = cx
96
+ self.cy = cy
97
+ w2c_mat = np.array(entry[7:]).reshape(3, 4)
98
+ w2c_mat_4x4 = np.eye(4)
99
+ w2c_mat_4x4[:3, :] = w2c_mat
100
+ self.w2c_mat = w2c_mat_4x4
101
+ self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
102
+
103
+
104
+ def get_relative_pose(cam_params):
105
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
106
+ """
107
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
108
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
109
+ cam_to_origin = 0
110
+ target_cam_c2w = np.array([
111
+ [1, 0, 0, 0],
112
+ [0, 1, 0, -cam_to_origin],
113
+ [0, 0, 1, 0],
114
+ [0, 0, 0, 1]
115
+ ])
116
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
117
+ ret_poses = [target_cam_c2w, ] + \
118
+ [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
119
+ ret_poses = np.array(ret_poses, dtype=np.float32)
120
+ return ret_poses
121
+
122
+
123
+ def custom_meshgrid(*args):
124
+ # torch>=2.0.0 only
125
+ return torch.meshgrid(*args, indexing='ij')
126
+
127
+
128
+ def ray_condition(K, c2w, H, W, device):
129
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
130
+ """
131
+ # c2w: B, V, 4, 4
132
+ # K: B, V, 4
133
+
134
+ B = K.shape[0]
135
+
136
+ j, i = custom_meshgrid(
137
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
138
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
139
+ )
140
+ i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
141
+ j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
142
+
143
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
144
+
145
+ zs = torch.ones_like(i) # [B, HxW]
146
+ xs = (i - cx) / fx * zs
147
+ ys = (j - cy) / fy * zs
148
+ zs = zs.expand_as(ys)
149
+
150
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
151
+ directions = directions / \
152
+ directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
153
+
154
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
155
+ rays_o = c2w[..., :3, 3] # B, V, 3
156
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
157
+ # c2w @ dirctions
158
+ rays_dxo = torch.linalg.cross(rays_o, rays_d)
159
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
160
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
161
+ # plucker = plucker.permute(0, 1, 4, 2, 3)
162
+ return plucker
163
+
164
+
165
+ def process_pose_file(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
166
+ if return_poses:
167
+ return cam_params
168
+ else:
169
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
170
+
171
+ sample_wh_ratio = width / height
172
+ # Assuming placeholder ratios, change as needed
173
+ pose_wh_ratio = original_pose_width / original_pose_height
174
+
175
+ if pose_wh_ratio > sample_wh_ratio:
176
+ resized_ori_w = height * pose_wh_ratio
177
+ for cam_param in cam_params:
178
+ cam_param.fx = resized_ori_w * cam_param.fx / width
179
+ else:
180
+ resized_ori_h = width / pose_wh_ratio
181
+ for cam_param in cam_params:
182
+ cam_param.fy = resized_ori_h * cam_param.fy / height
183
+
184
+ intrinsic = np.asarray([[cam_param.fx * width,
185
+ cam_param.fy * height,
186
+ cam_param.cx * width,
187
+ cam_param.cy * height]
188
+ for cam_param in cam_params], dtype=np.float32)
189
+
190
+ K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
191
+ # Assuming this function is defined elsewhere
192
+ c2ws = get_relative_pose(cam_params)
193
+ c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
194
+ plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[
195
+ 0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
196
+ plucker_embedding = plucker_embedding[None]
197
+ plucker_embedding = rearrange(
198
+ plucker_embedding, "b f c h w -> b f h w c")[0]
199
+ return plucker_embedding
200
+
201
+
202
+ def generate_camera_coordinates(
203
+ direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
204
+ length: int,
205
+ speed: float = 1/54,
206
+ origin=(0, 0.532139961, 0.946026558, 0.5, 0.5,
207
+ 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
208
+ ):
209
+ coordinates = [list(origin)]
210
+ while len(coordinates) < length:
211
+ coor = coordinates[-1].copy()
212
+ if "Left" in direction:
213
+ coor[9] += speed
214
+ if "Right" in direction:
215
+ coor[9] -= speed
216
+ if "Up" in direction:
217
+ coor[13] += speed
218
+ if "Down" in direction:
219
+ coor[13] -= speed
220
+ coordinates.append(coor)
221
+ return coordinates
diffsynth/models/wan_video_dit.py ADDED
@@ -0,0 +1,974 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from diffusers.models.lora import LoRALinearLayer
8
+ from einops import rearrange
9
+
10
+ from .utils import hash_state_dict_keys
11
+ from .wan_video_camera_controller import SimpleAdapter
12
+
13
+ try:
14
+ import flash_attn_interface
15
+
16
+ FLASH_ATTN_3_AVAILABLE = True
17
+ except ModuleNotFoundError:
18
+ FLASH_ATTN_3_AVAILABLE = False
19
+
20
+ try:
21
+ import flash_attn
22
+
23
+ FLASH_ATTN_2_AVAILABLE = True
24
+ except ModuleNotFoundError:
25
+ FLASH_ATTN_2_AVAILABLE = False
26
+
27
+ try:
28
+ from sageattention import sageattn
29
+
30
+ SAGE_ATTN_AVAILABLE = True
31
+ print(f"========= Using sage attention, please note that this is for inference speed up only!==========")
32
+ except ModuleNotFoundError:
33
+ SAGE_ATTN_AVAILABLE = False
34
+
35
+
36
+ def flash_attention(
37
+ q: torch.Tensor,
38
+ k: torch.Tensor,
39
+ v: torch.Tensor,
40
+ num_heads: int,
41
+ compatibility_mode=False,
42
+ ):
43
+ if compatibility_mode:
44
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
45
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
46
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
47
+ x = F.scaled_dot_product_attention(q, k, v)
48
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
49
+ elif FLASH_ATTN_3_AVAILABLE:
50
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
51
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
52
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
53
+ x = flash_attn_interface.flash_attn_func(q, k, v)
54
+ if isinstance(x, tuple):
55
+ x = x[0]
56
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
57
+ elif FLASH_ATTN_2_AVAILABLE:
58
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
59
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
60
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
61
+ x = flash_attn.flash_attn_func(q, k, v)
62
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
63
+ elif SAGE_ATTN_AVAILABLE:
64
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
65
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
66
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
67
+ x = sageattn(q, k, v)
68
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
69
+ else:
70
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
71
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
72
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
73
+ x = F.scaled_dot_product_attention(q, k, v)
74
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
75
+ return x
76
+
77
+
78
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
79
+ return x * (1 + scale) + shift
80
+
81
+
82
+ def sinusoidal_embedding_1d(dim, position):
83
+ sinusoid = torch.outer(
84
+ position.type(torch.float64),
85
+ torch.pow(
86
+ 10000,
87
+ -torch.arange(dim // 2, dtype=torch.float64, device=position.device).div(
88
+ dim // 2
89
+ ),
90
+ ),
91
+ )
92
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
93
+ return x.to(position.dtype)
94
+
95
+
96
+ def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
97
+ # 3d rope precompute
98
+ f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
99
+ h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
100
+ w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
101
+ return f_freqs_cis, h_freqs_cis, w_freqs_cis
102
+
103
+
104
+ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
105
+ # 1d rope precompute
106
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
107
+ [: (dim // 2)].double() / dim))
108
+ freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
109
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
110
+ return freqs_cis
111
+
112
+
113
+ def rope_apply(x, freqs, num_heads):
114
+ x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
115
+ x_out = torch.view_as_complex(
116
+ x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
117
+ )
118
+ x_out = torch.view_as_real(x_out * freqs).flatten(2)
119
+ return x_out.to(x.dtype)
120
+
121
+
122
+ class RMSNorm(nn.Module):
123
+ def __init__(self, dim, eps=1e-5):
124
+ super().__init__()
125
+ self.eps = eps
126
+ self.weight = nn.Parameter(torch.ones(dim))
127
+
128
+ def norm(self, x):
129
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
130
+
131
+ def forward(self, x):
132
+ dtype = x.dtype
133
+ # print(f"x device: {x.device}, weight device: {self.weight.device}")
134
+ return self.norm(x.float()).to(dtype) * self.weight
135
+
136
+
137
+ class AttentionModule(nn.Module):
138
+ def __init__(self, num_heads):
139
+ super().__init__()
140
+ self.num_heads = num_heads
141
+
142
+ def forward(self, q, k, v):
143
+ x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
144
+ return x
145
+
146
+
147
+ class SelfAttention(nn.Module):
148
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
149
+ super().__init__()
150
+ self.dim = dim
151
+ self.num_heads = num_heads
152
+ self.head_dim = dim // num_heads
153
+
154
+ self.q = nn.Linear(dim, dim)
155
+ self.k = nn.Linear(dim, dim)
156
+ self.v = nn.Linear(dim, dim)
157
+ self.o = nn.Linear(dim, dim)
158
+ self.norm_q = RMSNorm(dim, eps=eps)
159
+ self.norm_k = RMSNorm(dim, eps=eps)
160
+
161
+ self.attn = AttentionModule(self.num_heads)
162
+
163
+ def forward(self, x, freqs):
164
+ q = self.norm_q(self.q(x))
165
+ k = self.norm_k(self.k(x))
166
+ v = self.v(x)
167
+ q = rope_apply(q, freqs, self.num_heads)
168
+ k = rope_apply(k, freqs, self.num_heads)
169
+ x = self.attn(q, k, v)
170
+ return self.o(x)
171
+
172
+
173
+ class CrossAttention(nn.Module):
174
+ def __init__(
175
+ self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False
176
+ ):
177
+ super().__init__()
178
+ self.dim = dim
179
+ self.num_heads = num_heads
180
+ self.head_dim = dim // num_heads
181
+
182
+ self.q = nn.Linear(dim, dim)
183
+ self.k = nn.Linear(dim, dim)
184
+ self.v = nn.Linear(dim, dim)
185
+ self.o = nn.Linear(dim, dim)
186
+ self.norm_q = RMSNorm(dim, eps=eps)
187
+ self.norm_k = RMSNorm(dim, eps=eps)
188
+ self.has_image_input = has_image_input
189
+ if has_image_input:
190
+ self.k_img = nn.Linear(dim, dim)
191
+ self.v_img = nn.Linear(dim, dim)
192
+ self.norm_k_img = RMSNorm(dim, eps=eps)
193
+
194
+ self.attn = AttentionModule(self.num_heads)
195
+
196
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
197
+ if self.has_image_input:
198
+ img = y[:, :257]
199
+ ctx = y[:, 257:]
200
+ else:
201
+ ctx = y
202
+ q = self.norm_q(self.q(x))
203
+ k = self.norm_k(self.k(ctx))
204
+ v = self.v(ctx)
205
+ x = self.attn(q, k, v)
206
+ if self.has_image_input:
207
+ k_img = self.norm_k_img(self.k_img(img))
208
+ v_img = self.v_img(img)
209
+ y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
210
+ x = x + y
211
+ return self.o(x)
212
+
213
+
214
+ class SelfAttentionSeparate(nn.Module):
215
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, rank=64):
216
+ super().__init__()
217
+ self.dim = dim
218
+ self.num_heads = num_heads
219
+ self.head_dim = dim // num_heads
220
+
221
+ self.q = nn.Linear(dim, dim)
222
+ self.k = nn.Linear(dim, dim)
223
+ self.v = nn.Linear(dim, dim)
224
+ self.o = nn.Linear(dim, dim)
225
+ if rank > 0:
226
+ # LoraLinear
227
+ self.q_zl_before = LoRALinearLayer(dim, dim, rank=rank)
228
+ self.k_zl_before = LoRALinearLayer(dim, dim, rank=rank)
229
+ self.v_zl_before = LoRALinearLayer(dim, dim, rank=rank)
230
+
231
+ self.q_zl_after = LoRALinearLayer(dim, dim, rank=rank)
232
+ self.k_zl_after = LoRALinearLayer(dim, dim, rank=rank)
233
+ self.v_zl_after = LoRALinearLayer(dim, dim, rank=rank)
234
+ else:
235
+ # Normal Linear
236
+ self.q_zl_before = nn.Linear(dim, dim)
237
+ self.k_zl_before = nn.Linear(dim, dim)
238
+ self.v_zl_before = nn.Linear(dim, dim)
239
+
240
+ self.q_zl_after = nn.Linear(dim, dim)
241
+ self.k_zl_after = nn.Linear(dim, dim)
242
+ self.v_zl_after = nn.Linear(dim, dim)
243
+
244
+ self.norm_q = RMSNorm(dim, eps=eps)
245
+ self.norm_k = RMSNorm(dim, eps=eps)
246
+
247
+ self.attn = AttentionModule(self.num_heads)
248
+ self.zero_init_linear()
249
+
250
+ def zero_init_linear(self):
251
+ layers_to_handle = [
252
+ self.q_zl_before,
253
+ self.k_zl_before,
254
+ self.v_zl_before,
255
+ self.q_zl_after,
256
+ self.k_zl_after,
257
+ self.v_zl_after,
258
+ ]
259
+ for _layer in layers_to_handle:
260
+ if isinstance(_layer, nn.Linear):
261
+ nn.init.zeros_(_layer.weight)
262
+ if _layer.bias is not None:
263
+ nn.init.zeros_(_layer.bias)
264
+
265
+ def forward(self, x, freqs, camera_pose_embedding=None):
266
+ if camera_pose_embedding is not None:
267
+
268
+ q = self.norm_q(self.q(x))
269
+ k = self.norm_k(self.k(x))
270
+ v = self.v(x)
271
+ # TODO uncomment
272
+ # ----------------------------------------------------------------
273
+ else:
274
+ q = self.norm_q(
275
+ self.q(x + self.q_zl_before(camera_pose_embedding))
276
+ + self.q_zl_after(camera_pose_embedding)
277
+ )
278
+ k = self.norm_k(
279
+ self.k(x + self.k_zl_before(camera_pose_embedding))
280
+ + self.k_zl_after(camera_pose_embedding)
281
+ )
282
+ v = self.v(x + self.v_zl_before(camera_pose_embedding)) + self.v_zl_after(
283
+ camera_pose_embedding
284
+ )
285
+ # --------------------------------------------------------------------
286
+
287
+ q = rope_apply(q, freqs, self.num_heads)
288
+ k = rope_apply(k, freqs, self.num_heads)
289
+ x = self.attn(q, k, v)
290
+ return self.o(x)
291
+
292
+
293
+ class CrossAttentionSeparate(nn.Module):
294
+ def __init__(
295
+ self,
296
+ dim: int,
297
+ num_heads: int,
298
+ eps: float = 1e-6,
299
+ has_image_input: bool = False,
300
+ rank=64,
301
+ ):
302
+ super().__init__()
303
+ self.dim = dim
304
+ self.num_heads = num_heads
305
+ self.head_dim = dim // num_heads
306
+
307
+ self.q = nn.Linear(dim, dim)
308
+ self.k = nn.Linear(dim, dim)
309
+ self.v = nn.Linear(dim, dim)
310
+ self.o = nn.Linear(dim, dim)
311
+
312
+ if rank > 0:
313
+ # LoraLinear
314
+ self.q_zl_before = LoRALinearLayer(dim, dim, rank=rank)
315
+ self.q_zl_after = LoRALinearLayer(dim, dim, rank=rank)
316
+ else:
317
+ # Normal linear
318
+ self.q_zl_before = nn.Linear(dim, dim, bias=False)
319
+ self.q_zl_after = nn.Linear(dim, dim, bias=False)
320
+
321
+ self.norm_q = RMSNorm(dim, eps=eps)
322
+ self.norm_k = RMSNorm(dim, eps=eps)
323
+ self.has_image_input = has_image_input
324
+ if has_image_input:
325
+ self.k_img = nn.Linear(dim, dim)
326
+ self.v_img = nn.Linear(dim, dim)
327
+ self.norm_k_img = RMSNorm(dim, eps=eps)
328
+
329
+ self.attn = AttentionModule(self.num_heads)
330
+ self.zero_init_linear()
331
+
332
+ def zero_init_linear(self):
333
+ layers_to_handle = [
334
+ self.q_zl_before,
335
+ self.q_zl_after,
336
+ ]
337
+ for _layer in layers_to_handle:
338
+ if isinstance(_layer, nn.Linear):
339
+ nn.init.zeros_(_layer.weight)
340
+ if _layer.bias is not None:
341
+ nn.init.zeros_(_layer.bias)
342
+
343
+ def forward(self, x: torch.Tensor, y: torch.Tensor, camera_pose_embedding=None):
344
+ if self.has_image_input:
345
+ img = y[:, :257]
346
+ ctx = y[:, 257:]
347
+ else:
348
+ ctx = y
349
+ # q = self.norm_q(self.q(x))
350
+ # k = self.norm_k(self.k(ctx))
351
+ # v = self.v(ctx)
352
+ # TODO uncomment
353
+ # -------------------------------------------------
354
+ q = self.norm_q(
355
+ self.q(x + self.q_zl_before(camera_pose_embedding))
356
+ + self.q_zl_after(camera_pose_embedding)
357
+ )
358
+ k = self.norm_k(self.k(ctx))
359
+ v = self.v(ctx)
360
+ # -------------------------------------------------
361
+
362
+ x = self.attn(q, k, v)
363
+ if self.has_image_input:
364
+ k_img = self.norm_k_img(self.k_img(img))
365
+ v_img = self.v_img(img)
366
+ y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
367
+ x = x + y
368
+ return self.o(x)
369
+
370
+
371
+ class GateModule(nn.Module):
372
+ def __init__(
373
+ self,
374
+ ):
375
+ super().__init__()
376
+
377
+ def forward(self, x, gate, residual):
378
+ return x + gate * residual
379
+
380
+
381
+ class DiTBlock(nn.Module):
382
+ def __init__(
383
+ self,
384
+ has_image_input: bool,
385
+ dim: int,
386
+ num_heads: int,
387
+ ffn_dim: int,
388
+ eps: float = 1e-6,
389
+ ):
390
+ super().__init__()
391
+ self.dim = dim
392
+ self.num_heads = num_heads
393
+ self.ffn_dim = ffn_dim
394
+
395
+ self.self_attn = SelfAttention(dim, num_heads, eps)
396
+ self.cross_attn = CrossAttention(
397
+ dim, num_heads, eps, has_image_input=has_image_input
398
+ )
399
+ self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
400
+ self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
401
+ self.norm3 = nn.LayerNorm(dim, eps=eps)
402
+ self.ffn = nn.Sequential(
403
+ nn.Linear(dim, ffn_dim),
404
+ nn.GELU(approximate="tanh"),
405
+ nn.Linear(ffn_dim, dim),
406
+ )
407
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
408
+ self.gate = GateModule()
409
+
410
+ # @torch.compile(mode='max-autotune')
411
+ def forward(self, x, context, t_mod, freqs, camera_pose_embedding=None):
412
+ # msa: multi-head self-attention mlp: multi-layer perceptron
413
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
414
+ self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod
415
+ ).chunk(6, dim=1)
416
+ input_x = modulate(self.norm1(x), shift_msa, scale_msa)
417
+ x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
418
+ x = x + self.cross_attn(self.norm3(x), context)
419
+ input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
420
+ x = self.gate(x, gate_mlp, self.ffn(input_x))
421
+ return x
422
+
423
+
424
+ class CameraDiTBlock(nn.Module):
425
+ def __init__(
426
+ self,
427
+ has_image_input: bool,
428
+ dim: int,
429
+ num_heads: int,
430
+ ffn_dim: int,
431
+ eps: float = 1e-6,
432
+ camera_lora_rank=64,
433
+ ):
434
+ super().__init__()
435
+ self.dim = dim
436
+ self.num_heads = num_heads
437
+ self.ffn_dim = ffn_dim
438
+
439
+ self.self_attn = SelfAttentionSeparate(
440
+ dim, num_heads, eps, rank=camera_lora_rank
441
+ )
442
+ self.cross_attn = CrossAttentionSeparate(
443
+ dim, num_heads, eps, has_image_input=has_image_input, rank=camera_lora_rank
444
+ )
445
+ self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
446
+ self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
447
+ self.norm3 = nn.LayerNorm(dim, eps=eps)
448
+ self.ffn = nn.Sequential(
449
+ nn.Linear(dim, ffn_dim),
450
+ nn.GELU(approximate="tanh"),
451
+ nn.Linear(ffn_dim, dim),
452
+ )
453
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
454
+ self.gate = GateModule()
455
+
456
+ def forward(self, x, context, t_mod, freqs, camera_pose_embedding=None):
457
+ # msa: multi-head self-attention mlp: multi-layer perceptron
458
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
459
+ self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod
460
+ ).chunk(6, dim=1)
461
+ input_x = modulate(self.norm1(x), shift_msa, scale_msa)
462
+ # x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
463
+ x = self.gate(
464
+ x, gate_msa, self.self_attn(input_x, freqs, camera_pose_embedding)
465
+ )
466
+
467
+ x = x + self.cross_attn(self.norm3(x), context, camera_pose_embedding)
468
+ input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
469
+ x = self.gate(x, gate_mlp, self.ffn(input_x))
470
+ return x
471
+
472
+
473
+ class MLP(torch.nn.Module):
474
+ def __init__(self, in_dim, out_dim, has_pos_emb=False):
475
+ super().__init__()
476
+ self.proj = torch.nn.Sequential(
477
+ nn.LayerNorm(in_dim),
478
+ nn.Linear(in_dim, in_dim),
479
+ nn.GELU(),
480
+ nn.Linear(in_dim, out_dim),
481
+ nn.LayerNorm(out_dim),
482
+ )
483
+ self.has_pos_emb = has_pos_emb
484
+ if has_pos_emb:
485
+ self.emb_pos = torch.nn.Parameter(torch.zeros((1, 514, 1280)))
486
+
487
+ def forward(self, x):
488
+ if self.has_pos_emb:
489
+ x = x + self.emb_pos.to(dtype=x.dtype, device=x.device)
490
+ return self.proj(x)
491
+
492
+
493
+ class Head(nn.Module):
494
+ def __init__(
495
+ self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float
496
+ ):
497
+ super().__init__()
498
+ self.dim = dim
499
+ self.patch_size = patch_size
500
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
501
+ self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
502
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
503
+
504
+ def forward(self, x, t_mod):
505
+ shift, scale = (
506
+ self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod
507
+ ).chunk(2, dim=1)
508
+ x = self.head(self.norm(x) * (1 + scale) + shift)
509
+ return x
510
+
511
+
512
+ class WanModel(torch.nn.Module):
513
+ def __init__(
514
+ self,
515
+ dim: int,
516
+ in_dim: int,
517
+ ffn_dim: int,
518
+ out_dim: int,
519
+ text_dim: int,
520
+ freq_dim: int,
521
+ eps: float,
522
+ patch_size: Tuple[int, int, int],
523
+ num_heads: int,
524
+ num_layers: int,
525
+ has_image_input: bool,
526
+ has_image_pos_emb: bool = False,
527
+ has_ref_conv: bool = False,
528
+ add_control_adapter: bool = False,
529
+ in_dim_control_adapter: int = 24,
530
+ ):
531
+ super().__init__()
532
+ self.dim = dim
533
+ self.freq_dim = freq_dim
534
+ self.has_image_input = has_image_input
535
+ self.patch_size = patch_size
536
+
537
+ self.patch_embedding = nn.Conv3d(
538
+ in_dim, dim, kernel_size=patch_size, stride=patch_size
539
+ )
540
+ self.text_embedding = nn.Sequential(
541
+ nn.Linear(text_dim, dim), nn.GELU(
542
+ approximate="tanh"), nn.Linear(dim, dim)
543
+ )
544
+ self.time_embedding = nn.Sequential(
545
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
546
+ )
547
+ self.time_projection = nn.Sequential(
548
+ nn.SiLU(), nn.Linear(dim, dim * 6))
549
+ self.blocks = nn.ModuleList(
550
+ [
551
+ DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps)
552
+ for _ in range(num_layers)
553
+ ]
554
+ )
555
+ self.head = Head(dim, out_dim, patch_size, eps)
556
+ head_dim = dim // num_heads
557
+ self.freqs = precompute_freqs_cis_3d(head_dim)
558
+
559
+ if has_image_input:
560
+ self.img_emb = MLP(
561
+ 1280, dim, has_pos_emb=has_image_pos_emb
562
+ ) # clip_feature_dim = 1280
563
+ if has_ref_conv:
564
+ self.ref_conv = nn.Conv2d(
565
+ 16, dim, kernel_size=(2, 2), stride=(2, 2))
566
+ self.has_image_pos_emb = has_image_pos_emb
567
+ self.has_ref_conv = has_ref_conv
568
+
569
+ self.control_adapter = None
570
+ self.add_control_adapter = add_control_adapter
571
+ if add_control_adapter:
572
+ self.control_adapter = SimpleAdapter(
573
+ in_dim_control_adapter,
574
+ dim,
575
+ kernel_size=patch_size[1:],
576
+ stride=patch_size[1:],
577
+ )
578
+ else:
579
+ self.control_adapter = None
580
+
581
+ def patchify(
582
+ self, x: torch.Tensor, control_camera_latents_input: torch.Tensor = None
583
+ ):
584
+ x = self.patch_embedding(x)
585
+ if (
586
+ self.control_adapter is not None
587
+ and control_camera_latents_input is not None
588
+ ):
589
+ y_camera = self.control_adapter(control_camera_latents_input)
590
+ x = [u + v for u, v in zip(x, y_camera)]
591
+ # x = x[0].unsqueeze(0)
592
+ grid_size = x.shape[2:]
593
+ x = rearrange(x, "b c f h w -> b (f h w) c").contiguous()
594
+ return x, grid_size # x, grid_size: (f, h, w)
595
+
596
+ def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
597
+ return rearrange(
598
+ x,
599
+ "b (f h w) (x y z c) -> b c (f x) (h y) (w z)",
600
+ f=grid_size[0],
601
+ h=grid_size[1],
602
+ w=grid_size[2],
603
+ x=self.patch_size[0],
604
+ y=self.patch_size[1],
605
+ z=self.patch_size[2],
606
+ )
607
+
608
+ def forward(
609
+ self,
610
+ x: torch.Tensor,
611
+ timestep: torch.Tensor,
612
+ context: torch.Tensor,
613
+ clip_feature: Optional[torch.Tensor] = None,
614
+ y: Optional[torch.Tensor] = None,
615
+ use_gradient_checkpointing: bool = False,
616
+ use_gradient_checkpointing_offload: bool = False,
617
+ **kwargs,
618
+ ):
619
+ t = self.time_embedding(
620
+ sinusoidal_embedding_1d(self.freq_dim, timestep))
621
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
622
+ context = self.text_embedding(context)
623
+
624
+ if self.has_image_input:
625
+ # print(f"x,y shape: {x.shape}, {y.shape if y is not None else 'None'}")
626
+ x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
627
+ clip_embdding = self.img_emb(clip_feature)
628
+ context = torch.cat([clip_embdding, context], dim=1)
629
+
630
+ x, (f, h, w) = self.patchify(x)
631
+
632
+ freqs = (
633
+ torch.cat(
634
+ [
635
+ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
636
+ self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
637
+ self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
638
+ ],
639
+ dim=-1,
640
+ )
641
+ .reshape(f * h * w, 1, -1)
642
+ .to(x.device)
643
+ )
644
+
645
+ def create_custom_forward(module):
646
+ def custom_forward(*inputs):
647
+ return module(*inputs)
648
+
649
+ return custom_forward
650
+
651
+ for block in self.blocks:
652
+ if self.training and use_gradient_checkpointing:
653
+ if use_gradient_checkpointing_offload:
654
+ with torch.autograd.graph.save_on_cpu():
655
+ x = torch.utils.checkpoint.checkpoint(
656
+ create_custom_forward(block),
657
+ x,
658
+ context,
659
+ t_mod,
660
+ freqs,
661
+ use_reentrant=False,
662
+ )
663
+ else:
664
+ x = torch.utils.checkpoint.checkpoint(
665
+ create_custom_forward(block),
666
+ x,
667
+ context,
668
+ t_mod,
669
+ freqs,
670
+ use_reentrant=False,
671
+ )
672
+ else:
673
+ x = block(x, context, t_mod, freqs)
674
+
675
+ x = self.head(x, t)
676
+ x = self.unpatchify(x, (f, h, w))
677
+
678
+ if hasattr(self, 'rgb_head'):
679
+ rgb = self.rgb_head(x, t)
680
+ rgb = self.unpatchify(rgb, (f, h, w))
681
+
682
+ return x, rgb
683
+
684
+ return x
685
+
686
+ @staticmethod
687
+ def state_dict_converter():
688
+ return WanModelStateDictConverter()
689
+
690
+
691
+ class WanModelStateDictConverter:
692
+ def __init__(self):
693
+ pass
694
+
695
+ def from_diffusers(self, state_dict):
696
+ rename_dict = {
697
+ "blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
698
+ "blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
699
+ "blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
700
+ "blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
701
+ "blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
702
+ "blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
703
+ "blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
704
+ "blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
705
+ "blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
706
+ "blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
707
+ "blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
708
+ "blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
709
+ "blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
710
+ "blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
711
+ "blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
712
+ "blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
713
+ "blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
714
+ "blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
715
+ "blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
716
+ "blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
717
+ "blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
718
+ "blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
719
+ "blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
720
+ "blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
721
+ "blocks.0.norm2.bias": "blocks.0.norm3.bias",
722
+ "blocks.0.norm2.weight": "blocks.0.norm3.weight",
723
+ "blocks.0.scale_shift_table": "blocks.0.modulation",
724
+ "condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
725
+ "condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
726
+ "condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
727
+ "condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
728
+ "condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
729
+ "condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
730
+ "condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
731
+ "condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
732
+ "condition_embedder.time_proj.bias": "time_projection.1.bias",
733
+ "condition_embedder.time_proj.weight": "time_projection.1.weight",
734
+ "patch_embedding.bias": "patch_embedding.bias",
735
+ "patch_embedding.weight": "patch_embedding.weight",
736
+ "scale_shift_table": "head.modulation",
737
+ "proj_out.bias": "head.head.bias",
738
+ "proj_out.weight": "head.head.weight",
739
+ }
740
+ state_dict_ = {}
741
+ print(
742
+ f"hash_state_dict_keys(state_dict): {hash_state_dict_keys(state_dict)}")
743
+ for name, param in state_dict.items():
744
+ if name in rename_dict:
745
+ state_dict_[rename_dict[name]] = param
746
+ else:
747
+ name_ = ".".join(name.split(
748
+ ".")[:1] + ["0"] + name.split(".")[2:])
749
+ if name_ in rename_dict:
750
+ name_ = rename_dict[name_]
751
+ name_ = ".".join(
752
+ name_.split(".")[:1]
753
+ + [name.split(".")[1]]
754
+ + name_.split(".")[2:]
755
+ )
756
+ state_dict_[name_] = param
757
+ if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
758
+ config = {
759
+ "model_type": "t2v",
760
+ "patch_size": (1, 2, 2),
761
+ "text_len": 512,
762
+ "in_dim": 16,
763
+ "dim": 5120,
764
+ "ffn_dim": 13824,
765
+ "freq_dim": 256,
766
+ "text_dim": 4096,
767
+ "out_dim": 16,
768
+ "num_heads": 40,
769
+ "num_layers": 40,
770
+ "window_size": (-1, -1),
771
+ "qk_norm": True,
772
+ "cross_attn_norm": True,
773
+ "eps": 1e-6,
774
+ }
775
+ else:
776
+ config = {}
777
+ return state_dict_, config
778
+
779
+ def from_civitai(self, state_dict):
780
+ state_dict = {
781
+ name: param
782
+ for name, param in state_dict.items()
783
+ if not name.startswith("vace")
784
+ }
785
+ print(
786
+ f"hash_state_dict_keys(state_dict): {hash_state_dict_keys(state_dict)} from civitai"
787
+ )
788
+
789
+ if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
790
+ config = {
791
+ "has_image_input": False,
792
+ "patch_size": [1, 2, 2],
793
+ "in_dim": 16,
794
+ "dim": 1536,
795
+ "ffn_dim": 8960,
796
+ "freq_dim": 256,
797
+ "text_dim": 4096,
798
+ "out_dim": 16,
799
+ "num_heads": 12,
800
+ "num_layers": 30,
801
+ "eps": 1e-6,
802
+ }
803
+ elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
804
+ config = {
805
+ "has_image_input": False,
806
+ "patch_size": [1, 2, 2],
807
+ "in_dim": 16,
808
+ "dim": 5120,
809
+ "ffn_dim": 13824,
810
+ "freq_dim": 256,
811
+ "text_dim": 4096,
812
+ "out_dim": 16,
813
+ "num_heads": 40,
814
+ "num_layers": 40,
815
+ "eps": 1e-6,
816
+ }
817
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
818
+ config = {
819
+ "has_image_input": True,
820
+ "patch_size": [1, 2, 2],
821
+ "in_dim": 36,
822
+ "dim": 5120,
823
+ "ffn_dim": 13824,
824
+ "freq_dim": 256,
825
+ "text_dim": 4096,
826
+ "out_dim": 16,
827
+ "num_heads": 40,
828
+ "num_layers": 40,
829
+ "eps": 1e-6,
830
+ }
831
+ elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
832
+ config = {
833
+ "has_image_input": True,
834
+ "patch_size": [1, 2, 2],
835
+ "in_dim": 36,
836
+ "dim": 1536,
837
+ "ffn_dim": 8960,
838
+ "freq_dim": 256,
839
+ "text_dim": 4096,
840
+ "out_dim": 16,
841
+ "num_heads": 12,
842
+ "num_layers": 30,
843
+ "eps": 1e-6,
844
+ }
845
+ elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
846
+ config = {
847
+ "has_image_input": True,
848
+ "patch_size": [1, 2, 2],
849
+ "in_dim": 36,
850
+ "dim": 5120,
851
+ "ffn_dim": 13824,
852
+ "freq_dim": 256,
853
+ "text_dim": 4096,
854
+ "out_dim": 16,
855
+ "num_heads": 40,
856
+ "num_layers": 40,
857
+ "eps": 1e-6,
858
+ }
859
+ elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
860
+ # 1.3B PAI control
861
+ config = {
862
+ "has_image_input": True,
863
+ "patch_size": [1, 2, 2],
864
+ "in_dim": 48,
865
+ "dim": 1536,
866
+ "ffn_dim": 8960,
867
+ "freq_dim": 256,
868
+ "text_dim": 4096,
869
+ "out_dim": 16,
870
+ "num_heads": 12,
871
+ "num_layers": 30,
872
+ "eps": 1e-6,
873
+ }
874
+ elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
875
+ # 14B PAI control
876
+ config = {
877
+ "has_image_input": True,
878
+ "patch_size": [1, 2, 2],
879
+ "in_dim": 48,
880
+ "dim": 5120,
881
+ "ffn_dim": 13824,
882
+ "freq_dim": 256,
883
+ "text_dim": 4096,
884
+ "out_dim": 16,
885
+ "num_heads": 40,
886
+ "num_layers": 40,
887
+ "eps": 1e-6,
888
+ }
889
+ elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
890
+ config = {
891
+ "has_image_input": True,
892
+ "patch_size": [1, 2, 2],
893
+ "in_dim": 36,
894
+ "dim": 5120,
895
+ "ffn_dim": 13824,
896
+ "freq_dim": 256,
897
+ "text_dim": 4096,
898
+ "out_dim": 16,
899
+ "num_heads": 40,
900
+ "num_layers": 40,
901
+ "eps": 1e-6,
902
+ "has_image_pos_emb": True,
903
+ }
904
+ elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
905
+ # 1.3B PAI control v1.1
906
+ config = {
907
+ "has_image_input": True,
908
+ "patch_size": [1, 2, 2],
909
+ "in_dim": 48,
910
+ "dim": 1536,
911
+ "ffn_dim": 8960,
912
+ "freq_dim": 256,
913
+ "text_dim": 4096,
914
+ "out_dim": 16,
915
+ "num_heads": 12,
916
+ "num_layers": 30,
917
+ "eps": 1e-6,
918
+ "has_ref_conv": True,
919
+ }
920
+ elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
921
+ # 14B PAI control v1.1
922
+ config = {
923
+ "has_image_input": True,
924
+ "patch_size": [1, 2, 2],
925
+ "in_dim": 48,
926
+ "dim": 5120,
927
+ "ffn_dim": 13824,
928
+ "freq_dim": 256,
929
+ "text_dim": 4096,
930
+ "out_dim": 16,
931
+ "num_heads": 40,
932
+ "num_layers": 40,
933
+ "eps": 1e-6,
934
+ "has_ref_conv": True,
935
+ }
936
+ elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901":
937
+ # 1.3B PAI control-camera v1.1
938
+ config = {
939
+ "has_image_input": True,
940
+ "patch_size": [1, 2, 2],
941
+ "in_dim": 32,
942
+ "dim": 1536,
943
+ "ffn_dim": 8960,
944
+ "freq_dim": 256,
945
+ "text_dim": 4096,
946
+ "out_dim": 16,
947
+ "num_heads": 12,
948
+ "num_layers": 30,
949
+ "eps": 1e-6,
950
+ "has_ref_conv": False,
951
+ "add_control_adapter": True,
952
+ "in_dim_control_adapter": 24,
953
+ }
954
+ elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae":
955
+ # 14B PAI control-camera v1.1
956
+ config = {
957
+ "has_image_input": True,
958
+ "patch_size": [1, 2, 2],
959
+ "in_dim": 32,
960
+ "dim": 5120,
961
+ "ffn_dim": 13824,
962
+ "freq_dim": 256,
963
+ "text_dim": 4096,
964
+ "out_dim": 16,
965
+ "num_heads": 40,
966
+ "num_layers": 40,
967
+ "eps": 1e-6,
968
+ "has_ref_conv": False,
969
+ "add_control_adapter": True,
970
+ "in_dim_control_adapter": 24,
971
+ }
972
+ else:
973
+ config = {}
974
+ return state_dict, config
diffsynth/models/wan_video_image_encoder.py ADDED
@@ -0,0 +1,902 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Concise re-implementation of
3
+ ``https://github.com/openai/CLIP'' and
4
+ ``https://github.com/mlfoundations/open_clip''.
5
+ """
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as T
11
+ from .wan_video_dit import flash_attention
12
+
13
+
14
+ class SelfAttention(nn.Module):
15
+
16
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
17
+ assert dim % num_heads == 0
18
+ super().__init__()
19
+ self.dim = dim
20
+ self.num_heads = num_heads
21
+ self.head_dim = dim // num_heads
22
+ self.eps = eps
23
+
24
+ # layers
25
+ self.q = nn.Linear(dim, dim)
26
+ self.k = nn.Linear(dim, dim)
27
+ self.v = nn.Linear(dim, dim)
28
+ self.o = nn.Linear(dim, dim)
29
+ self.dropout = nn.Dropout(dropout)
30
+
31
+ def forward(self, x, mask):
32
+ """
33
+ x: [B, L, C].
34
+ """
35
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
36
+
37
+ # compute query, key, value
38
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
39
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
40
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
41
+
42
+ # compute attention
43
+ p = self.dropout.p if self.training else 0.0
44
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
45
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
46
+
47
+ # output
48
+ x = self.o(x)
49
+ x = self.dropout(x)
50
+ return x
51
+
52
+
53
+ class AttentionBlock(nn.Module):
54
+
55
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
56
+ super().__init__()
57
+ self.dim = dim
58
+ self.num_heads = num_heads
59
+ self.post_norm = post_norm
60
+ self.eps = eps
61
+
62
+ # layers
63
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
64
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
65
+ self.ffn = nn.Sequential(
66
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
67
+ nn.Dropout(dropout))
68
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
69
+
70
+ def forward(self, x, mask):
71
+ if self.post_norm:
72
+ x = self.norm1(x + self.attn(x, mask))
73
+ x = self.norm2(x + self.ffn(x))
74
+ else:
75
+ x = x + self.attn(self.norm1(x), mask)
76
+ x = x + self.ffn(self.norm2(x))
77
+ return x
78
+
79
+
80
+ class XLMRoberta(nn.Module):
81
+ """
82
+ XLMRobertaModel with no pooler and no LM head.
83
+ """
84
+
85
+ def __init__(self,
86
+ vocab_size=250002,
87
+ max_seq_len=514,
88
+ type_size=1,
89
+ pad_id=1,
90
+ dim=1024,
91
+ num_heads=16,
92
+ num_layers=24,
93
+ post_norm=True,
94
+ dropout=0.1,
95
+ eps=1e-5):
96
+ super().__init__()
97
+ self.vocab_size = vocab_size
98
+ self.max_seq_len = max_seq_len
99
+ self.type_size = type_size
100
+ self.pad_id = pad_id
101
+ self.dim = dim
102
+ self.num_heads = num_heads
103
+ self.num_layers = num_layers
104
+ self.post_norm = post_norm
105
+ self.eps = eps
106
+
107
+ # embeddings
108
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
109
+ self.type_embedding = nn.Embedding(type_size, dim)
110
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
111
+ self.dropout = nn.Dropout(dropout)
112
+
113
+ # blocks
114
+ self.blocks = nn.ModuleList([
115
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
116
+ for _ in range(num_layers)
117
+ ])
118
+
119
+ # norm layer
120
+ self.norm = nn.LayerNorm(dim, eps=eps)
121
+
122
+ def forward(self, ids):
123
+ """
124
+ ids: [B, L] of torch.LongTensor.
125
+ """
126
+ b, s = ids.shape
127
+ mask = ids.ne(self.pad_id).long()
128
+
129
+ # embeddings
130
+ x = self.token_embedding(ids) + \
131
+ self.type_embedding(torch.zeros_like(ids)) + \
132
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
133
+ if self.post_norm:
134
+ x = self.norm(x)
135
+ x = self.dropout(x)
136
+
137
+ # blocks
138
+ mask = torch.where(
139
+ mask.view(b, 1, 1, s).gt(0), 0.0,
140
+ torch.finfo(x.dtype).min)
141
+ for block in self.blocks:
142
+ x = block(x, mask)
143
+
144
+ # output
145
+ if not self.post_norm:
146
+ x = self.norm(x)
147
+ return x
148
+
149
+
150
+ def xlm_roberta_large(pretrained=False,
151
+ return_tokenizer=False,
152
+ device='cpu',
153
+ **kwargs):
154
+ """
155
+ XLMRobertaLarge adapted from Huggingface.
156
+ """
157
+ # params
158
+ cfg = dict(
159
+ vocab_size=250002,
160
+ max_seq_len=514,
161
+ type_size=1,
162
+ pad_id=1,
163
+ dim=1024,
164
+ num_heads=16,
165
+ num_layers=24,
166
+ post_norm=True,
167
+ dropout=0.1,
168
+ eps=1e-5)
169
+ cfg.update(**kwargs)
170
+
171
+ # init model
172
+ if pretrained:
173
+ from sora import DOWNLOAD_TO_CACHE
174
+
175
+ # init a meta model
176
+ with torch.device('meta'):
177
+ model = XLMRoberta(**cfg)
178
+
179
+ # load checkpoint
180
+ model.load_state_dict(
181
+ torch.load(
182
+ DOWNLOAD_TO_CACHE('models/xlm_roberta/xlm_roberta_large.pth'),
183
+ map_location=device),
184
+ assign=True)
185
+ else:
186
+ # init a model on device
187
+ with torch.device(device):
188
+ model = XLMRoberta(**cfg)
189
+
190
+ # init tokenizer
191
+ if return_tokenizer:
192
+ from sora.data import HuggingfaceTokenizer
193
+ tokenizer = HuggingfaceTokenizer(
194
+ name='xlm-roberta-large',
195
+ seq_len=model.text_len,
196
+ clean='whitespace')
197
+ return model, tokenizer
198
+ else:
199
+ return model
200
+
201
+
202
+
203
+ def pos_interpolate(pos, seq_len):
204
+ if pos.size(1) == seq_len:
205
+ return pos
206
+ else:
207
+ src_grid = int(math.sqrt(pos.size(1)))
208
+ tar_grid = int(math.sqrt(seq_len))
209
+ n = pos.size(1) - src_grid * src_grid
210
+ return torch.cat([
211
+ pos[:, :n],
212
+ F.interpolate(
213
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
214
+ 0, 3, 1, 2),
215
+ size=(tar_grid, tar_grid),
216
+ mode='bicubic',
217
+ align_corners=False).flatten(2).transpose(1, 2)
218
+ ],
219
+ dim=1)
220
+
221
+
222
+ class QuickGELU(nn.Module):
223
+
224
+ def forward(self, x):
225
+ return x * torch.sigmoid(1.702 * x)
226
+
227
+
228
+ class LayerNorm(nn.LayerNorm):
229
+
230
+ def forward(self, x):
231
+ return super().forward(x).type_as(x)
232
+
233
+
234
+ class SelfAttention(nn.Module):
235
+
236
+ def __init__(self,
237
+ dim,
238
+ num_heads,
239
+ causal=False,
240
+ attn_dropout=0.0,
241
+ proj_dropout=0.0):
242
+ assert dim % num_heads == 0
243
+ super().__init__()
244
+ self.dim = dim
245
+ self.num_heads = num_heads
246
+ self.head_dim = dim // num_heads
247
+ self.causal = causal
248
+ self.attn_dropout = attn_dropout
249
+ self.proj_dropout = proj_dropout
250
+
251
+ # layers
252
+ self.to_qkv = nn.Linear(dim, dim * 3)
253
+ self.proj = nn.Linear(dim, dim)
254
+
255
+ def forward(self, x):
256
+ """
257
+ x: [B, L, C].
258
+ """
259
+ # compute query, key, value
260
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
261
+
262
+ # compute attention
263
+ x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
264
+
265
+ # output
266
+ x = self.proj(x)
267
+ x = F.dropout(x, self.proj_dropout, self.training)
268
+ return x
269
+
270
+
271
+ class SwiGLU(nn.Module):
272
+
273
+ def __init__(self, dim, mid_dim):
274
+ super().__init__()
275
+ self.dim = dim
276
+ self.mid_dim = mid_dim
277
+
278
+ # layers
279
+ self.fc1 = nn.Linear(dim, mid_dim)
280
+ self.fc2 = nn.Linear(dim, mid_dim)
281
+ self.fc3 = nn.Linear(mid_dim, dim)
282
+
283
+ def forward(self, x):
284
+ x = F.silu(self.fc1(x)) * self.fc2(x)
285
+ x = self.fc3(x)
286
+ return x
287
+
288
+
289
+ class AttentionBlock(nn.Module):
290
+
291
+ def __init__(self,
292
+ dim,
293
+ mlp_ratio,
294
+ num_heads,
295
+ post_norm=False,
296
+ causal=False,
297
+ activation='quick_gelu',
298
+ attn_dropout=0.0,
299
+ proj_dropout=0.0,
300
+ norm_eps=1e-5):
301
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
302
+ super().__init__()
303
+ self.dim = dim
304
+ self.mlp_ratio = mlp_ratio
305
+ self.num_heads = num_heads
306
+ self.post_norm = post_norm
307
+ self.causal = causal
308
+ self.norm_eps = norm_eps
309
+
310
+ # layers
311
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
312
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
313
+ proj_dropout)
314
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
315
+ if activation == 'swi_glu':
316
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
317
+ else:
318
+ self.mlp = nn.Sequential(
319
+ nn.Linear(dim, int(dim * mlp_ratio)),
320
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
321
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
322
+
323
+ def forward(self, x):
324
+ if self.post_norm:
325
+ x = x + self.norm1(self.attn(x))
326
+ x = x + self.norm2(self.mlp(x))
327
+ else:
328
+ x = x + self.attn(self.norm1(x))
329
+ x = x + self.mlp(self.norm2(x))
330
+ return x
331
+
332
+
333
+ class AttentionPool(nn.Module):
334
+
335
+ def __init__(self,
336
+ dim,
337
+ mlp_ratio,
338
+ num_heads,
339
+ activation='gelu',
340
+ proj_dropout=0.0,
341
+ norm_eps=1e-5):
342
+ assert dim % num_heads == 0
343
+ super().__init__()
344
+ self.dim = dim
345
+ self.mlp_ratio = mlp_ratio
346
+ self.num_heads = num_heads
347
+ self.head_dim = dim // num_heads
348
+ self.proj_dropout = proj_dropout
349
+ self.norm_eps = norm_eps
350
+
351
+ # layers
352
+ gain = 1.0 / math.sqrt(dim)
353
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
354
+ self.to_q = nn.Linear(dim, dim)
355
+ self.to_kv = nn.Linear(dim, dim * 2)
356
+ self.proj = nn.Linear(dim, dim)
357
+ self.norm = LayerNorm(dim, eps=norm_eps)
358
+ self.mlp = nn.Sequential(
359
+ nn.Linear(dim, int(dim * mlp_ratio)),
360
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
361
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
362
+
363
+ def forward(self, x):
364
+ """
365
+ x: [B, L, C].
366
+ """
367
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
368
+
369
+ # compute query, key, value
370
+ q = self.to_q(self.cls_embedding).view(1, 1, n*d).expand(b, -1, -1)
371
+ k, v = self.to_kv(x).chunk(2, dim=-1)
372
+
373
+ # compute attention
374
+ x = flash_attention(q, k, v, num_heads=self.num_heads, compatibility_mode=True)
375
+ x = x.reshape(b, 1, c)
376
+
377
+ # output
378
+ x = self.proj(x)
379
+ x = F.dropout(x, self.proj_dropout, self.training)
380
+
381
+ # mlp
382
+ x = x + self.mlp(self.norm(x))
383
+ return x[:, 0]
384
+
385
+
386
+ class VisionTransformer(nn.Module):
387
+
388
+ def __init__(self,
389
+ image_size=224,
390
+ patch_size=16,
391
+ dim=768,
392
+ mlp_ratio=4,
393
+ out_dim=512,
394
+ num_heads=12,
395
+ num_layers=12,
396
+ pool_type='token',
397
+ pre_norm=True,
398
+ post_norm=False,
399
+ activation='quick_gelu',
400
+ attn_dropout=0.0,
401
+ proj_dropout=0.0,
402
+ embedding_dropout=0.0,
403
+ norm_eps=1e-5):
404
+ if image_size % patch_size != 0:
405
+ print(
406
+ '[WARNING] image_size is not divisible by patch_size',
407
+ flush=True)
408
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
409
+ out_dim = out_dim or dim
410
+ super().__init__()
411
+ self.image_size = image_size
412
+ self.patch_size = patch_size
413
+ self.num_patches = (image_size // patch_size)**2
414
+ self.dim = dim
415
+ self.mlp_ratio = mlp_ratio
416
+ self.out_dim = out_dim
417
+ self.num_heads = num_heads
418
+ self.num_layers = num_layers
419
+ self.pool_type = pool_type
420
+ self.post_norm = post_norm
421
+ self.norm_eps = norm_eps
422
+
423
+ # embeddings
424
+ gain = 1.0 / math.sqrt(dim)
425
+ self.patch_embedding = nn.Conv2d(
426
+ 3,
427
+ dim,
428
+ kernel_size=patch_size,
429
+ stride=patch_size,
430
+ bias=not pre_norm)
431
+ if pool_type in ('token', 'token_fc'):
432
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
433
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
434
+ 1, self.num_patches +
435
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
436
+ self.dropout = nn.Dropout(embedding_dropout)
437
+
438
+ # transformer
439
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
440
+ self.transformer = nn.Sequential(*[
441
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
442
+ activation, attn_dropout, proj_dropout, norm_eps)
443
+ for _ in range(num_layers)
444
+ ])
445
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
446
+
447
+ # head
448
+ if pool_type == 'token':
449
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
450
+ elif pool_type == 'token_fc':
451
+ self.head = nn.Linear(dim, out_dim)
452
+ elif pool_type == 'attn_pool':
453
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
454
+ proj_dropout, norm_eps)
455
+
456
+ def forward(self, x, interpolation=False, use_31_block=False):
457
+ b = x.size(0)
458
+
459
+ # embeddings
460
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
461
+ if self.pool_type in ('token', 'token_fc'):
462
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1).to(dtype=x.dtype, device=x.device), x], dim=1)
463
+ if interpolation:
464
+ e = pos_interpolate(self.pos_embedding, x.size(1))
465
+ else:
466
+ e = self.pos_embedding
467
+ e = e.to(dtype=x.dtype, device=x.device)
468
+ x = self.dropout(x + e)
469
+ if self.pre_norm is not None:
470
+ x = self.pre_norm(x)
471
+
472
+ # transformer
473
+ if use_31_block:
474
+ x = self.transformer[:-1](x)
475
+ return x
476
+ else:
477
+ x = self.transformer(x)
478
+ return x
479
+
480
+
481
+ class CLIP(nn.Module):
482
+
483
+ def __init__(self,
484
+ embed_dim=512,
485
+ image_size=224,
486
+ patch_size=16,
487
+ vision_dim=768,
488
+ vision_mlp_ratio=4,
489
+ vision_heads=12,
490
+ vision_layers=12,
491
+ vision_pool='token',
492
+ vision_pre_norm=True,
493
+ vision_post_norm=False,
494
+ vocab_size=49408,
495
+ text_len=77,
496
+ text_dim=512,
497
+ text_mlp_ratio=4,
498
+ text_heads=8,
499
+ text_layers=12,
500
+ text_causal=True,
501
+ text_pool='argmax',
502
+ text_head_bias=False,
503
+ logit_bias=None,
504
+ activation='quick_gelu',
505
+ attn_dropout=0.0,
506
+ proj_dropout=0.0,
507
+ embedding_dropout=0.0,
508
+ norm_eps=1e-5):
509
+ super().__init__()
510
+ self.embed_dim = embed_dim
511
+ self.image_size = image_size
512
+ self.patch_size = patch_size
513
+ self.vision_dim = vision_dim
514
+ self.vision_mlp_ratio = vision_mlp_ratio
515
+ self.vision_heads = vision_heads
516
+ self.vision_layers = vision_layers
517
+ self.vision_pool = vision_pool
518
+ self.vision_pre_norm = vision_pre_norm
519
+ self.vision_post_norm = vision_post_norm
520
+ self.vocab_size = vocab_size
521
+ self.text_len = text_len
522
+ self.text_dim = text_dim
523
+ self.text_mlp_ratio = text_mlp_ratio
524
+ self.text_heads = text_heads
525
+ self.text_layers = text_layers
526
+ self.text_causal = text_causal
527
+ self.text_pool = text_pool
528
+ self.text_head_bias = text_head_bias
529
+ self.norm_eps = norm_eps
530
+
531
+ # models
532
+ self.visual = VisionTransformer(
533
+ image_size=image_size,
534
+ patch_size=patch_size,
535
+ dim=vision_dim,
536
+ mlp_ratio=vision_mlp_ratio,
537
+ out_dim=embed_dim,
538
+ num_heads=vision_heads,
539
+ num_layers=vision_layers,
540
+ pool_type=vision_pool,
541
+ pre_norm=vision_pre_norm,
542
+ post_norm=vision_post_norm,
543
+ activation=activation,
544
+ attn_dropout=attn_dropout,
545
+ proj_dropout=proj_dropout,
546
+ embedding_dropout=embedding_dropout,
547
+ norm_eps=norm_eps)
548
+ self.textual = TextTransformer(
549
+ vocab_size=vocab_size,
550
+ text_len=text_len,
551
+ dim=text_dim,
552
+ mlp_ratio=text_mlp_ratio,
553
+ out_dim=embed_dim,
554
+ num_heads=text_heads,
555
+ num_layers=text_layers,
556
+ causal=text_causal,
557
+ pool_type=text_pool,
558
+ head_bias=text_head_bias,
559
+ activation=activation,
560
+ attn_dropout=attn_dropout,
561
+ proj_dropout=proj_dropout,
562
+ embedding_dropout=embedding_dropout,
563
+ norm_eps=norm_eps)
564
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
565
+ if logit_bias is not None:
566
+ self.logit_bias = nn.Parameter(logit_bias * torch.ones([]))
567
+
568
+ # initialize weights
569
+ self.init_weights()
570
+
571
+ def forward(self, imgs, txt_ids):
572
+ """
573
+ imgs: [B, 3, H, W] of torch.float32.
574
+ - mean: [0.48145466, 0.4578275, 0.40821073]
575
+ - std: [0.26862954, 0.26130258, 0.27577711]
576
+ txt_ids: [B, L] of torch.long. Encoded by data.CLIPTokenizer.
577
+ """
578
+ xi = self.visual(imgs)
579
+ xt = self.textual(txt_ids)
580
+ return xi, xt
581
+
582
+ def init_weights(self):
583
+ # embeddings
584
+ nn.init.normal_(self.textual.token_embedding.weight, std=0.02)
585
+ nn.init.normal_(self.visual.patch_embedding.weight, std=0.1)
586
+
587
+ # attentions
588
+ for modality in ['visual', 'textual']:
589
+ dim = self.vision_dim if modality == 'visual' else self.text_dim
590
+ transformer = getattr(self, modality).transformer
591
+ proj_gain = (1.0 / math.sqrt(dim)) * (
592
+ 1.0 / math.sqrt(2 * len(transformer)))
593
+ attn_gain = 1.0 / math.sqrt(dim)
594
+ mlp_gain = 1.0 / math.sqrt(2.0 * dim)
595
+ for block in transformer:
596
+ nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain)
597
+ nn.init.normal_(block.attn.proj.weight, std=proj_gain)
598
+ nn.init.normal_(block.mlp[0].weight, std=mlp_gain)
599
+ nn.init.normal_(block.mlp[2].weight, std=proj_gain)
600
+
601
+ def param_groups(self):
602
+ groups = [{
603
+ 'params': [
604
+ p for n, p in self.named_parameters()
605
+ if 'norm' in n or n.endswith('bias')
606
+ ],
607
+ 'weight_decay': 0.0
608
+ }, {
609
+ 'params': [
610
+ p for n, p in self.named_parameters()
611
+ if not ('norm' in n or n.endswith('bias'))
612
+ ]
613
+ }]
614
+ return groups
615
+
616
+
617
+ class XLMRobertaWithHead(XLMRoberta):
618
+
619
+ def __init__(self, **kwargs):
620
+ self.out_dim = kwargs.pop('out_dim')
621
+ super().__init__(**kwargs)
622
+
623
+ # head
624
+ mid_dim = (self.dim + self.out_dim) // 2
625
+ self.head = nn.Sequential(
626
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
627
+ nn.Linear(mid_dim, self.out_dim, bias=False))
628
+
629
+ def forward(self, ids):
630
+ # xlm-roberta
631
+ x = super().forward(ids)
632
+
633
+ # average pooling
634
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
635
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
636
+
637
+ # head
638
+ x = self.head(x)
639
+ return x
640
+
641
+
642
+ class XLMRobertaCLIP(nn.Module):
643
+
644
+ def __init__(self,
645
+ embed_dim=1024,
646
+ image_size=224,
647
+ patch_size=14,
648
+ vision_dim=1280,
649
+ vision_mlp_ratio=4,
650
+ vision_heads=16,
651
+ vision_layers=32,
652
+ vision_pool='token',
653
+ vision_pre_norm=True,
654
+ vision_post_norm=False,
655
+ activation='gelu',
656
+ vocab_size=250002,
657
+ max_text_len=514,
658
+ type_size=1,
659
+ pad_id=1,
660
+ text_dim=1024,
661
+ text_heads=16,
662
+ text_layers=24,
663
+ text_post_norm=True,
664
+ text_dropout=0.1,
665
+ attn_dropout=0.0,
666
+ proj_dropout=0.0,
667
+ embedding_dropout=0.0,
668
+ norm_eps=1e-5):
669
+ super().__init__()
670
+ self.embed_dim = embed_dim
671
+ self.image_size = image_size
672
+ self.patch_size = patch_size
673
+ self.vision_dim = vision_dim
674
+ self.vision_mlp_ratio = vision_mlp_ratio
675
+ self.vision_heads = vision_heads
676
+ self.vision_layers = vision_layers
677
+ self.vision_pre_norm = vision_pre_norm
678
+ self.vision_post_norm = vision_post_norm
679
+ self.activation = activation
680
+ self.vocab_size = vocab_size
681
+ self.max_text_len = max_text_len
682
+ self.type_size = type_size
683
+ self.pad_id = pad_id
684
+ self.text_dim = text_dim
685
+ self.text_heads = text_heads
686
+ self.text_layers = text_layers
687
+ self.text_post_norm = text_post_norm
688
+ self.norm_eps = norm_eps
689
+
690
+ # models
691
+ self.visual = VisionTransformer(
692
+ image_size=image_size,
693
+ patch_size=patch_size,
694
+ dim=vision_dim,
695
+ mlp_ratio=vision_mlp_ratio,
696
+ out_dim=embed_dim,
697
+ num_heads=vision_heads,
698
+ num_layers=vision_layers,
699
+ pool_type=vision_pool,
700
+ pre_norm=vision_pre_norm,
701
+ post_norm=vision_post_norm,
702
+ activation=activation,
703
+ attn_dropout=attn_dropout,
704
+ proj_dropout=proj_dropout,
705
+ embedding_dropout=embedding_dropout,
706
+ norm_eps=norm_eps)
707
+ self.textual = None
708
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
709
+
710
+ def forward(self, imgs, txt_ids):
711
+ """
712
+ imgs: [B, 3, H, W] of torch.float32.
713
+ - mean: [0.48145466, 0.4578275, 0.40821073]
714
+ - std: [0.26862954, 0.26130258, 0.27577711]
715
+ txt_ids: [B, L] of torch.long.
716
+ Encoded by data.CLIPTokenizer.
717
+ """
718
+ xi = self.visual(imgs)
719
+ xt = self.textual(txt_ids)
720
+ return xi, xt
721
+
722
+ def param_groups(self):
723
+ groups = [{
724
+ 'params': [
725
+ p for n, p in self.named_parameters()
726
+ if 'norm' in n or n.endswith('bias')
727
+ ],
728
+ 'weight_decay': 0.0
729
+ }, {
730
+ 'params': [
731
+ p for n, p in self.named_parameters()
732
+ if not ('norm' in n or n.endswith('bias'))
733
+ ]
734
+ }]
735
+ return groups
736
+
737
+
738
+ def _clip(pretrained=False,
739
+ pretrained_name=None,
740
+ model_cls=CLIP,
741
+ return_transforms=False,
742
+ return_tokenizer=False,
743
+ tokenizer_padding='eos',
744
+ dtype=torch.float32,
745
+ device='cpu',
746
+ **kwargs):
747
+ # init model
748
+ if pretrained and pretrained_name:
749
+ from sora import BUCKET, DOWNLOAD_TO_CACHE
750
+
751
+ # init a meta model
752
+ with torch.device('meta'):
753
+ model = model_cls(**kwargs)
754
+
755
+ # checkpoint path
756
+ checkpoint = f'models/clip/{pretrained_name}'
757
+ if dtype in (torch.float16, torch.bfloat16):
758
+ suffix = '-' + {
759
+ torch.float16: 'fp16',
760
+ torch.bfloat16: 'bf16'
761
+ }[dtype]
762
+ if object_exists(BUCKET, f'{checkpoint}{suffix}.pth'):
763
+ checkpoint = f'{checkpoint}{suffix}'
764
+ checkpoint += '.pth'
765
+
766
+ # load
767
+ model.load_state_dict(
768
+ torch.load(DOWNLOAD_TO_CACHE(checkpoint), map_location=device),
769
+ assign=True,
770
+ strict=False)
771
+ else:
772
+ # init a model on device
773
+ with torch.device(device):
774
+ model = model_cls(**kwargs)
775
+
776
+ # set device
777
+ output = (model,)
778
+
779
+ # init transforms
780
+ if return_transforms:
781
+ # mean and std
782
+ if 'siglip' in pretrained_name.lower():
783
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
784
+ else:
785
+ mean = [0.48145466, 0.4578275, 0.40821073]
786
+ std = [0.26862954, 0.26130258, 0.27577711]
787
+
788
+ # transforms
789
+ transforms = T.Compose([
790
+ T.Resize((model.image_size, model.image_size),
791
+ interpolation=T.InterpolationMode.BICUBIC),
792
+ T.ToTensor(),
793
+ T.Normalize(mean=mean, std=std)
794
+ ])
795
+ output += (transforms,)
796
+
797
+ # init tokenizer
798
+ if return_tokenizer:
799
+ from sora import data
800
+ if 'siglip' in pretrained_name.lower():
801
+ tokenizer = data.HuggingfaceTokenizer(
802
+ name=f'timm/{pretrained_name}',
803
+ seq_len=model.text_len,
804
+ clean='canonicalize')
805
+ elif 'xlm' in pretrained_name.lower():
806
+ tokenizer = data.HuggingfaceTokenizer(
807
+ name='xlm-roberta-large',
808
+ seq_len=model.max_text_len - 2,
809
+ clean='whitespace')
810
+ elif 'mba' in pretrained_name.lower():
811
+ tokenizer = data.HuggingfaceTokenizer(
812
+ name='facebook/xlm-roberta-xl',
813
+ seq_len=model.max_text_len - 2,
814
+ clean='whitespace')
815
+ else:
816
+ tokenizer = data.CLIPTokenizer(
817
+ seq_len=model.text_len, padding=tokenizer_padding)
818
+ output += (tokenizer,)
819
+ return output[0] if len(output) == 1 else output
820
+
821
+
822
+ def clip_xlm_roberta_vit_h_14(
823
+ pretrained=False,
824
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
825
+ **kwargs):
826
+ cfg = dict(
827
+ embed_dim=1024,
828
+ image_size=224,
829
+ patch_size=14,
830
+ vision_dim=1280,
831
+ vision_mlp_ratio=4,
832
+ vision_heads=16,
833
+ vision_layers=32,
834
+ vision_pool='token',
835
+ activation='gelu',
836
+ vocab_size=250002,
837
+ max_text_len=514,
838
+ type_size=1,
839
+ pad_id=1,
840
+ text_dim=1024,
841
+ text_heads=16,
842
+ text_layers=24,
843
+ text_post_norm=True,
844
+ text_dropout=0.1,
845
+ attn_dropout=0.0,
846
+ proj_dropout=0.0,
847
+ embedding_dropout=0.0)
848
+ cfg.update(**kwargs)
849
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
850
+
851
+
852
+ class WanImageEncoder(torch.nn.Module):
853
+
854
+ def __init__(self):
855
+ super().__init__()
856
+ # init model
857
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
858
+ pretrained=False,
859
+ return_transforms=True,
860
+ return_tokenizer=False,
861
+ dtype=torch.float32,
862
+ device="cpu")
863
+
864
+ def encode_image(self, videos):
865
+ # preprocess
866
+ size = (self.model.image_size,) * 2
867
+ videos = torch.cat([
868
+ F.interpolate(
869
+ u,
870
+ size=size,
871
+ mode='bicubic',
872
+ align_corners=False) for u in videos
873
+ ])
874
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
875
+
876
+ # forward
877
+ dtype = next(iter(self.model.visual.parameters())).dtype
878
+ videos = videos.to(dtype)
879
+ out = self.model.visual(videos, use_31_block=True)
880
+ return out
881
+
882
+ @staticmethod
883
+ def state_dict_converter():
884
+ return WanImageEncoderStateDictConverter()
885
+
886
+
887
+ class WanImageEncoderStateDictConverter:
888
+ def __init__(self):
889
+ pass
890
+
891
+ def from_diffusers(self, state_dict):
892
+ return state_dict
893
+
894
+ def from_civitai(self, state_dict):
895
+ state_dict_ = {}
896
+ for name, param in state_dict.items():
897
+ if name.startswith("textual."):
898
+ continue
899
+ name = "model." + name
900
+ state_dict_[name] = param
901
+ return state_dict_
902
+
diffsynth/models/wan_video_motion_controller.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .wan_video_dit import sinusoidal_embedding_1d
4
+
5
+
6
+
7
+ class WanMotionControllerModel(torch.nn.Module):
8
+ def __init__(self, freq_dim=256, dim=1536):
9
+ super().__init__()
10
+ self.freq_dim = freq_dim
11
+ self.linear = nn.Sequential(
12
+ nn.Linear(freq_dim, dim),
13
+ nn.SiLU(),
14
+ nn.Linear(dim, dim),
15
+ nn.SiLU(),
16
+ nn.Linear(dim, dim * 6),
17
+ )
18
+
19
+ def forward(self, motion_bucket_id):
20
+ emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
21
+ emb = self.linear(emb)
22
+ return emb
23
+
24
+ def init(self):
25
+ state_dict = self.linear[-1].state_dict()
26
+ state_dict = {i: state_dict[i] * 0 for i in state_dict}
27
+ self.linear[-1].load_state_dict(state_dict)
28
+
29
+ @staticmethod
30
+ def state_dict_converter():
31
+ return WanMotionControllerModelDictConverter()
32
+
33
+
34
+
35
+ class WanMotionControllerModelDictConverter:
36
+ def __init__(self):
37
+ pass
38
+
39
+ def from_diffusers(self, state_dict):
40
+ return state_dict
41
+
42
+ def from_civitai(self, state_dict):
43
+ return state_dict
44
+
diffsynth/models/wan_video_text_encoder.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def fp16_clamp(x):
9
+ if x.dtype == torch.float16 and torch.isinf(x).any():
10
+ clamp = torch.finfo(x.dtype).max - 1000
11
+ x = torch.clamp(x, min=-clamp, max=clamp)
12
+ return x
13
+
14
+
15
+ class GELU(nn.Module):
16
+
17
+ def forward(self, x):
18
+ return 0.5 * x * (1.0 + torch.tanh(
19
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
20
+
21
+
22
+ class T5LayerNorm(nn.Module):
23
+
24
+ def __init__(self, dim, eps=1e-6):
25
+ super(T5LayerNorm, self).__init__()
26
+ self.dim = dim
27
+ self.eps = eps
28
+ self.weight = nn.Parameter(torch.ones(dim))
29
+
30
+ def forward(self, x):
31
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
32
+ self.eps)
33
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
34
+ x = x.type_as(self.weight)
35
+ return self.weight * x
36
+
37
+
38
+ class T5Attention(nn.Module):
39
+
40
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
41
+ assert dim_attn % num_heads == 0
42
+ super(T5Attention, self).__init__()
43
+ self.dim = dim
44
+ self.dim_attn = dim_attn
45
+ self.num_heads = num_heads
46
+ self.head_dim = dim_attn // num_heads
47
+
48
+ # layers
49
+ self.q = nn.Linear(dim, dim_attn, bias=False)
50
+ self.k = nn.Linear(dim, dim_attn, bias=False)
51
+ self.v = nn.Linear(dim, dim_attn, bias=False)
52
+ self.o = nn.Linear(dim_attn, dim, bias=False)
53
+ self.dropout = nn.Dropout(dropout)
54
+
55
+ def forward(self, x, context=None, mask=None, pos_bias=None):
56
+ """
57
+ x: [B, L1, C].
58
+ context: [B, L2, C] or None.
59
+ mask: [B, L2] or [B, L1, L2] or None.
60
+ """
61
+ # check inputs
62
+ context = x if context is None else context
63
+ b, n, c = x.size(0), self.num_heads, self.head_dim
64
+
65
+ # compute query, key, value
66
+ q = self.q(x).view(b, -1, n, c)
67
+ k = self.k(context).view(b, -1, n, c)
68
+ v = self.v(context).view(b, -1, n, c)
69
+
70
+ # attention bias
71
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
72
+ if pos_bias is not None:
73
+ attn_bias += pos_bias
74
+ if mask is not None:
75
+ assert mask.ndim in [2, 3]
76
+ mask = mask.view(b, 1, 1,
77
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
78
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
79
+
80
+ # compute attention (T5 does not use scaling)
81
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
82
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
83
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
84
+
85
+ # output
86
+ x = x.reshape(b, -1, n * c)
87
+ x = self.o(x)
88
+ x = self.dropout(x)
89
+ return x
90
+
91
+
92
+ class T5FeedForward(nn.Module):
93
+
94
+ def __init__(self, dim, dim_ffn, dropout=0.1):
95
+ super(T5FeedForward, self).__init__()
96
+ self.dim = dim
97
+ self.dim_ffn = dim_ffn
98
+
99
+ # layers
100
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
101
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
102
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
103
+ self.dropout = nn.Dropout(dropout)
104
+
105
+ def forward(self, x):
106
+ x = self.fc1(x) * self.gate(x)
107
+ x = self.dropout(x)
108
+ x = self.fc2(x)
109
+ x = self.dropout(x)
110
+ return x
111
+
112
+
113
+ class T5SelfAttention(nn.Module):
114
+
115
+ def __init__(self,
116
+ dim,
117
+ dim_attn,
118
+ dim_ffn,
119
+ num_heads,
120
+ num_buckets,
121
+ shared_pos=True,
122
+ dropout=0.1):
123
+ super(T5SelfAttention, self).__init__()
124
+ self.dim = dim
125
+ self.dim_attn = dim_attn
126
+ self.dim_ffn = dim_ffn
127
+ self.num_heads = num_heads
128
+ self.num_buckets = num_buckets
129
+ self.shared_pos = shared_pos
130
+
131
+ # layers
132
+ self.norm1 = T5LayerNorm(dim)
133
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
134
+ self.norm2 = T5LayerNorm(dim)
135
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
136
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
137
+ num_buckets, num_heads, bidirectional=True)
138
+
139
+ def forward(self, x, mask=None, pos_bias=None):
140
+ e = pos_bias if self.shared_pos else self.pos_embedding(
141
+ x.size(1), x.size(1))
142
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
143
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
144
+ return x
145
+
146
+
147
+ class T5RelativeEmbedding(nn.Module):
148
+
149
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
150
+ super(T5RelativeEmbedding, self).__init__()
151
+ self.num_buckets = num_buckets
152
+ self.num_heads = num_heads
153
+ self.bidirectional = bidirectional
154
+ self.max_dist = max_dist
155
+
156
+ # layers
157
+ self.embedding = nn.Embedding(num_buckets, num_heads)
158
+
159
+ def forward(self, lq, lk):
160
+ device = self.embedding.weight.device
161
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
162
+ # torch.arange(lq).unsqueeze(1).to(device)
163
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
164
+ torch.arange(lq, device=device).unsqueeze(1)
165
+ rel_pos = self._relative_position_bucket(rel_pos)
166
+ rel_pos_embeds = self.embedding(rel_pos)
167
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
168
+ 0) # [1, N, Lq, Lk]
169
+ return rel_pos_embeds.contiguous()
170
+
171
+ def _relative_position_bucket(self, rel_pos):
172
+ # preprocess
173
+ if self.bidirectional:
174
+ num_buckets = self.num_buckets // 2
175
+ rel_buckets = (rel_pos > 0).long() * num_buckets
176
+ rel_pos = torch.abs(rel_pos)
177
+ else:
178
+ num_buckets = self.num_buckets
179
+ rel_buckets = 0
180
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
181
+
182
+ # embeddings for small and large positions
183
+ max_exact = num_buckets // 2
184
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
185
+ math.log(self.max_dist / max_exact) *
186
+ (num_buckets - max_exact)).long()
187
+ rel_pos_large = torch.min(
188
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
189
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
190
+ return rel_buckets
191
+
192
+ def init_weights(m):
193
+ if isinstance(m, T5LayerNorm):
194
+ nn.init.ones_(m.weight)
195
+ elif isinstance(m, T5FeedForward):
196
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
197
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
198
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
199
+ elif isinstance(m, T5Attention):
200
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
201
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
202
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
203
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
204
+ elif isinstance(m, T5RelativeEmbedding):
205
+ nn.init.normal_(
206
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
207
+
208
+
209
+ class WanTextEncoder(torch.nn.Module):
210
+
211
+ def __init__(self,
212
+ vocab=256384,
213
+ dim=4096,
214
+ dim_attn=4096,
215
+ dim_ffn=10240,
216
+ num_heads=64,
217
+ num_layers=24,
218
+ num_buckets=32,
219
+ shared_pos=False,
220
+ dropout=0.1):
221
+ super(WanTextEncoder, self).__init__()
222
+ self.dim = dim
223
+ self.dim_attn = dim_attn
224
+ self.dim_ffn = dim_ffn
225
+ self.num_heads = num_heads
226
+ self.num_layers = num_layers
227
+ self.num_buckets = num_buckets
228
+ self.shared_pos = shared_pos
229
+
230
+ # layers
231
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
232
+ else nn.Embedding(vocab, dim)
233
+ self.pos_embedding = T5RelativeEmbedding(
234
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
235
+ self.dropout = nn.Dropout(dropout)
236
+ self.blocks = nn.ModuleList([
237
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
238
+ shared_pos, dropout) for _ in range(num_layers)
239
+ ])
240
+ self.norm = T5LayerNorm(dim)
241
+
242
+ # initialize weights
243
+ self.apply(init_weights)
244
+
245
+ def forward(self, ids, mask=None):
246
+ x = self.token_embedding(ids)
247
+ x = self.dropout(x)
248
+ e = self.pos_embedding(x.size(1),
249
+ x.size(1)) if self.shared_pos else None
250
+ for block in self.blocks:
251
+ x = block(x, mask, pos_bias=e)
252
+ x = self.norm(x)
253
+ x = self.dropout(x)
254
+ return x
255
+
256
+ @staticmethod
257
+ def state_dict_converter():
258
+ return WanTextEncoderStateDictConverter()
259
+
260
+
261
+ class WanTextEncoderStateDictConverter:
262
+ def __init__(self):
263
+ pass
264
+
265
+ def from_diffusers(self, state_dict):
266
+ return state_dict
267
+
268
+ def from_civitai(self, state_dict):
269
+ return state_dict
diffsynth/models/wan_video_vace.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .wan_video_dit import DiTBlock
3
+ from .utils import hash_state_dict_keys
4
+
5
+ class VaceWanAttentionBlock(DiTBlock):
6
+ def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
7
+ super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
8
+ self.block_id = block_id
9
+ if block_id == 0:
10
+ self.before_proj = torch.nn.Linear(self.dim, self.dim)
11
+ self.after_proj = torch.nn.Linear(self.dim, self.dim)
12
+
13
+ def forward(self, c, x, context, t_mod, freqs):
14
+ if self.block_id == 0:
15
+ c = self.before_proj(c) + x
16
+ all_c = []
17
+ else:
18
+ all_c = list(torch.unbind(c))
19
+ c = all_c.pop(-1)
20
+ c = super().forward(c, context, t_mod, freqs)
21
+ c_skip = self.after_proj(c)
22
+ all_c += [c_skip, c]
23
+ c = torch.stack(all_c)
24
+ return c
25
+
26
+
27
+ class VaceWanModel(torch.nn.Module):
28
+ def __init__(
29
+ self,
30
+ vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
31
+ vace_in_dim=96,
32
+ patch_size=(1, 2, 2),
33
+ has_image_input=False,
34
+ dim=1536,
35
+ num_heads=12,
36
+ ffn_dim=8960,
37
+ eps=1e-6,
38
+ ):
39
+ super().__init__()
40
+ self.vace_layers = vace_layers
41
+ self.vace_in_dim = vace_in_dim
42
+ self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
43
+
44
+ # vace blocks
45
+ self.vace_blocks = torch.nn.ModuleList([
46
+ VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
47
+ for i in self.vace_layers
48
+ ])
49
+
50
+ # vace patch embeddings
51
+ self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)
52
+
53
+ def forward(
54
+ self, x, vace_context, context, t_mod, freqs,
55
+ use_gradient_checkpointing: bool = False,
56
+ use_gradient_checkpointing_offload: bool = False,
57
+ ):
58
+ c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
59
+ c = [u.flatten(2).transpose(1, 2) for u in c]
60
+ c = torch.cat([
61
+ torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))],
62
+ dim=1) for u in c
63
+ ])
64
+
65
+ def create_custom_forward(module):
66
+ def custom_forward(*inputs):
67
+ return module(*inputs)
68
+ return custom_forward
69
+
70
+ for block in self.vace_blocks:
71
+ if use_gradient_checkpointing_offload:
72
+ with torch.autograd.graph.save_on_cpu():
73
+ c = torch.utils.checkpoint.checkpoint(
74
+ create_custom_forward(block),
75
+ c, x, context, t_mod, freqs,
76
+ use_reentrant=False,
77
+ )
78
+ elif use_gradient_checkpointing:
79
+ c = torch.utils.checkpoint.checkpoint(
80
+ create_custom_forward(block),
81
+ c, x, context, t_mod, freqs,
82
+ use_reentrant=False,
83
+ )
84
+ else:
85
+ c = block(c, x, context, t_mod, freqs)
86
+ hints = torch.unbind(c)[:-1]
87
+ return hints
88
+
89
+ @staticmethod
90
+ def state_dict_converter():
91
+ return VaceWanModelDictConverter()
92
+
93
+
94
+ class VaceWanModelDictConverter:
95
+ def __init__(self):
96
+ pass
97
+
98
+ def from_civitai(self, state_dict):
99
+ state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("vace")}
100
+ if hash_state_dict_keys(state_dict_) == '3b2726384e4f64837bdf216eea3f310d': # vace 14B
101
+ config = {
102
+ "vace_layers": (0, 5, 10, 15, 20, 25, 30, 35),
103
+ "vace_in_dim": 96,
104
+ "patch_size": (1, 2, 2),
105
+ "has_image_input": False,
106
+ "dim": 5120,
107
+ "num_heads": 40,
108
+ "ffn_dim": 13824,
109
+ "eps": 1e-06,
110
+ }
111
+ else:
112
+ config = {}
113
+ return state_dict_, config
diffsynth/models/wan_video_vae.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange, repeat
5
+ from tqdm import tqdm
6
+
7
+ CACHE_T = 2
8
+
9
+
10
+ def check_is_instance(model, module_class):
11
+ if isinstance(model, module_class):
12
+ return True
13
+ if hasattr(model, "module") and isinstance(model.module, module_class):
14
+ return True
15
+ return False
16
+
17
+
18
+ def block_causal_mask(x, block_size):
19
+ # params
20
+ b, n, s, _, device = *x.size(), x.device
21
+ assert s % block_size == 0
22
+ num_blocks = s // block_size
23
+
24
+ # build mask
25
+ mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
26
+ for i in range(num_blocks):
27
+ mask[:, :,
28
+ i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
29
+ return mask
30
+
31
+
32
+ class CausalConv3d(nn.Conv3d):
33
+ """
34
+ Causal 3d convolusion.
35
+ """
36
+
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
40
+ self.padding[1], 2 * self.padding[0], 0)
41
+ self.padding = (0, 0, 0)
42
+
43
+ def forward(self, x, cache_x=None):
44
+ padding = list(self._padding)
45
+ if cache_x is not None and self._padding[4] > 0:
46
+ cache_x = cache_x.to(x.device)
47
+ x = torch.cat([cache_x, x], dim=2)
48
+ padding[4] -= cache_x.shape[2]
49
+ x = F.pad(x, padding)
50
+
51
+ return super().forward(x)
52
+
53
+
54
+ class RMS_norm(nn.Module):
55
+
56
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
57
+ super().__init__()
58
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
59
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
60
+
61
+ self.channel_first = channel_first
62
+ self.scale = dim**0.5
63
+ self.gamma = nn.Parameter(torch.ones(shape))
64
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
65
+
66
+ def forward(self, x):
67
+ return F.normalize(
68
+ x, dim=(1 if self.channel_first else
69
+ -1)) * self.scale * self.gamma + self.bias
70
+
71
+
72
+ class Upsample(nn.Upsample):
73
+
74
+ def forward(self, x):
75
+ """
76
+ Fix bfloat16 support for nearest neighbor interpolation.
77
+ """
78
+ return super().forward(x.float()).type_as(x)
79
+
80
+
81
+ class Resample(nn.Module):
82
+
83
+ def __init__(self, dim, mode):
84
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
85
+ 'downsample3d')
86
+ super().__init__()
87
+ self.dim = dim
88
+ self.mode = mode
89
+
90
+ # layers
91
+ if mode == 'upsample2d':
92
+ self.resample = nn.Sequential(
93
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
94
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
95
+ elif mode == 'upsample3d':
96
+ self.resample = nn.Sequential(
97
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
98
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
99
+ self.time_conv = CausalConv3d(dim,
100
+ dim * 2, (3, 1, 1),
101
+ padding=(1, 0, 0))
102
+
103
+ elif mode == 'downsample2d':
104
+ self.resample = nn.Sequential(
105
+ nn.ZeroPad2d((0, 1, 0, 1)),
106
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
107
+ elif mode == 'downsample3d':
108
+ self.resample = nn.Sequential(
109
+ nn.ZeroPad2d((0, 1, 0, 1)),
110
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
111
+ self.time_conv = CausalConv3d(dim,
112
+ dim, (3, 1, 1),
113
+ stride=(2, 1, 1),
114
+ padding=(0, 0, 0))
115
+
116
+ else:
117
+ self.resample = nn.Identity()
118
+
119
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
120
+ b, c, t, h, w = x.size()
121
+ if self.mode == 'upsample3d':
122
+ if feat_cache is not None:
123
+ idx = feat_idx[0]
124
+ if feat_cache[idx] is None:
125
+ feat_cache[idx] = 'Rep'
126
+ feat_idx[0] += 1
127
+ else:
128
+
129
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
130
+ if cache_x.shape[2] < 2 and feat_cache[
131
+ idx] is not None and feat_cache[idx] != 'Rep':
132
+ # cache last frame of last two chunk
133
+ cache_x = torch.cat([
134
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
135
+ cache_x.device), cache_x
136
+ ],
137
+ dim=2)
138
+ if cache_x.shape[2] < 2 and feat_cache[
139
+ idx] is not None and feat_cache[idx] == 'Rep':
140
+ cache_x = torch.cat([
141
+ torch.zeros_like(cache_x).to(cache_x.device),
142
+ cache_x
143
+ ],
144
+ dim=2)
145
+ if feat_cache[idx] == 'Rep':
146
+ x = self.time_conv(x)
147
+ else:
148
+ x = self.time_conv(x, feat_cache[idx])
149
+ feat_cache[idx] = cache_x
150
+ feat_idx[0] += 1
151
+
152
+ x = x.reshape(b, 2, c, t, h, w)
153
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
154
+ 3)
155
+ x = x.reshape(b, c, t * 2, h, w)
156
+ t = x.shape[2]
157
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
158
+ x = self.resample(x)
159
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
160
+
161
+ if self.mode == 'downsample3d':
162
+ if feat_cache is not None:
163
+ idx = feat_idx[0]
164
+ if feat_cache[idx] is None:
165
+ feat_cache[idx] = x.clone()
166
+ feat_idx[0] += 1
167
+ else:
168
+ cache_x = x[:, :, -1:, :, :].clone()
169
+ x = self.time_conv(
170
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
171
+ feat_cache[idx] = cache_x
172
+ feat_idx[0] += 1
173
+ return x
174
+
175
+ def init_weight(self, conv):
176
+ conv_weight = conv.weight
177
+ nn.init.zeros_(conv_weight)
178
+ c1, c2, t, h, w = conv_weight.size()
179
+ one_matrix = torch.eye(c1, c2)
180
+ init_matrix = one_matrix
181
+ nn.init.zeros_(conv_weight)
182
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix
183
+ conv.weight.data.copy_(conv_weight)
184
+ nn.init.zeros_(conv.bias.data)
185
+
186
+ def init_weight2(self, conv):
187
+ conv_weight = conv.weight.data
188
+ nn.init.zeros_(conv_weight)
189
+ c1, c2, t, h, w = conv_weight.size()
190
+ init_matrix = torch.eye(c1 // 2, c2)
191
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
192
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
193
+ conv.weight.data.copy_(conv_weight)
194
+ nn.init.zeros_(conv.bias.data)
195
+
196
+
197
+ class ResidualBlock(nn.Module):
198
+
199
+ def __init__(self, in_dim, out_dim, dropout=0.0):
200
+ super().__init__()
201
+ self.in_dim = in_dim
202
+ self.out_dim = out_dim
203
+
204
+ # layers
205
+ self.residual = nn.Sequential(
206
+ RMS_norm(in_dim, images=False), nn.SiLU(),
207
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
208
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
209
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
210
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
211
+ if in_dim != out_dim else nn.Identity()
212
+
213
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
214
+ h = self.shortcut(x)
215
+ for layer in self.residual:
216
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
217
+ idx = feat_idx[0]
218
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
219
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
220
+ # cache last frame of last two chunk
221
+ cache_x = torch.cat([
222
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
223
+ cache_x.device), cache_x
224
+ ],
225
+ dim=2)
226
+ x = layer(x, feat_cache[idx])
227
+ feat_cache[idx] = cache_x
228
+ feat_idx[0] += 1
229
+ else:
230
+ x = layer(x)
231
+ return x + h
232
+
233
+
234
+ class AttentionBlock(nn.Module):
235
+ """
236
+ Causal self-attention with a single head.
237
+ """
238
+
239
+ def __init__(self, dim):
240
+ super().__init__()
241
+ self.dim = dim
242
+
243
+ # layers
244
+ self.norm = RMS_norm(dim)
245
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
246
+ self.proj = nn.Conv2d(dim, dim, 1)
247
+
248
+ # zero out the last layer params
249
+ nn.init.zeros_(self.proj.weight)
250
+
251
+ def forward(self, x):
252
+ identity = x
253
+ b, c, t, h, w = x.size()
254
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
255
+ x = self.norm(x)
256
+ # compute query, key, value
257
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
258
+ 0, 1, 3, 2).contiguous().chunk(3, dim=-1)
259
+
260
+ # apply attention
261
+ x = F.scaled_dot_product_attention(
262
+ q,
263
+ k,
264
+ v,
265
+ # attn_mask=block_causal_mask(q, block_size=h * w)
266
+ )
267
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
268
+
269
+ # output
270
+ x = self.proj(x)
271
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
272
+ return x + identity
273
+
274
+
275
+ class Encoder3d(nn.Module):
276
+
277
+ def __init__(self,
278
+ dim=128,
279
+ z_dim=4,
280
+ dim_mult=[1, 2, 4, 4],
281
+ num_res_blocks=2,
282
+ attn_scales=[],
283
+ temperal_downsample=[True, True, False],
284
+ dropout=0.0):
285
+ super().__init__()
286
+ self.dim = dim
287
+ self.z_dim = z_dim
288
+ self.dim_mult = dim_mult
289
+ self.num_res_blocks = num_res_blocks
290
+ self.attn_scales = attn_scales
291
+ self.temperal_downsample = temperal_downsample
292
+
293
+ # dimensions
294
+ dims = [dim * u for u in [1] + dim_mult]
295
+ scale = 1.0
296
+
297
+ # init block
298
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
299
+
300
+ # downsample blocks
301
+ downsamples = []
302
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
303
+ # residual (+attention) blocks
304
+ for _ in range(num_res_blocks):
305
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
306
+ if scale in attn_scales:
307
+ downsamples.append(AttentionBlock(out_dim))
308
+ in_dim = out_dim
309
+
310
+ # downsample block
311
+ if i != len(dim_mult) - 1:
312
+ mode = 'downsample3d' if temperal_downsample[
313
+ i] else 'downsample2d'
314
+ downsamples.append(Resample(out_dim, mode=mode))
315
+ scale /= 2.0
316
+ self.downsamples = nn.Sequential(*downsamples)
317
+
318
+ # middle blocks
319
+ self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
320
+ AttentionBlock(out_dim),
321
+ ResidualBlock(out_dim, out_dim, dropout))
322
+
323
+ # output blocks
324
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
325
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
326
+
327
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
328
+ if feat_cache is not None:
329
+ idx = feat_idx[0]
330
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
331
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
332
+ # cache last frame of last two chunk
333
+ cache_x = torch.cat([
334
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
335
+ cache_x.device), cache_x
336
+ ],
337
+ dim=2)
338
+ x = self.conv1(x, feat_cache[idx])
339
+ feat_cache[idx] = cache_x
340
+ feat_idx[0] += 1
341
+ else:
342
+ x = self.conv1(x)
343
+
344
+ # downsamples
345
+ for layer in self.downsamples:
346
+ if feat_cache is not None:
347
+ x = layer(x, feat_cache, feat_idx)
348
+ else:
349
+ x = layer(x)
350
+
351
+ # middle
352
+ for layer in self.middle:
353
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
354
+ x = layer(x, feat_cache, feat_idx)
355
+ else:
356
+ x = layer(x)
357
+
358
+ # head
359
+ for layer in self.head:
360
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
361
+ idx = feat_idx[0]
362
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
363
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
364
+ # cache last frame of last two chunk
365
+ cache_x = torch.cat([
366
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
367
+ cache_x.device), cache_x
368
+ ],
369
+ dim=2)
370
+ x = layer(x, feat_cache[idx])
371
+ feat_cache[idx] = cache_x
372
+ feat_idx[0] += 1
373
+ else:
374
+ x = layer(x)
375
+ return x
376
+
377
+
378
+ class Decoder3d(nn.Module):
379
+
380
+ def __init__(self,
381
+ dim=128,
382
+ z_dim=4,
383
+ dim_mult=[1, 2, 4, 4],
384
+ num_res_blocks=2,
385
+ attn_scales=[],
386
+ temperal_upsample=[False, True, True],
387
+ dropout=0.0):
388
+ super().__init__()
389
+ self.dim = dim
390
+ self.z_dim = z_dim
391
+ self.dim_mult = dim_mult
392
+ self.num_res_blocks = num_res_blocks
393
+ self.attn_scales = attn_scales
394
+ self.temperal_upsample = temperal_upsample
395
+
396
+ # dimensions
397
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
398
+ scale = 1.0 / 2**(len(dim_mult) - 2)
399
+
400
+ # init block
401
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
402
+
403
+ # middle blocks
404
+ self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
405
+ AttentionBlock(dims[0]),
406
+ ResidualBlock(dims[0], dims[0], dropout))
407
+
408
+ # upsample blocks
409
+ upsamples = []
410
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
411
+ # residual (+attention) blocks
412
+ if i == 1 or i == 2 or i == 3:
413
+ in_dim = in_dim // 2
414
+ for _ in range(num_res_blocks + 1):
415
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
416
+ if scale in attn_scales:
417
+ upsamples.append(AttentionBlock(out_dim))
418
+ in_dim = out_dim
419
+
420
+ # upsample block
421
+ if i != len(dim_mult) - 1:
422
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
423
+ upsamples.append(Resample(out_dim, mode=mode))
424
+ scale *= 2.0
425
+ self.upsamples = nn.Sequential(*upsamples)
426
+
427
+ # output blocks
428
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
429
+ CausalConv3d(out_dim, 3, 3, padding=1))
430
+
431
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
432
+ # conv1
433
+ if feat_cache is not None:
434
+ idx = feat_idx[0]
435
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
436
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
437
+ # cache last frame of last two chunk
438
+ cache_x = torch.cat([
439
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
440
+ cache_x.device), cache_x
441
+ ],
442
+ dim=2)
443
+ x = self.conv1(x, feat_cache[idx])
444
+ feat_cache[idx] = cache_x
445
+ feat_idx[0] += 1
446
+ else:
447
+ x = self.conv1(x)
448
+
449
+ # middle
450
+ for layer in self.middle:
451
+ if check_is_instance(layer, ResidualBlock) and feat_cache is not None:
452
+ x = layer(x, feat_cache, feat_idx)
453
+ else:
454
+ x = layer(x)
455
+
456
+ # upsamples
457
+ for layer in self.upsamples:
458
+ if feat_cache is not None:
459
+ x = layer(x, feat_cache, feat_idx)
460
+ else:
461
+ x = layer(x)
462
+
463
+ # head
464
+ for layer in self.head:
465
+ if check_is_instance(layer, CausalConv3d) and feat_cache is not None:
466
+ idx = feat_idx[0]
467
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
468
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
469
+ # cache last frame of last two chunk
470
+ cache_x = torch.cat([
471
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
472
+ cache_x.device), cache_x
473
+ ],
474
+ dim=2)
475
+ x = layer(x, feat_cache[idx])
476
+ feat_cache[idx] = cache_x
477
+ feat_idx[0] += 1
478
+ else:
479
+ x = layer(x)
480
+ return x
481
+
482
+
483
+ def count_conv3d(model):
484
+ count = 0
485
+ for m in model.modules():
486
+ if check_is_instance(m, CausalConv3d):
487
+ count += 1
488
+ return count
489
+
490
+
491
+ class VideoVAE_(nn.Module):
492
+
493
+ def __init__(self,
494
+ dim=96,
495
+ z_dim=16,
496
+ dim_mult=[1, 2, 4, 4],
497
+ num_res_blocks=2,
498
+ attn_scales=[],
499
+ temperal_downsample=[False, True, True],
500
+ dropout=0.0):
501
+ super().__init__()
502
+ self.dim = dim
503
+ self.z_dim = z_dim
504
+ self.dim_mult = dim_mult
505
+ self.num_res_blocks = num_res_blocks
506
+ self.attn_scales = attn_scales
507
+ self.temperal_downsample = temperal_downsample
508
+ self.temperal_upsample = temperal_downsample[::-1]
509
+
510
+ # modules
511
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
512
+ attn_scales, self.temperal_downsample, dropout)
513
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
514
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
515
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
516
+ attn_scales, self.temperal_upsample, dropout)
517
+
518
+ def forward(self, x):
519
+ mu, log_var = self.encode(x)
520
+ z = self.reparameterize(mu, log_var)
521
+ x_recon = self.decode(z)
522
+ return x_recon, mu, log_var
523
+
524
+ def encode(self, x, scale):
525
+ self.clear_cache()
526
+ # cache
527
+ t = x.shape[2]
528
+ iter_ = 1 + (t - 1) // 4
529
+
530
+ for i in range(iter_):
531
+ self._enc_conv_idx = [0]
532
+ if i == 0:
533
+ out = self.encoder(x[:, :, :1, :, :],
534
+ feat_cache=self._enc_feat_map,
535
+ feat_idx=self._enc_conv_idx)
536
+ else:
537
+ out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
538
+ feat_cache=self._enc_feat_map,
539
+ feat_idx=self._enc_conv_idx)
540
+ out = torch.cat([out, out_], 2)
541
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
542
+ if isinstance(scale[0], torch.Tensor):
543
+ scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
544
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
545
+ 1, self.z_dim, 1, 1, 1)
546
+ else:
547
+ scale = scale.to(dtype=mu.dtype, device=mu.device)
548
+ mu = (mu - scale[0]) * scale[1]
549
+ return mu
550
+
551
+ def decode(self, z, scale):
552
+ self.clear_cache()
553
+ # z: [b,c,t,h,w]
554
+ if isinstance(scale[0], torch.Tensor):
555
+ scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
556
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
557
+ 1, self.z_dim, 1, 1, 1)
558
+ else:
559
+ scale = scale.to(dtype=z.dtype, device=z.device)
560
+ z = z / scale[1] + scale[0]
561
+ iter_ = z.shape[2]
562
+ x = self.conv2(z)
563
+ for i in range(iter_):
564
+ self._conv_idx = [0]
565
+ if i == 0:
566
+ out = self.decoder(x[:, :, i:i + 1, :, :],
567
+ feat_cache=self._feat_map,
568
+ feat_idx=self._conv_idx)
569
+ else:
570
+ out_ = self.decoder(x[:, :, i:i + 1, :, :],
571
+ feat_cache=self._feat_map,
572
+ feat_idx=self._conv_idx)
573
+ out = torch.cat([out, out_], 2) # may add tensor offload
574
+ return out
575
+
576
+ def reparameterize(self, mu, log_var):
577
+ std = torch.exp(0.5 * log_var)
578
+ eps = torch.randn_like(std)
579
+ return eps * std + mu
580
+
581
+ def sample(self, imgs, deterministic=False):
582
+ mu, log_var = self.encode(imgs)
583
+ if deterministic:
584
+ return mu
585
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
586
+ return mu + std * torch.randn_like(std)
587
+
588
+ def clear_cache(self):
589
+ self._conv_num = count_conv3d(self.decoder)
590
+ self._conv_idx = [0]
591
+ self._feat_map = [None] * self._conv_num
592
+ # cache encode
593
+ self._enc_conv_num = count_conv3d(self.encoder)
594
+ self._enc_conv_idx = [0]
595
+ self._enc_feat_map = [None] * self._enc_conv_num
596
+
597
+
598
+ class WanVideoVAE(nn.Module):
599
+
600
+ def __init__(self, z_dim=16):
601
+ super().__init__()
602
+
603
+ mean = [
604
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
605
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
606
+ ]
607
+ std = [
608
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
609
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
610
+ ]
611
+ self.mean = torch.tensor(mean)
612
+ self.std = torch.tensor(std)
613
+ self.scale = [self.mean, 1.0 / self.std]
614
+
615
+ # init model
616
+ self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
617
+ self.upsampling_factor = 8
618
+
619
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
620
+ x = torch.ones((length,))
621
+ if not left_bound:
622
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
623
+ if not right_bound:
624
+ x[-border_width:] = torch.flip(
625
+ (torch.arange(border_width) + 1) / border_width, dims=(0,))
626
+ return x
627
+
628
+ def build_mask(self, data, is_bound, border_width):
629
+ _, _, _, H, W = data.shape
630
+ h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
631
+ w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
632
+
633
+ h = repeat(h, "H -> H W", H=H, W=W)
634
+ w = repeat(w, "W -> H W", H=H, W=W)
635
+
636
+ mask = torch.stack([h, w]).min(dim=0).values
637
+ mask = rearrange(mask, "H W -> 1 1 1 H W")
638
+ return mask
639
+
640
+ def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
641
+ _, _, T, H, W = hidden_states.shape
642
+ size_h, size_w = tile_size
643
+ stride_h, stride_w = tile_stride
644
+
645
+ # Split tasks
646
+ tasks = []
647
+ for h in range(0, H, stride_h):
648
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H):
649
+ continue
650
+ for w in range(0, W, stride_w):
651
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W):
652
+ continue
653
+ h_, w_ = h + size_h, w + size_w
654
+ tasks.append((h, h_, w, w_))
655
+
656
+ data_device = "cpu"
657
+ computation_device = device
658
+
659
+ out_T = T * 4 - 3
660
+ weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W *
661
+ self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
662
+ values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W *
663
+ self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
664
+ disable_flag = True
665
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding", disable=disable_flag):
666
+ hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(
667
+ computation_device)
668
+ hidden_states_batch = self.model.decode(
669
+ hidden_states_batch, self.scale).to(data_device)
670
+
671
+ mask = self.build_mask(
672
+ hidden_states_batch,
673
+ is_bound=(h == 0, h_ >= H, w == 0, w_ >= W),
674
+ border_width=((size_h - stride_h) * self.upsampling_factor,
675
+ (size_w - stride_w) * self.upsampling_factor)
676
+ ).to(dtype=hidden_states.dtype, device=data_device)
677
+
678
+ target_h = h * self.upsampling_factor
679
+ target_w = w * self.upsampling_factor
680
+ import pdb
681
+
682
+ # pdb.set_trace()
683
+ values[
684
+ :,
685
+ :,
686
+ :,
687
+ target_h:target_h + hidden_states_batch.shape[3],
688
+ target_w:target_w + hidden_states_batch.shape[4],
689
+ ] += hidden_states_batch * mask
690
+ weight[
691
+ :,
692
+ :,
693
+ :,
694
+ target_h: target_h + hidden_states_batch.shape[3],
695
+ target_w: target_w + hidden_states_batch.shape[4],
696
+ ] += mask
697
+ values = values / weight
698
+ values = values.clamp_(-1, 1)
699
+ return values
700
+
701
+ def tiled_encode(self, video, device, tile_size, tile_stride):
702
+ _, _, T, H, W = video.shape
703
+ size_h, size_w = tile_size
704
+ stride_h, stride_w = tile_stride
705
+
706
+ # Split tasks
707
+ tasks = []
708
+ for h in range(0, H, stride_h):
709
+ if (h-stride_h >= 0 and h-stride_h+size_h >= H):
710
+ continue
711
+ for w in range(0, W, stride_w):
712
+ if (w-stride_w >= 0 and w-stride_w+size_w >= W):
713
+ continue
714
+ h_, w_ = h + size_h, w + size_w
715
+ tasks.append((h, h_, w, w_))
716
+
717
+ data_device = "cpu"
718
+ computation_device = device
719
+
720
+ out_T = (T + 3) // 4
721
+ weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W //
722
+ self.upsampling_factor), dtype=video.dtype, device=data_device)
723
+ values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W //
724
+ self.upsampling_factor), dtype=video.dtype, device=data_device)
725
+ disable_flag = True
726
+ for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding", disable=disable_flag):
727
+ hidden_states_batch = video[:, :, :,
728
+ h:h_, w:w_].to(computation_device)
729
+ hidden_states_batch = self.model.encode(
730
+ hidden_states_batch, self.scale).to(data_device)
731
+
732
+ mask = self.build_mask(
733
+ hidden_states_batch,
734
+ is_bound=(h == 0, h_ >= H, w == 0, w_ >= W),
735
+ border_width=((size_h - stride_h) // self.upsampling_factor,
736
+ (size_w - stride_w) // self.upsampling_factor)
737
+ ).to(dtype=video.dtype, device=data_device)
738
+
739
+ target_h = h // self.upsampling_factor
740
+ target_w = w // self.upsampling_factor
741
+ values[
742
+ :,
743
+ :,
744
+ :,
745
+ target_h:target_h + hidden_states_batch.shape[3],
746
+ target_w:target_w + hidden_states_batch.shape[4],
747
+ ] += hidden_states_batch * mask
748
+ weight[
749
+ :,
750
+ :,
751
+ :,
752
+ target_h: target_h + hidden_states_batch.shape[3],
753
+ target_w: target_w + hidden_states_batch.shape[4],
754
+ ] += mask
755
+ values = values / weight
756
+ return values
757
+
758
+ def single_encode(self, video, device):
759
+ video = video.to(device)
760
+ x = self.model.encode(video, self.scale)
761
+ return x
762
+
763
+ def single_decode(self, hidden_state, device):
764
+ hidden_state = hidden_state.to(device)
765
+ video = self.model.decode(hidden_state, self.scale)
766
+ return video.clamp_(-1, 1)
767
+
768
+ def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
769
+
770
+ # videos = [video.to("cpu") for video in videos]
771
+ hidden_states = []
772
+ for video in videos:
773
+ video = video.unsqueeze(0)
774
+ if tiled:
775
+ tile_size = (tile_size[0] * 8, tile_size[1] * 8)
776
+ tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
777
+ hidden_state = self.tiled_encode(
778
+ video, device, tile_size, tile_stride)
779
+ else:
780
+ hidden_state = self.single_encode(video, device)
781
+ hidden_state = hidden_state.squeeze(0)
782
+ hidden_states.append(hidden_state)
783
+ hidden_states = torch.stack(hidden_states)
784
+ # TODO
785
+ # if tiled:
786
+ # hidden_states = self.tiled_encode(
787
+ # videos, device, tile_size, tile_stride)
788
+ # else:
789
+ # hidden_states = self.single_encode(videos, device)
790
+ return hidden_states
791
+
792
+ def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
793
+ video = []
794
+ for _hidden_states in hidden_states:
795
+ _hidden_states = _hidden_states.unsqueeze(0)
796
+ if tiled:
797
+ video.append(
798
+ self.tiled_decode(_hidden_states, device, tile_size,
799
+ tile_stride))
800
+ else:
801
+ video.append(self.single_decode(_hidden_states, device))
802
+
803
+ video = torch.cat(video, dim=0)
804
+ # TODO
805
+ # if tiled:
806
+ # video = self.tiled_decode(
807
+ # hidden_states, device, tile_size, tile_stride)
808
+ # else:
809
+ # video = self.single_decode(hidden_states, device)
810
+ return video
811
+
812
+ @staticmethod
813
+ def state_dict_converter():
814
+ return WanVideoVAEStateDictConverter()
815
+
816
+
817
+ class WanVideoVAEStateDictConverter:
818
+
819
+ def __init__(self):
820
+ pass
821
+
822
+ def from_civitai(self, state_dict):
823
+ state_dict_ = {}
824
+ if 'model_state' in state_dict:
825
+ state_dict = state_dict['model_state']
826
+ for name in state_dict:
827
+ state_dict_['model.' + name] = state_dict[name]
828
+ return state_dict_
diffsynth/pipelines/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .wan_video_new_determine import WanVideoPipeline
diffsynth/pipelines/wan_video_new_determine.py ADDED
@@ -0,0 +1,1730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import time
4
+ import types
5
+ import warnings
6
+ from dataclasses import dataclass
7
+ from typing import List, Optional, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from einops import rearrange, reduce, repeat
12
+ # from modelscope import snapshot_download
13
+ from huggingface_hub import snapshot_download
14
+ from PIL import Image
15
+ from tqdm import tqdm
16
+ from typing_extensions import Literal
17
+
18
+ from ..models import ModelManager, load_state_dict
19
+ from ..models.wan_video_dit import RMSNorm, WanModel, sinusoidal_embedding_1d
20
+ from ..models.wan_video_image_encoder import WanImageEncoder
21
+ from ..models.wan_video_motion_controller import WanMotionControllerModel
22
+ # from ..model.
23
+ from ..models.wan_video_text_encoder import (T5LayerNorm, T5RelativeEmbedding,
24
+ WanTextEncoder)
25
+ from ..models.wan_video_vace import VaceWanModel
26
+ from ..models.wan_video_vae import (CausalConv3d, RMS_norm, Upsample,
27
+ WanVideoVAE)
28
+ from ..schedulers.flow_match import FlowMatchScheduler
29
+ # from ..prompters import WanPrompter
30
+ from ..vram_management import (AutoWrappedLinear, AutoWrappedModule,
31
+ WanAutoCastLayerNorm, enable_vram_management)
32
+
33
+
34
+ class BasePipeline(torch.nn.Module):
35
+
36
+ def __init__(
37
+ self,
38
+ device="cuda",
39
+ torch_dtype=torch.float16,
40
+ height_division_factor=64,
41
+ width_division_factor=64,
42
+ time_division_factor=None,
43
+ time_division_remainder=None,
44
+ ):
45
+ super().__init__()
46
+ # The device and torch_dtype is used for the storage of intermediate variables, not models.
47
+ self.device = device
48
+ self.torch_dtype = torch_dtype
49
+ # The following parameters are used for shape check.
50
+ self.height_division_factor = height_division_factor
51
+ self.width_division_factor = width_division_factor
52
+ self.time_division_factor = time_division_factor
53
+ self.time_division_remainder = time_division_remainder
54
+ self.vram_management_enabled = False
55
+
56
+ def to(self, *args, **kwargs):
57
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
58
+ *args, **kwargs
59
+ )
60
+ if device is not None:
61
+ self.device = device
62
+ if dtype is not None:
63
+ self.torch_dtype = dtype
64
+ super().to(*args, **kwargs)
65
+ return self
66
+
67
+ def check_resize_height_width(self, height, width, num_frames=None):
68
+ # Shape check
69
+ # print(
70
+ # f"height, width, time division factor: {self.height_division_factor}, {self.width_division_factor}, {self.time_division_factor}, time division remainder: {self.time_division_remainder}"
71
+ # )
72
+ assert (
73
+ height % self.height_division_factor == 0
74
+ ), f"height {height} is not divisible by {self.height_division_factor}."
75
+
76
+ assert (
77
+ width % self.width_division_factor == 0
78
+ ), f"width {width} is not divisible by {self.width_division_factor}."
79
+ assert (num_frames is not None) and (
80
+ (num_frames + self.time_division_factor) % self.time_division_factor
81
+ == self.time_division_remainder
82
+ ), f"num_frames {num_frames} is not divisible by {self.time_division_factor} with remainder {self.time_division_remainder}."
83
+ return height, width, num_frames
84
+
85
+ def preprocess_image(
86
+ self,
87
+ image,
88
+ torch_dtype=None,
89
+ device=None,
90
+ pattern="B C H W",
91
+ min_value=-1,
92
+ max_value=1,
93
+ ):
94
+ # Transform a PIL.Image to torch.Tensor
95
+ # print(f"Image size: {image.size}, dtype: {image.mode}")
96
+ # assert isinstance(image, torch.Tensor), "Image must be a torch.Tensor."
97
+ # C H W
98
+ if isinstance(image, torch.Tensor):
99
+ # C H W
100
+ # print(f"Image shape {image.shape}")
101
+ assert (len(image.shape) == 3 and image.shape[0] == 3) or (
102
+ len(image.shape) == 4 and image.shape[1] == 3
103
+ ), "Image tensor must be in 3 H W or B 3 H W format."
104
+ image = image.to(
105
+ dtype=torch_dtype or self.torch_dtype, device=device or self.device
106
+ )
107
+ image = image * ((max_value - min_value)) + min_value
108
+ if len(image.shape) == 3:
109
+ image = image.unsqueeze(0) # Add batch dimension
110
+ else:
111
+ image = torch.Tensor(np.array(image, dtype=np.float32))
112
+ image = image.to(
113
+ dtype=torch_dtype or self.torch_dtype, device=device or self.device
114
+ )
115
+ image = image * ((max_value - min_value) / 255) + min_value
116
+ image = repeat(
117
+ image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {})
118
+ )
119
+ return image
120
+
121
+ def preprocess_video(
122
+ self,
123
+ video,
124
+ torch_dtype=None,
125
+ device=None,
126
+ pattern="B C T H W",
127
+ min_value=-1,
128
+ max_value=1,
129
+ ):
130
+ video = [
131
+ self.preprocess_image(
132
+ image,
133
+ torch_dtype=torch_dtype,
134
+ device=device,
135
+ min_value=min_value,
136
+ max_value=max_value,
137
+ )
138
+ for image in video
139
+ ]
140
+ video = torch.stack(video, dim=pattern.index("T") // 2)
141
+ return video
142
+
143
+ def vae_output_to_image(
144
+ self, vae_output, pattern="B C H W", min_value=-1, max_value=1
145
+ ):
146
+ # Transform a torch.Tensor to PIL.Image
147
+ if pattern != "H W C":
148
+ vae_output = reduce(
149
+ vae_output, f"{pattern} -> H W C", reduction="mean")
150
+
151
+ # image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(
152
+ # 0, 255
153
+ # )
154
+ image = (vae_output - min_value) * (255.0 / (max_value - min_value))
155
+ image = image.clamp(0.0, 255.0)
156
+
157
+ image = image.to(device="cpu", dtype=torch.float32)
158
+ image = image.numpy()
159
+ # image = Image.fromarray(image.numpy())
160
+ return image
161
+
162
+ def vae_output_to_video(
163
+ self, vae_output, pattern="B C T H W", min_value=-1, max_value=1
164
+ ):
165
+ # Transform a torch.Tensor to list of PIL.Image
166
+ # if pattern != "T H W C":
167
+ # vae_output = reduce(
168
+ # vae_output, f"{pattern} -> T H W C", reduction="mean")
169
+ if vae_output.ndim == 5: # B C T H W
170
+ assert (
171
+ vae_output.shape[1] == 3
172
+ ), f"vae_output shape {vae_output.shape} is not valid. Expected 5D tensor with 3 channels on the second dimension."
173
+ vae_output = vae_output.permute(0, 2, 3, 4, 1)
174
+ # print(f"vae_output shape after permute: {vae_output.shape}")
175
+ video = vae_output.to(device="cpu", dtype=torch.float32).numpy()
176
+ video = (video + 1.0) / 2.0
177
+ # print(f"Video range before clip: {video.min()} to {video.max()}")
178
+ video = video.clip(0.0, 1.0)
179
+
180
+ # for _video in vae_output:
181
+ # video.append(
182
+ # [
183
+ # self.vae_output_to_image(
184
+ # image,
185
+ # pattern="H W C",
186
+ # min_value=min_value,
187
+ # max_value=max_value,
188
+ # )
189
+ # for image in _video
190
+ # ]
191
+ # )
192
+ # else:
193
+ # raise ValueError(
194
+ # f"Invalid vae_output shape {vae_output.shape}. Expected 5D tensor."
195
+ # )
196
+ return video
197
+
198
+ def load_models_to_device(self, model_names=[]):
199
+ if self.vram_management_enabled:
200
+ # offload models
201
+ for name, model in self.named_children():
202
+ if name not in model_names:
203
+ if (
204
+ hasattr(model, "vram_management_enabled")
205
+ and model.vram_management_enabled
206
+ ):
207
+ for module in model.modules():
208
+ if hasattr(module, "offload"):
209
+ module.offload()
210
+ else:
211
+ model.cpu()
212
+ torch.cuda.empty_cache()
213
+ # onload models
214
+ for name, model in self.named_children():
215
+ if name in model_names:
216
+ if (
217
+ hasattr(model, "vram_management_enabled")
218
+ and model.vram_management_enabled
219
+ ):
220
+ for module in model.modules():
221
+ if hasattr(module, "onload"):
222
+ module.onload()
223
+ else:
224
+ model.to(self.device)
225
+
226
+ def generate_noise(
227
+ self,
228
+ shape,
229
+ seed=None,
230
+ rand_device="cpu",
231
+ rand_torch_dtype=torch.float32,
232
+ device=None,
233
+ torch_dtype=None,
234
+ ):
235
+ # Initialize Gaussian noise
236
+ generator = (
237
+ None if seed is None else torch.Generator(
238
+ rand_device).manual_seed(seed)
239
+ )
240
+ # TODO multi-res noise
241
+ noise = torch.randn(
242
+ shape, generator=generator, device=rand_device, dtype=rand_torch_dtype
243
+ )
244
+ noise = noise.to(
245
+ dtype=torch_dtype or self.torch_dtype, device=device or self.device
246
+ )
247
+ return noise
248
+
249
+ def enable_cpu_offload(self):
250
+ warnings.warn(
251
+ "`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`."
252
+ )
253
+ self.vram_management_enabled = True
254
+
255
+ def get_vram(self):
256
+ return torch.cuda.mem_get_info(self.device)[1] / (1024**3)
257
+
258
+ def freeze_except(self, model_names):
259
+ for name, model in self.named_children():
260
+ if name in model_names:
261
+ print(f"Unfreezing model {name}.")
262
+ print(
263
+ f"Model parameters: {sum(p.numel() for p in model.parameters())}")
264
+ model.train()
265
+ model.requires_grad_(True)
266
+ else:
267
+ print(f"Freezing model {name}.")
268
+ print(
269
+ f"Model parameters: {sum(p.numel() for p in model.parameters())}")
270
+ model.eval()
271
+ model.requires_grad_(False)
272
+
273
+
274
+ @dataclass
275
+ class ModelConfig:
276
+ path: Union[str, list[str]] = None
277
+ model_id: str = None
278
+ origin_file_pattern: Union[str, list[str]] = None
279
+ download_resource: str = "ModelScope"
280
+ offload_device: Optional[Union[str, torch.device]] = None
281
+ offload_dtype: Optional[torch.dtype] = None
282
+
283
+ def download_if_necessary(
284
+ self, local_model_path="./models", skip_download=False, use_usp=False
285
+ ):
286
+ if self.path is None:
287
+ # Check model_id and origin_file_pattern
288
+ if self.model_id is None:
289
+ raise ValueError(
290
+ f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`."""
291
+ )
292
+
293
+ # Skip if not in rank 0
294
+ if use_usp:
295
+ import torch.distributed as dist
296
+
297
+ skip_download = dist.get_rank() != 0
298
+
299
+ # Check whether the origin path is a folder
300
+ if self.origin_file_pattern is None or self.origin_file_pattern == "":
301
+ self.origin_file_pattern = ""
302
+ allow_file_pattern = None
303
+ is_folder = True
304
+ elif isinstance(
305
+ self.origin_file_pattern, str
306
+ ) and self.origin_file_pattern.endswith("/"):
307
+ allow_file_pattern = self.origin_file_pattern + "*"
308
+ is_folder = True
309
+ else:
310
+ allow_file_pattern = self.origin_file_pattern
311
+ is_folder = False
312
+
313
+ # Download
314
+ if not skip_download:
315
+ downloaded_files = glob.glob(
316
+ self.origin_file_pattern,
317
+ root_dir=os.path.join(local_model_path, self.model_id),
318
+ )
319
+ # snapshot_download(
320
+ # self.model_id,
321
+ # local_dir=os.path.join(local_model_path, self.model_id),
322
+ # allow_file_pattern=allow_file_pattern,
323
+ # ignore_file_pattern=downloaded_files,
324
+ # local_files_only=False,
325
+ # )
326
+ snapshot_download(
327
+ self.model_id,
328
+ repo_type="model", # 如果是dataset要改成"dataset"
329
+ local_dir=os.path.join(local_model_path, self.model_id),
330
+ allow_patterns=allow_file_pattern,
331
+ ignore_patterns=downloaded_files, # 注意这里是 patterns
332
+ # ignore_file_pattern=downloaded_files,
333
+ # local_files_only=False,
334
+ local_files_only=False,
335
+ resume_download=True, # 支持断点续传
336
+
337
+ )
338
+
339
+ # Let rank 1, 2, ... wait for rank 0
340
+ if use_usp:
341
+ import torch.distributed as dist
342
+
343
+ dist.barrier(device_ids=[dist.get_rank()])
344
+
345
+ # Return downloaded files
346
+ if is_folder:
347
+ self.path = os.path.join(
348
+ local_model_path, self.model_id, self.origin_file_pattern
349
+ )
350
+ else:
351
+ self.path = glob.glob(
352
+ os.path.join(
353
+ local_model_path, self.model_id, self.origin_file_pattern
354
+ )
355
+ )
356
+ if isinstance(self.path, list) and len(self.path) == 1:
357
+ self.path = self.path[0]
358
+
359
+
360
+ class WanVideoPipeline(BasePipeline):
361
+
362
+ def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None):
363
+ super().__init__(
364
+ device=device,
365
+ torch_dtype=torch_dtype,
366
+ height_division_factor=16,
367
+ width_division_factor=16,
368
+ time_division_factor=4,
369
+ time_division_remainder=1,
370
+ )
371
+ self.scheduler = FlowMatchScheduler(
372
+ shift=5, sigma_min=0.0, extra_one_step=True)
373
+ # self.prompter = WanPrompter(tokenizer_path=tokenizer_path)
374
+ self.text_encoder: WanTextEncoder = None
375
+ self.image_encoder: WanImageEncoder = None
376
+ # self.pose_encoder: CameraPoseEncoder = None
377
+ self.dit: WanModel = None
378
+ self.vae: WanVideoVAE = None
379
+ self.motion_controller: WanMotionControllerModel = None
380
+ self.vace: VaceWanModel = None
381
+ self.in_iteration_models = ("dit", "motion_controller", "vace")
382
+ self.unit_runner = PipelineUnitRunner()
383
+
384
+ self.units = [
385
+ WanVideoUnit_ShapeChecker(), # check if the shape if ok
386
+ # WanVideoUnit_NoiseInitializer(),
387
+ WanVideoUnit_InputVideoEmbedder(),
388
+ WanVideoUnit_PromptEmbedder(),
389
+ WanVideoUnit_ImageEmbedder(),
390
+ # WanVideoUnit_FunReference(),
391
+ # WanVideoUnit_CameraPoseEmbedder(),
392
+ # WanVideoUnit_SpeedControl(),
393
+ # WanVideoUnit_VACE(),
394
+ WanVideoUnit_UnifiedSequenceParallel(),
395
+ # WanVideoUnit_TeaCache(),
396
+ # WanVideoUnit_CfgMerger(),
397
+ ]
398
+
399
+ self.model_fn = model_fn_wan_video
400
+
401
+ def training_predict(self, **inputs):
402
+ timestep_id = torch.tensor([0])
403
+ # print(f"timestep_id: {timestep_id}")
404
+ timestep = self.scheduler.timesteps[timestep_id].to(
405
+ dtype=self.torch_dtype, device=self.device
406
+ )
407
+ # print(f"Selected timestep {timestep}")
408
+ inputs["latents"] = inputs['rgb_latents']
409
+ training_target = self.scheduler.training_target(
410
+ inputs["depth_latents"], inputs["rgb_latents"], timestep
411
+ )
412
+ noise_pred = self.model_fn(**inputs, timestep=timestep)
413
+
414
+ return {
415
+ 'rgb_gt': inputs['rgb_latents'],
416
+ "depth_gt": training_target,
417
+ "pred": noise_pred,
418
+ "weight": self.scheduler.training_weight(timestep),
419
+ }
420
+
421
+ def enable_vram_management(
422
+ self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5
423
+ ):
424
+ self.vram_management_enabled = True
425
+ if num_persistent_param_in_dit is not None:
426
+ vram_limit = None
427
+ else:
428
+ if vram_limit is None:
429
+ vram_limit = self.get_vram()
430
+ vram_limit = vram_limit - vram_buffer
431
+ if self.text_encoder is not None:
432
+ dtype = next(iter(self.text_encoder.parameters())).dtype
433
+ enable_vram_management(
434
+ self.text_encoder,
435
+ module_map={
436
+ torch.nn.Linear: AutoWrappedLinear,
437
+ torch.nn.Embedding: AutoWrappedModule,
438
+ T5RelativeEmbedding: AutoWrappedModule,
439
+ T5LayerNorm: AutoWrappedModule,
440
+ },
441
+ module_config=dict(
442
+ offload_dtype=dtype,
443
+ offload_device="cpu",
444
+ onload_dtype=dtype,
445
+ onload_device="cpu",
446
+ computation_dtype=self.torch_dtype,
447
+ computation_device=self.device,
448
+ ),
449
+ vram_limit=vram_limit,
450
+ )
451
+ if self.dit is not None:
452
+ dtype = next(iter(self.dit.parameters())).dtype
453
+ device = "cpu" if vram_limit is not None else self.device
454
+ enable_vram_management(
455
+ self.dit,
456
+ module_map={
457
+ torch.nn.Linear: AutoWrappedLinear,
458
+ torch.nn.Conv3d: AutoWrappedModule,
459
+ torch.nn.LayerNorm: WanAutoCastLayerNorm,
460
+ RMSNorm: AutoWrappedModule,
461
+ torch.nn.Conv2d: AutoWrappedModule,
462
+ },
463
+ module_config=dict(
464
+ offload_dtype=dtype,
465
+ offload_device="cpu",
466
+ onload_dtype=dtype,
467
+ onload_device=device,
468
+ computation_dtype=self.torch_dtype,
469
+ computation_device=self.device,
470
+ ),
471
+ max_num_param=num_persistent_param_in_dit,
472
+ overflow_module_config=dict(
473
+ offload_dtype=dtype,
474
+ offload_device="cpu",
475
+ onload_dtype=dtype,
476
+ onload_device="cpu",
477
+ computation_dtype=self.torch_dtype,
478
+ computation_device=self.device,
479
+ ),
480
+ vram_limit=vram_limit,
481
+ )
482
+ if self.vae is not None:
483
+ dtype = next(iter(self.vae.parameters())).dtype
484
+ enable_vram_management(
485
+ self.vae,
486
+ module_map={
487
+ torch.nn.Linear: AutoWrappedLinear,
488
+ torch.nn.Conv2d: AutoWrappedModule,
489
+ RMS_norm: AutoWrappedModule,
490
+ CausalConv3d: AutoWrappedModule,
491
+ Upsample: AutoWrappedModule,
492
+ torch.nn.SiLU: AutoWrappedModule,
493
+ torch.nn.Dropout: AutoWrappedModule,
494
+ },
495
+ module_config=dict(
496
+ offload_dtype=dtype,
497
+ offload_device="cpu",
498
+ onload_dtype=dtype,
499
+ onload_device=self.device,
500
+ computation_dtype=self.torch_dtype,
501
+ computation_device=self.device,
502
+ ),
503
+ )
504
+ if self.image_encoder is not None:
505
+ dtype = next(iter(self.image_encoder.parameters())).dtype
506
+ enable_vram_management(
507
+ self.image_encoder,
508
+ module_map={
509
+ torch.nn.Linear: AutoWrappedLinear,
510
+ torch.nn.Conv2d: AutoWrappedModule,
511
+ torch.nn.LayerNorm: AutoWrappedModule,
512
+ },
513
+ module_config=dict(
514
+ offload_dtype=dtype,
515
+ offload_device="cpu",
516
+ onload_dtype=dtype,
517
+ onload_device="cpu",
518
+ computation_dtype=dtype,
519
+ computation_device=self.device,
520
+ ),
521
+ )
522
+ if self.motion_controller is not None:
523
+ dtype = next(iter(self.motion_controller.parameters())).dtype
524
+ enable_vram_management(
525
+ self.motion_controller,
526
+ module_map={
527
+ torch.nn.Linear: AutoWrappedLinear,
528
+ },
529
+ module_config=dict(
530
+ offload_dtype=dtype,
531
+ offload_device="cpu",
532
+ onload_dtype=dtype,
533
+ onload_device="cpu",
534
+ computation_dtype=dtype,
535
+ computation_device=self.device,
536
+ ),
537
+ )
538
+ if self.vace is not None:
539
+ device = "cpu" if vram_limit is not None else self.device
540
+ enable_vram_management(
541
+ self.vace,
542
+ module_map={
543
+ torch.nn.Linear: AutoWrappedLinear,
544
+ torch.nn.Conv3d: AutoWrappedModule,
545
+ torch.nn.LayerNorm: AutoWrappedModule,
546
+ RMSNorm: AutoWrappedModule,
547
+ },
548
+ module_config=dict(
549
+ offload_dtype=dtype,
550
+ offload_device="cpu",
551
+ onload_dtype=dtype,
552
+ onload_device=device,
553
+ computation_dtype=self.torch_dtype,
554
+ computation_device=self.device,
555
+ ),
556
+ vram_limit=vram_limit,
557
+ )
558
+
559
+ def initialize_usp(self):
560
+ import torch.distributed as dist
561
+ from xfuser.core.distributed import (init_distributed_environment,
562
+ initialize_model_parallel)
563
+
564
+ dist.init_process_group(backend="nccl", init_method="env://")
565
+ init_distributed_environment(
566
+ rank=dist.get_rank(), world_size=dist.get_world_size()
567
+ )
568
+ initialize_model_parallel(
569
+ sequence_parallel_degree=dist.get_world_size(),
570
+ ring_degree=1,
571
+ ulysses_degree=dist.get_world_size(),
572
+ )
573
+ torch.cuda.set_device(dist.get_rank())
574
+
575
+ def enable_usp(self):
576
+ from xfuser.core.distributed import get_sequence_parallel_world_size
577
+
578
+ from ..distributed.xdit_context_parallel import (usp_attn_forward,
579
+ usp_dit_forward)
580
+
581
+ for block in self.dit.blocks:
582
+ block.self_attn.forward = types.MethodType(
583
+ usp_attn_forward, block.self_attn
584
+ )
585
+ self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
586
+ self.sp_size = get_sequence_parallel_world_size()
587
+ self.use_unified_sequence_parallel = True
588
+
589
+ @staticmethod
590
+ def from_pretrained(
591
+ torch_dtype: torch.dtype = torch.bfloat16,
592
+ device: Union[str, torch.device] = "cuda",
593
+ model_configs: list[ModelConfig] = [],
594
+ tokenizer_config: ModelConfig = ModelConfig(
595
+ model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"
596
+ ),
597
+ local_model_path: str = "./models",
598
+ skip_download: bool = False,
599
+ redirect_common_files: bool = True,
600
+ use_usp=False,
601
+ ):
602
+ # Redirect model path
603
+ if redirect_common_files:
604
+ redirect_dict = {
605
+ "models_t5_umt5-xxl-enc-bf16.pth": "Wan-AI/Wan2.1-T2V-1.3B",
606
+ "Wan2.1_VAE.pth": "Wan-AI/Wan2.1-T2V-1.3B",
607
+ "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth": "Wan-AI/Wan2.1-I2V-14B-480P",
608
+ }
609
+ for model_config in model_configs:
610
+ if (
611
+ model_config.origin_file_pattern is None
612
+ or model_config.model_id is None
613
+ ):
614
+ continue
615
+ if (
616
+ model_config.origin_file_pattern in redirect_dict
617
+ and model_config.model_id
618
+ != redirect_dict[model_config.origin_file_pattern]
619
+ ):
620
+ print(
621
+ f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection."
622
+ )
623
+ model_config.model_id = redirect_dict[
624
+ model_config.origin_file_pattern
625
+ ]
626
+
627
+ # Initialize pipeline
628
+ pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
629
+ if use_usp:
630
+ pipe.initialize_usp()
631
+
632
+ # Download and load models
633
+ model_manager = ModelManager()
634
+ for model_config in model_configs:
635
+ model_config.download_if_necessary(
636
+ local_model_path, skip_download=skip_download, use_usp=use_usp
637
+ )
638
+ model_manager.load_model(
639
+ model_config.path,
640
+ device=model_config.offload_device or device,
641
+ torch_dtype=model_config.offload_dtype or torch_dtype,
642
+ )
643
+
644
+ # Load models
645
+ # pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
646
+ pipe.dit = model_manager.fetch_model("wan_video_dit")
647
+
648
+ pipe.vae = model_manager.fetch_model("wan_video_vae")
649
+ pipe.image_encoder = model_manager.fetch_model(
650
+ "wan_video_image_encoder")
651
+ pipe.motion_controller = model_manager.fetch_model(
652
+ "wan_video_motion_controller"
653
+ )
654
+ pipe.vace = model_manager.fetch_model("wan_video_vace")
655
+
656
+ # Initialize tokenizer
657
+ tokenizer_config.download_if_necessary(
658
+ local_model_path, skip_download=skip_download
659
+ )
660
+ # pipe.prompter.fetch_models(pipe.text_encoder)
661
+ # pipe.prompter.fetch_tokenizer(tokenizer_config.path)
662
+
663
+ # Unified Sequence Parallel
664
+ if use_usp:
665
+ pipe.enable_usp()
666
+ return pipe
667
+
668
+ # @torch.no_grad()
669
+ @torch.inference_mode()
670
+ def __call__(
671
+ self,
672
+ # Prompt
673
+ prompt: str,
674
+ negative_prompt: Optional[str] = "",
675
+ # Image-to-video
676
+ input_image: Optional[Image.Image] = None,
677
+ # First-last-frame-to-video
678
+ end_image: Optional[Image.Image] = None,
679
+ # Video-to-video
680
+ input_video: Optional[list[Image.Image]] = None,
681
+ denoising_strength: Optional[float] = 1.0,
682
+ # ControlNet
683
+ reference_image: Optional[Image.Image] = None,
684
+ extra_images: Optional[List[Image.Image]] = None,
685
+ extra_image_frame_index: Optional[List[int]] = None,
686
+ # VACE
687
+ vace_video: Optional[list[Image.Image]] = None,
688
+ vace_video_mask: Optional[Image.Image] = None,
689
+ vace_reference_image: Optional[Image.Image] = None,
690
+ vace_scale: Optional[float] = 1.0,
691
+ # Randomness
692
+ seed: Optional[int] = None,
693
+ rand_device: Optional[str] = "cpu",
694
+ # Shape
695
+ mode: Optional[str] = "regression",
696
+ batch_size: Optional[int] = 1,
697
+ height: Optional[int] = 480,
698
+ width: Optional[int] = 720,
699
+ frame_mask: Optional[torch.Tensor] = None,
700
+ num_frames=41,
701
+ # Classifier-free guidance
702
+ cfg_scale: Optional[float] = 1,
703
+ cfg_merge: Optional[bool] = False,
704
+ # Scheduler
705
+ num_inference_steps: Optional[int] = 1,
706
+ sigma_shift: Optional[float] = 5.0,
707
+ denoise_step=1,
708
+ # Speed control
709
+ motion_bucket_id: Optional[int] = None,
710
+ # VAE tiling
711
+ tiled: Optional[bool] = False,
712
+ tile_size: Optional[tuple[int, int]] = (30, 52),
713
+ tile_stride: Optional[tuple[int, int]] = (15, 26),
714
+ # Sliding window
715
+ sliding_window_size: Optional[int] = None,
716
+ sliding_window_stride: Optional[int] = None,
717
+ # Teacache
718
+ tea_cache_l1_thresh: Optional[float] = None,
719
+ tea_cache_model_id: Optional[str] = "",
720
+ # progress_bar
721
+ progress_bar_cmd=tqdm,
722
+ ):
723
+ self.scheduler.set_timesteps(
724
+ num_inference_steps=num_inference_steps,
725
+ denoising_strength=denoising_strength,
726
+ shift=sigma_shift,
727
+ denoise_step=denoise_step,
728
+ )
729
+
730
+ # Inputs
731
+ inputs_posi = {
732
+ "prompt": prompt,
733
+ "prompt_num": batch_size,
734
+ "tea_cache_l1_thresh": tea_cache_l1_thresh,
735
+ "tea_cache_model_id": tea_cache_model_id,
736
+ "num_inference_steps": num_inference_steps,
737
+ }
738
+ inputs_nega = {
739
+ "negative_prompt": negative_prompt,
740
+ "prompt_num": batch_size,
741
+ "tea_cache_l1_thresh": tea_cache_l1_thresh,
742
+ "tea_cache_model_id": tea_cache_model_id,
743
+ "num_inference_steps": num_inference_steps,
744
+ }
745
+
746
+ inputs_shared = {
747
+ "batch_size": batch_size,
748
+ "input_image": input_image,
749
+ "end_image": end_image,
750
+ "input_video": input_video,
751
+ "denoising_strength": denoising_strength,
752
+ "reference_image": reference_image,
753
+ "vace_video": vace_video,
754
+ "vace_video_mask": vace_video_mask,
755
+ "vace_reference_image": vace_reference_image,
756
+ "vace_scale": vace_scale,
757
+ "seed": seed,
758
+ "rand_device": rand_device,
759
+ 'mode': mode,
760
+ "height": height,
761
+ "width": width,
762
+ "frame_mask": frame_mask,
763
+ "num_frames": num_frames,
764
+ "cfg_scale": cfg_scale,
765
+ "cfg_merge": cfg_merge,
766
+ "sigma_shift": sigma_shift,
767
+ "motion_bucket_id": motion_bucket_id,
768
+ "tiled": tiled,
769
+ "tile_size": tile_size,
770
+ "tile_stride": tile_stride,
771
+ "sliding_window_size": sliding_window_size,
772
+ "sliding_window_stride": sliding_window_stride,
773
+ "extra_images": extra_images,
774
+ "extra_image_frame_index": extra_image_frame_index,
775
+ }
776
+ for unit in self.units:
777
+ inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
778
+ unit, self, inputs_shared, inputs_posi, inputs_nega
779
+ )
780
+
781
+ models = {name: getattr(self, name)
782
+ for name in self.in_iteration_models}
783
+
784
+ for timestep in self.scheduler.timesteps:
785
+ timestep = timestep.unsqueeze(0).to(
786
+ dtype=self.torch_dtype, device=self.device
787
+ )
788
+ # torch.cuda.synchronize()
789
+ # start_time = time.time()
790
+ noise_pred_posi = self.model_fn(
791
+ **models, **inputs_shared, **inputs_posi, timestep=timestep
792
+ )
793
+ # torch.cuda.synchronize()
794
+ # end_time = time.time()
795
+ # print(f"Model forward time: {end_time - start_time}")
796
+ noise_pred = noise_pred_posi
797
+
798
+ inputs_shared["latents"] = self.scheduler.step(
799
+ model_output=noise_pred,
800
+ sample=inputs_shared["latents"],
801
+ )
802
+
803
+ rgb, depth = None, None
804
+ if isinstance(inputs_shared['latents'], tuple):
805
+ rgb, depth = inputs_shared['latents']
806
+ else:
807
+ depth = inputs_shared['latents']
808
+
809
+ # VACE (TODO: remove it)
810
+ if vace_reference_image is not None:
811
+ inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
812
+
813
+ # torch.cuda.synchronize()
814
+ # start_time = time.time()
815
+ depth_video = self.vae.decode(
816
+ depth,
817
+ device=self.device,
818
+ tiled=tiled,
819
+ tile_size=tile_size,
820
+ tile_stride=tile_stride,
821
+ )
822
+ # torch.cuda.synchronize()
823
+ # end_time = time.time()
824
+ # print(f"VAE decoding time: {end_time - start_time}")
825
+ depth_video = self.vae_output_to_video(depth_video)
826
+ rgb_video = None
827
+ if rgb is not None:
828
+ rgb_video = self.vae.decode(
829
+ depth,
830
+ device=self.device,
831
+ tiled=tiled,
832
+ tile_size=tile_size,
833
+ tile_stride=tile_stride,
834
+ )
835
+ rgb_video = self.vae_output_to_video(rgb_video)
836
+
837
+ return {
838
+ 'depth': depth_video,
839
+ 'rgb': rgb_video
840
+ }
841
+
842
+
843
+ class PipelineUnit:
844
+ def __init__(
845
+ self,
846
+ seperate_cfg: bool = False,
847
+ take_over: bool = False,
848
+ input_params: tuple[str] = None,
849
+ input_params_posi: dict[str, str] = None,
850
+ input_params_nega: dict[str, str] = None,
851
+ onload_model_names: tuple[str] = None,
852
+ ):
853
+ self.seperate_cfg = seperate_cfg
854
+ self.take_over = take_over
855
+ self.input_params = input_params
856
+ self.input_params_posi = input_params_posi
857
+ self.input_params_nega = input_params_nega
858
+ self.onload_model_names = onload_model_names
859
+
860
+ def process(
861
+ self, pipe: WanVideoPipeline, inputs: dict, positive=True, **kwargs
862
+ ) -> dict:
863
+ raise NotImplementedError("`process` is not implemented.")
864
+
865
+
866
+ class PipelineUnitRunner:
867
+ def __init__(self):
868
+ pass
869
+
870
+ def __call__(
871
+ self,
872
+ unit: PipelineUnit,
873
+ pipe: WanVideoPipeline,
874
+ inputs_shared: dict,
875
+ inputs_posi: dict,
876
+ inputs_nega: dict,
877
+ ) -> tuple[dict, dict]:
878
+ if unit.take_over:
879
+ # Let the pipeline unit take over this function.
880
+ inputs_shared, inputs_posi, inputs_nega = unit.process(
881
+ pipe,
882
+ inputs_shared=inputs_shared,
883
+ inputs_posi=inputs_posi,
884
+ inputs_nega=inputs_nega,
885
+ )
886
+ elif unit.seperate_cfg:
887
+ # Positive side
888
+ processor_inputs = {
889
+ name: inputs_posi.get(name_)
890
+ for name, name_ in unit.input_params_posi.items()
891
+ }
892
+ if unit.input_params is not None:
893
+ for name in unit.input_params:
894
+ processor_inputs[name] = inputs_shared.get(name)
895
+ processor_outputs = unit.process(pipe, **processor_inputs)
896
+ inputs_posi.update(processor_outputs)
897
+ # Negative side
898
+ if inputs_shared["cfg_scale"] != 1:
899
+ processor_inputs = {
900
+ name: inputs_nega.get(name_)
901
+ for name, name_ in unit.input_params_nega.items()
902
+ }
903
+ if unit.input_params is not None:
904
+ for name in unit.input_params:
905
+ processor_inputs[name] = inputs_shared.get(name)
906
+ processor_outputs = unit.process(pipe, **processor_inputs)
907
+ inputs_nega.update(processor_outputs)
908
+ else:
909
+ inputs_nega.update(processor_outputs)
910
+ else:
911
+ processor_inputs = {
912
+ name: inputs_shared.get(name) for name in unit.input_params
913
+ }
914
+ processor_outputs = unit.process(pipe, **processor_inputs)
915
+ inputs_shared.update(processor_outputs)
916
+ return inputs_shared, inputs_posi, inputs_nega
917
+
918
+
919
+ class WanVideoUnit_ShapeChecker(PipelineUnit):
920
+ def __init__(self):
921
+ super().__init__(input_params=("height", "width", "num_frames"))
922
+
923
+ def process(self, pipe: WanVideoPipeline, height, width, num_frames):
924
+ # print(
925
+ # f"Init WanVideoPipeline with height={height}, width={width}, num_frames={num_frames}."
926
+ # )
927
+ height, width, num_frames = pipe.check_resize_height_width(
928
+ height, width, num_frames
929
+ )
930
+ # print(
931
+ # f"Resized WanVideoPipeline to height={height}, width={width}, num_frames={num_frames}."
932
+ # )
933
+ return {"height": height, "width": width, "num_frames": num_frames}
934
+
935
+
936
+ class WanVideoUnit_NoiseInitializer(PipelineUnit):
937
+ def __init__(self):
938
+ super().__init__(
939
+ input_params=(
940
+ "batch_size",
941
+ "height",
942
+ "width",
943
+ "num_frames",
944
+ "seed",
945
+ "rand_device",
946
+ "vace_reference_image",
947
+ )
948
+ )
949
+
950
+ def process(
951
+ self,
952
+ pipe: WanVideoPipeline,
953
+ batch_size,
954
+ height,
955
+ width,
956
+ num_frames,
957
+ seed,
958
+ rand_device,
959
+ vace_reference_image,
960
+ ):
961
+ # print(f"num frames {num_frames}")
962
+ length = (num_frames - 1) // 4 + 1
963
+ if vace_reference_image is not None:
964
+ length += 1
965
+ # TODO
966
+ noise = pipe.generate_noise(
967
+ (batch_size, 16, length, height // 8, width // 8),
968
+ seed=seed,
969
+ rand_device=rand_device,
970
+ )
971
+ # print(f"Noise shape {noise.shape} ")
972
+
973
+ return {"noise": noise, "latents": noise}
974
+
975
+
976
+ class WanVideoUnit_InputVideoEmbedder(PipelineUnit): # For training only
977
+ def __init__(self):
978
+ super().__init__(
979
+ input_params=(
980
+ 'mode',
981
+ 'seed',
982
+ 'rand_device',
983
+ "batch_size",
984
+ "height",
985
+ "width",
986
+ "num_frames",
987
+ "input_video",
988
+ "input_disp",
989
+ "noise",
990
+ "tiled",
991
+ "tile_size",
992
+ "tile_stride",
993
+ "vace_reference_image",
994
+ ),
995
+ onload_model_names=("vae",),
996
+ )
997
+
998
+ def process(
999
+ self,
1000
+ pipe,
1001
+ mode,
1002
+ seed,
1003
+ rand_device,
1004
+ batch_size,
1005
+ height,
1006
+ width,
1007
+ num_frames,
1008
+ input_video,
1009
+ input_disp,
1010
+ noise,
1011
+ tiled,
1012
+ tile_size,
1013
+ tile_stride,
1014
+ vace_reference_image,
1015
+ ):
1016
+ assert mode in ['generation',
1017
+ 'regression'], f"mode {mode} is not supported"
1018
+ length = (num_frames - 1) // 4 + 1
1019
+ # inference part
1020
+ if not pipe.scheduler.training:
1021
+ if mode == 'generation':
1022
+ # only need noise
1023
+ noise = pipe.generate_noise(
1024
+ (batch_size, 16, length, height // 8, width // 8),
1025
+ seed=seed,
1026
+ rand_device=rand_device,
1027
+ )
1028
+ return {'latents': noise}
1029
+ else:
1030
+ # only need rgb latent
1031
+ video_list = []
1032
+ for _input_video in input_video:
1033
+ _preprocessed_video = pipe.preprocess_video(_input_video)
1034
+ video_list.append(_preprocessed_video)
1035
+ videos_tensor = torch.cat(video_list, dim=0)
1036
+ # print(f"videos_tensor shape: {videos_tensor.shape}")
1037
+ input_rgb_latents = pipe.vae.encode(
1038
+ videos_tensor,
1039
+ device=pipe.device,
1040
+ tiled=tiled,
1041
+ tile_size=tile_size,
1042
+ tile_stride=tile_stride,
1043
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
1044
+ return {"latents": input_rgb_latents}
1045
+
1046
+ disp_list = []
1047
+ for _input_disp in input_disp:
1048
+ _preprocessed_disp = pipe.preprocess_video(_input_disp)
1049
+ disp_list.append(_preprocessed_disp)
1050
+ disp_tensor = torch.cat(disp_list, dim=0)
1051
+ input_disp_latents = pipe.vae.encode(
1052
+ disp_tensor,
1053
+ device=pipe.device,
1054
+ tiled=tiled,
1055
+ tile_size=tile_size,
1056
+ tile_stride=tile_stride,
1057
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
1058
+
1059
+ # Training
1060
+ if mode == 'generation':
1061
+ # need noise + depth
1062
+ noise = pipe.generate_noise(
1063
+ (batch_size, 16, length, height // 8, width // 8),
1064
+ seed=seed,
1065
+ rand_device=rand_device,
1066
+ )
1067
+ return {'rgb_latents': noise, 'depth_latents': input_disp_latents}
1068
+ else:
1069
+ # need rgb + depth
1070
+ video_list = []
1071
+ for _input_video in input_video:
1072
+ _preprocessed_video = pipe.preprocess_video(_input_video)
1073
+ video_list.append(_preprocessed_video)
1074
+ videos_tensor = torch.cat(video_list, dim=0)
1075
+ input_rgb_latents = pipe.vae.encode(
1076
+ videos_tensor,
1077
+ device=pipe.device,
1078
+ tiled=tiled,
1079
+ tile_size=tile_size,
1080
+ tile_stride=tile_stride,
1081
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
1082
+ # del videos_tensor
1083
+ return {
1084
+ "rgb_latents": input_rgb_latents,
1085
+ "depth_latents": input_disp_latents,
1086
+ }
1087
+
1088
+
1089
+ class WanVideoUnit_PromptEmbedder(PipelineUnit):
1090
+ def __init__(self):
1091
+ super().__init__(
1092
+ seperate_cfg=True,
1093
+ input_params_posi={
1094
+ "prompt": "prompt",
1095
+ "positive": "positive",
1096
+ "prompt_num": "prompt_num",
1097
+ },
1098
+ input_params_nega={
1099
+ "prompt": "negative_prompt",
1100
+ "positive": "positive",
1101
+ "prompt_num": "prompt_num",
1102
+ },
1103
+ onload_model_names=("text_encoder",),
1104
+ )
1105
+
1106
+ def process(self, pipe: WanVideoPipeline, prompt, positive, prompt_num) -> dict:
1107
+ # pipe.load_models_to_device(self.onload_model_names)
1108
+ prompt_emb = []
1109
+ # print(f"Encoding prompt: {prompt}")
1110
+ # if isinstance(prompt, str):
1111
+ # prompt = [prompt] * prompt_num
1112
+ # prompt_emb = None
1113
+ # for _prompt in prompt:
1114
+ # _prompt_emb = pipe.prompter.encode_prompt(
1115
+ # _prompt, positive=positive, device=pipe.device
1116
+ # )
1117
+ # prompt_emb = _prompt_emb
1118
+ # break
1119
+ # prompt_emb = prompt_emb.repeat(prompt_num,1,1)
1120
+ # # prompt_emb = torch.cat(prompt_emb, dim=0)
1121
+ # prompt_emb = prompt_emb.to(dtype=pipe.torch_dtype, device=pipe.device)
1122
+ # print(f"Prompt embedding shape: {prompt_emb.shape}")
1123
+ zero_pad = torch.zeros([prompt_num, 512, 4096])
1124
+ zero_pad = zero_pad.to(dtype=pipe.torch_dtype, device=pipe.device)
1125
+ return {"context": zero_pad}
1126
+ # return {"context": prompt_emb}
1127
+
1128
+
1129
+ class WanVideoUnit_ImageEmbedder(PipelineUnit):
1130
+ def __init__(self):
1131
+ super().__init__(
1132
+ input_params=(
1133
+ "input_image",
1134
+ "end_image",
1135
+ "num_frames",
1136
+ "height",
1137
+ "width",
1138
+ "tiled",
1139
+ "tile_size",
1140
+ "tile_stride",
1141
+ "extra_images",
1142
+ "extra_image_frame_index",
1143
+ ),
1144
+ onload_model_names=("image_encoder", "vae"),
1145
+ )
1146
+
1147
+ def process(
1148
+ self,
1149
+ pipe: WanVideoPipeline,
1150
+ input_image,
1151
+ end_image,
1152
+ num_frames,
1153
+ height,
1154
+ width,
1155
+ tiled,
1156
+ tile_size,
1157
+ tile_stride,
1158
+ extra_images,
1159
+ extra_image_frame_index,
1160
+ ):
1161
+ # print(f"input image shape{input_image.shape} ")
1162
+ if not pipe.dit.has_image_input:
1163
+ return {}
1164
+ if input_image is None:
1165
+ return {}
1166
+ # pipe.load_models_to_device(self.onload_model_names)
1167
+ image = pipe.preprocess_image(input_image).to(pipe.device) # B C H W
1168
+ batch_size = image.shape[0]
1169
+ clip_context = pipe.image_encoder.encode_image([image])
1170
+ msk = torch.ones(
1171
+ batch_size, num_frames, height // 8, width // 8, device=pipe.device
1172
+ )
1173
+
1174
+ # print(
1175
+ # f"tiled, tile size, tile stride: {tiled}, {tile_size}, {tile_stride}")
1176
+ # Assmue that one must have a input image
1177
+ vae_input = torch.concat(
1178
+ [
1179
+ image.unsqueeze(2), # B C 1 H W
1180
+ torch.zeros(batch_size, 3, num_frames - 1, height, width).to(
1181
+ image.device
1182
+ ),
1183
+ ],
1184
+ dim=2,
1185
+ ) # B C F H W
1186
+
1187
+ vae_input = vae_input.permute(0, 2, 1, 3, 4).contiguous() # B F C H W
1188
+
1189
+ if (
1190
+ extra_images is not None
1191
+ and extra_image_frame_index is not None
1192
+ ):
1193
+ # print(f"extra images shape {extra_images.shape}")
1194
+ for _videoid, _video in enumerate(extra_images):
1195
+ # _video F C H W
1196
+ for idx, image in enumerate(_video):
1197
+ if idx == 0:
1198
+ continue
1199
+ image = pipe.preprocess_image(
1200
+ image).to(pipe.device) # 1 C H W
1201
+ vae_input[_videoid, idx] = image.squeeze(0)
1202
+
1203
+ mask = extra_image_frame_index[:, :, None, None].to(
1204
+ pipe.device) # B F 1 1
1205
+ mask = mask.expand(
1206
+ batch_size, mask.shape[1], height // 8, width // 8
1207
+ ) # B F H W
1208
+
1209
+ msk = msk * mask
1210
+ else:
1211
+ msk[:, 1:] = 0
1212
+
1213
+ msk = torch.concat(
1214
+ [torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1
1215
+ )
1216
+ msk = msk.view(
1217
+ batch_size, msk.shape[1] // 4, 4, height // 8, width // 8
1218
+ ) # B F C(4) H W
1219
+ msk = msk.transpose(1, 2)
1220
+ vae_input = vae_input.permute(0, 2, 1, 3, 4).contiguous() # B C F H W
1221
+ y = pipe.vae.encode(
1222
+ vae_input.to(dtype=pipe.torch_dtype, device=pipe.device),
1223
+ device=pipe.device,
1224
+ tiled=tiled,
1225
+ tile_size=tile_size,
1226
+ tile_stride=tile_stride,
1227
+ )
1228
+ # print(f"y shape after VAE encode: {y.shape}")
1229
+ # print(f"after VAE encode, y shape: {y.shape}")
1230
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
1231
+ # print()
1232
+ y = torch.concat([msk, y], dim=1) # B 16+4 F H W
1233
+ # print(f"after concat, y shape: {y.shape}")
1234
+ # y = y.unsqueeze(0)
1235
+ clip_context = clip_context.to(
1236
+ dtype=pipe.torch_dtype, device=pipe.device)
1237
+ # print(f"clip context shape: {clip_context.shape}")
1238
+ y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
1239
+ return {"clip_feature": clip_context, "y": y}
1240
+
1241
+
1242
+ class WanVideoUnit_VACE(PipelineUnit):
1243
+ def __init__(self):
1244
+ super().__init__(
1245
+ input_params=(
1246
+ "vace_video",
1247
+ "vace_video_mask",
1248
+ "vace_reference_image",
1249
+ "vace_scale",
1250
+ "height",
1251
+ "width",
1252
+ "num_frames",
1253
+ "tiled",
1254
+ "tile_size",
1255
+ "tile_stride",
1256
+ ),
1257
+ onload_model_names=("vae",),
1258
+ )
1259
+
1260
+ def process(
1261
+ self,
1262
+ pipe: WanVideoPipeline,
1263
+ vace_video,
1264
+ vace_video_mask,
1265
+ vace_reference_image,
1266
+ vace_scale,
1267
+ height,
1268
+ width,
1269
+ num_frames,
1270
+ tiled,
1271
+ tile_size,
1272
+ tile_stride,
1273
+ ):
1274
+ if (
1275
+ vace_video is not None
1276
+ or vace_video_mask is not None
1277
+ or vace_reference_image is not None
1278
+ ):
1279
+ # pipe.load_models_to_device(["vae"])
1280
+ if vace_video is None:
1281
+ vace_video = torch.zeros(
1282
+ (1, 3, num_frames, height, width),
1283
+ dtype=pipe.torch_dtype,
1284
+ device=pipe.device,
1285
+ )
1286
+ else:
1287
+ vace_video = pipe.preprocess_video(vace_video)
1288
+
1289
+ if vace_video_mask is None:
1290
+ vace_video_mask = torch.ones_like(vace_video)
1291
+ else:
1292
+ vace_video_mask = pipe.preprocess_video(
1293
+ vace_video_mask, min_value=0, max_value=1
1294
+ )
1295
+
1296
+ inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask
1297
+ reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask)
1298
+ inactive = pipe.vae.encode(
1299
+ inactive,
1300
+ device=pipe.device,
1301
+ tiled=tiled,
1302
+ tile_size=tile_size,
1303
+ tile_stride=tile_stride,
1304
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
1305
+ reactive = pipe.vae.encode(
1306
+ reactive,
1307
+ device=pipe.device,
1308
+ tiled=tiled,
1309
+ tile_size=tile_size,
1310
+ tile_stride=tile_stride,
1311
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
1312
+ vace_video_latents = torch.concat((inactive, reactive), dim=1)
1313
+
1314
+ vace_mask_latents = rearrange(
1315
+ vace_video_mask[0, 0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8
1316
+ )
1317
+ vace_mask_latents = torch.nn.functional.interpolate(
1318
+ vace_mask_latents,
1319
+ size=(
1320
+ (vace_mask_latents.shape[2] + 3) // 4,
1321
+ vace_mask_latents.shape[3],
1322
+ vace_mask_latents.shape[4],
1323
+ ),
1324
+ mode="nearest-exact",
1325
+ )
1326
+
1327
+ if vace_reference_image is None:
1328
+ pass
1329
+ else:
1330
+ vace_reference_image = pipe.preprocess_video(
1331
+ [vace_reference_image])
1332
+ vace_reference_latents = pipe.vae.encode(
1333
+ vace_reference_image,
1334
+ device=pipe.device,
1335
+ tiled=tiled,
1336
+ tile_size=tile_size,
1337
+ tile_stride=tile_stride,
1338
+ ).to(dtype=pipe.torch_dtype, device=pipe.device)
1339
+ vace_reference_latents = torch.concat(
1340
+ (vace_reference_latents, torch.zeros_like(
1341
+ vace_reference_latents)),
1342
+ dim=1,
1343
+ )
1344
+ vace_video_latents = torch.concat(
1345
+ (vace_reference_latents, vace_video_latents), dim=2
1346
+ )
1347
+ vace_mask_latents = torch.concat(
1348
+ (torch.zeros_like(
1349
+ vace_mask_latents[:, :, :1]), vace_mask_latents),
1350
+ dim=2,
1351
+ )
1352
+
1353
+ vace_context = torch.concat(
1354
+ (vace_video_latents, vace_mask_latents), dim=1)
1355
+ return {"vace_context": vace_context, "vace_scale": vace_scale}
1356
+ else:
1357
+ # print(f"No VACE video, mask or reference image provided, skipping VACE.")
1358
+ return {"vace_context": None, "vace_scale": vace_scale}
1359
+
1360
+
1361
+ class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
1362
+ def __init__(self):
1363
+ super().__init__(input_params=())
1364
+
1365
+ def process(self, pipe: WanVideoPipeline):
1366
+ if hasattr(pipe, "use_unified_sequence_parallel"):
1367
+ if pipe.use_unified_sequence_parallel:
1368
+ return {"use_unified_sequence_parallel": True}
1369
+ return {}
1370
+
1371
+
1372
+ class WanVideoUnit_CfgMerger(PipelineUnit):
1373
+ def __init__(self):
1374
+ super().__init__(take_over=True)
1375
+ self.concat_tensor_names = ["context",
1376
+ "clip_feature", "y", "reference_latents"]
1377
+
1378
+ def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
1379
+ if not inputs_shared["cfg_merge"]:
1380
+ # print(f"Skipping CFG merge, cfg_merge is set to False.")
1381
+ return inputs_shared, inputs_posi, inputs_nega
1382
+ for name in self.concat_tensor_names:
1383
+ tensor_posi = inputs_posi.get(name)
1384
+ tensor_nega = inputs_nega.get(name)
1385
+ tensor_shared = inputs_shared.get(name)
1386
+ if tensor_posi is not None and tensor_nega is not None:
1387
+ inputs_shared[name] = torch.concat(
1388
+ (tensor_posi, tensor_nega), dim=0)
1389
+ elif tensor_shared is not None:
1390
+ inputs_shared[name] = torch.concat(
1391
+ (tensor_shared, tensor_shared), dim=0
1392
+ )
1393
+ inputs_posi.clear()
1394
+ inputs_nega.clear()
1395
+ return inputs_shared, inputs_posi, inputs_nega
1396
+
1397
+
1398
+ class TeaCache:
1399
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
1400
+ self.num_inference_steps = num_inference_steps
1401
+ self.step = 0
1402
+ self.accumulated_rel_l1_distance = 0
1403
+ self.previous_modulated_input = None
1404
+ self.rel_l1_thresh = rel_l1_thresh
1405
+ self.previous_residual = None
1406
+ self.previous_hidden_states = None
1407
+
1408
+ self.coefficients_dict = {
1409
+ "Wan2.1-T2V-1.3B": [
1410
+ -5.21862437e04,
1411
+ 9.23041404e03,
1412
+ -5.28275948e02,
1413
+ 1.36987616e01,
1414
+ -4.99875664e-02,
1415
+ ],
1416
+ "Wan2.1-T2V-14B": [
1417
+ -3.03318725e05,
1418
+ 4.90537029e04,
1419
+ -2.65530556e03,
1420
+ 5.87365115e01,
1421
+ -3.15583525e-01,
1422
+ ],
1423
+ "Wan2.1-I2V-14B-480P": [
1424
+ 2.57151496e05,
1425
+ -3.54229917e04,
1426
+ 1.40286849e03,
1427
+ -1.35890334e01,
1428
+ 1.32517977e-01,
1429
+ ],
1430
+ "Wan2.1-I2V-14B-720P": [
1431
+ 8.10705460e03,
1432
+ 2.13393892e03,
1433
+ -3.72934672e02,
1434
+ 1.66203073e01,
1435
+ -4.17769401e-02,
1436
+ ],
1437
+ }
1438
+ if model_id not in self.coefficients_dict:
1439
+ supported_model_ids = ", ".join(
1440
+ [i for i in self.coefficients_dict])
1441
+ raise ValueError(
1442
+ f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids})."
1443
+ )
1444
+ self.coefficients = self.coefficients_dict[model_id]
1445
+
1446
+ def check(self, dit: WanModel, x, t_mod):
1447
+ modulated_inp = t_mod.clone()
1448
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
1449
+ should_calc = True
1450
+ self.accumulated_rel_l1_distance = 0
1451
+ else:
1452
+ coefficients = self.coefficients
1453
+ rescale_func = np.poly1d(coefficients)
1454
+ self.accumulated_rel_l1_distance += rescale_func(
1455
+ (
1456
+ (modulated_inp - self.previous_modulated_input).abs().mean()
1457
+ / self.previous_modulated_input.abs().mean()
1458
+ )
1459
+ .cpu()
1460
+ .item()
1461
+ )
1462
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
1463
+ should_calc = False
1464
+ else:
1465
+ should_calc = True
1466
+ self.accumulated_rel_l1_distance = 0
1467
+ self.previous_modulated_input = modulated_inp
1468
+ self.step += 1
1469
+ if self.step == self.num_inference_steps:
1470
+ self.step = 0
1471
+ if should_calc:
1472
+ self.previous_hidden_states = x.clone()
1473
+ return not should_calc
1474
+
1475
+ def store(self, hidden_states):
1476
+ self.previous_residual = hidden_states - self.previous_hidden_states
1477
+ self.previous_hidden_states = None
1478
+
1479
+ def update(self, hidden_states):
1480
+ hidden_states = hidden_states + self.previous_residual
1481
+ return hidden_states
1482
+
1483
+
1484
+ class TemporalTiler_BCTHW:
1485
+ def __init__(self):
1486
+ pass
1487
+
1488
+ def build_1d_mask(self, length, left_bound, right_bound, border_width):
1489
+ x = torch.ones((length,))
1490
+ if not left_bound:
1491
+ x[:border_width] = (torch.arange(border_width) + 1) / border_width
1492
+ if not right_bound:
1493
+ x[-border_width:] = torch.flip(
1494
+ (torch.arange(border_width) + 1) / border_width, dims=(0,)
1495
+ )
1496
+ return x
1497
+
1498
+ def build_mask(self, data, is_bound, border_width):
1499
+ _, _, T, _, _ = data.shape
1500
+ t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
1501
+ mask = repeat(t, "T -> 1 1 T 1 1")
1502
+ return mask
1503
+
1504
+ def run(
1505
+ self,
1506
+ model_fn,
1507
+ sliding_window_size,
1508
+ sliding_window_stride,
1509
+ computation_device,
1510
+ computation_dtype,
1511
+ model_kwargs,
1512
+ tensor_names,
1513
+ batch_size=None,
1514
+ ):
1515
+ tensor_names = [
1516
+ tensor_name
1517
+ for tensor_name in tensor_names
1518
+ if model_kwargs.get(tensor_name) is not None
1519
+ ]
1520
+ tensor_dict = {
1521
+ tensor_name: model_kwargs[tensor_name] for tensor_name in tensor_names
1522
+ }
1523
+ B, C, T, H, W = tensor_dict[tensor_names[0]].shape
1524
+ if batch_size is not None:
1525
+ B *= batch_size
1526
+ data_device, data_dtype = (
1527
+ tensor_dict[tensor_names[0]].device,
1528
+ tensor_dict[tensor_names[0]].dtype,
1529
+ )
1530
+ value = torch.zeros(
1531
+ (B, C, T, H, W), device=data_device, dtype=data_dtype)
1532
+ weight = torch.zeros(
1533
+ (1, 1, T, 1, 1), device=data_device, dtype=data_dtype)
1534
+ for t in range(0, T, sliding_window_stride):
1535
+ if (
1536
+ t - sliding_window_stride >= 0
1537
+ and t - sliding_window_stride + sliding_window_size >= T
1538
+ ):
1539
+ continue
1540
+ t_ = min(t + sliding_window_size, T)
1541
+ model_kwargs.update(
1542
+ {
1543
+ tensor_name: tensor_dict[tensor_name][:, :, t:t_:, :].to(
1544
+ device=computation_device, dtype=computation_dtype
1545
+ )
1546
+ for tensor_name in tensor_names
1547
+ }
1548
+ )
1549
+ model_output = model_fn(**model_kwargs).to(
1550
+ device=data_device, dtype=data_dtype
1551
+ )
1552
+ mask = self.build_mask(
1553
+ model_output,
1554
+ is_bound=(t == 0, t_ == T),
1555
+ border_width=(sliding_window_size - sliding_window_stride,),
1556
+ ).to(device=data_device, dtype=data_dtype)
1557
+ value[:, :, t:t_, :, :] += model_output * mask
1558
+ weight[:, :, t:t_, :, :] += mask
1559
+ value /= weight
1560
+ model_kwargs.update(tensor_dict)
1561
+ return value
1562
+
1563
+
1564
+ def model_fn_wan_video(
1565
+ dit: WanModel,
1566
+ motion_controller: WanMotionControllerModel = None,
1567
+ vace: VaceWanModel = None,
1568
+ latents: torch.Tensor = None,
1569
+ timestep: torch.Tensor = None,
1570
+ context: torch.Tensor = None,
1571
+ clip_feature: Optional[torch.Tensor] = None,
1572
+ y: Optional[torch.Tensor] = None,
1573
+ reference_latents=None,
1574
+ vace_context=None,
1575
+ vace_scale=1.0,
1576
+ tea_cache: TeaCache = None,
1577
+ use_unified_sequence_parallel: bool = False,
1578
+ motion_bucket_id: Optional[torch.Tensor] = None,
1579
+ sliding_window_size: Optional[int] = None,
1580
+ sliding_window_stride: Optional[int] = None,
1581
+ cfg_merge: bool = False,
1582
+ use_gradient_checkpointing: bool = False,
1583
+ use_gradient_checkpointing_offload: bool = False,
1584
+ **kwargs,
1585
+ ):
1586
+ if sliding_window_size is not None and sliding_window_stride is not None:
1587
+ model_kwargs = dict(
1588
+ dit=dit,
1589
+ motion_controller=motion_controller,
1590
+ vace=vace,
1591
+ latents=latents,
1592
+ timestep=timestep,
1593
+ context=context,
1594
+ clip_feature=clip_feature,
1595
+ y=y,
1596
+ reference_latents=reference_latents,
1597
+ vace_context=vace_context,
1598
+ vace_scale=vace_scale,
1599
+ tea_cache=tea_cache,
1600
+ use_unified_sequence_parallel=use_unified_sequence_parallel,
1601
+ motion_bucket_id=motion_bucket_id,
1602
+ )
1603
+ return TemporalTiler_BCTHW().run(
1604
+ model_fn_wan_video,
1605
+ sliding_window_size,
1606
+ sliding_window_stride,
1607
+ latents.device,
1608
+ latents.dtype,
1609
+ model_kwargs=model_kwargs,
1610
+ tensor_names=["latents", "y"],
1611
+ batch_size=2 if cfg_merge else 1,
1612
+ )
1613
+
1614
+ if use_unified_sequence_parallel:
1615
+ import torch.distributed as dist
1616
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
1617
+ get_sequence_parallel_world_size,
1618
+ get_sp_group)
1619
+
1620
+ # x = latents
1621
+ # print(f"Receving x with shape{x.shape}")
1622
+ # print(f"timesteps {timestep}", end=" ")
1623
+ t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
1624
+ t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
1625
+ # print(f"t_mod shape: {t_mod.shape}")
1626
+ # print(f"first ten element{t_mod[0][:10]}")
1627
+ if motion_bucket_id is not None and motion_controller is not None:
1628
+ t_mod = t_mod + \
1629
+ motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
1630
+ context = dit.text_embedding(context)
1631
+
1632
+ # c_b, c_c, c_f, c_h, c_w = x.shape
1633
+
1634
+ # Merged cfg
1635
+ if latents.shape[0] != context.shape[0]:
1636
+ latents = torch.concat([latents] * context.shape[0], dim=0)
1637
+ # print(f"Merging x to shape {x.shape}")
1638
+
1639
+ if timestep.shape[0] != context.shape[0]:
1640
+ timestep = torch.concat([timestep] * context.shape[0], dim=0)
1641
+ # import pdb
1642
+ # pdb.set_trace()
1643
+ if dit.has_image_input:
1644
+ latents = torch.cat([latents, y], dim=1) # (b, c_x + c_y, f, h, w)
1645
+ clip_embdding = dit.img_emb(clip_feature)
1646
+ context = torch.cat([clip_embdding, context], dim=1)
1647
+
1648
+ latents, (f, h, w) = dit.patchify(latents, None)
1649
+ _shortcut = latents
1650
+ freqs = (
1651
+ torch.cat(
1652
+ [
1653
+ dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
1654
+ dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
1655
+ dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
1656
+ ],
1657
+ dim=-1,
1658
+ )
1659
+ .reshape(f * h * w, 1, -1)
1660
+ .to(latents.device)
1661
+ )
1662
+
1663
+ if tea_cache is not None:
1664
+ tea_cache_update = tea_cache.check(dit, latents, t_mod)
1665
+ else:
1666
+ tea_cache_update = False
1667
+
1668
+ if use_unified_sequence_parallel:
1669
+ if dist.is_initialized() and dist.get_world_size() > 1:
1670
+ latents = torch.chunk(latents, get_sequence_parallel_world_size(), dim=1)[
1671
+ get_sequence_parallel_rank()
1672
+ ]
1673
+
1674
+ if tea_cache_update:
1675
+ latents = tea_cache.update(latents)
1676
+ else:
1677
+ def create_custom_forward(module):
1678
+ def custom_forward(*inputs, **kwargs):
1679
+ return module(*inputs, **kwargs)
1680
+ return custom_forward
1681
+
1682
+ for idx, block in enumerate(dit.blocks):
1683
+ if use_gradient_checkpointing_offload:
1684
+ with torch.autograd.graph.save_on_cpu():
1685
+ latents = torch.utils.checkpoint.checkpoint(
1686
+ create_custom_forward(block),
1687
+ latents,
1688
+ context,
1689
+ t_mod,
1690
+ freqs,
1691
+ use_reentrant=False,
1692
+ )
1693
+ elif use_gradient_checkpointing:
1694
+ latents = torch.utils.checkpoint.checkpoint(
1695
+ create_custom_forward(block),
1696
+ latents,
1697
+ context,
1698
+ t_mod,
1699
+ freqs,
1700
+ use_reentrant=False,
1701
+ )
1702
+ else:
1703
+ latents = block(latents, context, t_mod, freqs)
1704
+
1705
+ if vace_context is not None and idx in vace.vace_layers_mapping:
1706
+ current_vace_hint = vace_hints[vace.vace_layers_mapping[idx]]
1707
+ if (
1708
+ use_unified_sequence_parallel
1709
+ and dist.is_initialized()
1710
+ and dist.get_world_size() > 1
1711
+ ):
1712
+ current_vace_hint = torch.chunk(
1713
+ current_vace_hint, get_sequence_parallel_world_size(), dim=1
1714
+ )[get_sequence_parallel_rank()]
1715
+ latents = latents + current_vace_hint * vace_scale
1716
+ if tea_cache is not None:
1717
+ tea_cache.store(latents)
1718
+
1719
+ latents = dit.head(latents, t)
1720
+
1721
+ if use_unified_sequence_parallel:
1722
+ if dist.is_initialized() and dist.get_world_size() > 1:
1723
+ latents = get_sp_group().all_gather(latents, dim=1)
1724
+ # Remove reference latents
1725
+ if reference_latents is not None:
1726
+ latents = latents[:, reference_latents.shape[1]:]
1727
+ f -= 1
1728
+
1729
+ latents = dit.unpatchify(latents, (f, h, w))
1730
+ return latents
diffsynth/schedulers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .ddim import EnhancedDDIMScheduler
2
+ from .continuous_ode import ContinuousODEScheduler
3
+ from .flow_match import FlowMatchScheduler
diffsynth/schedulers/continuous_ode.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class ContinuousODEScheduler():
5
+
6
+ def __init__(self, num_inference_steps=100, sigma_max=700.0, sigma_min=0.002, rho=7.0):
7
+ self.sigma_max = sigma_max
8
+ self.sigma_min = sigma_min
9
+ self.rho = rho
10
+ self.set_timesteps(num_inference_steps)
11
+
12
+
13
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, **kwargs):
14
+ ramp = torch.linspace(1-denoising_strength, 1, num_inference_steps)
15
+ min_inv_rho = torch.pow(torch.tensor((self.sigma_min,)), (1 / self.rho))
16
+ max_inv_rho = torch.pow(torch.tensor((self.sigma_max,)), (1 / self.rho))
17
+ self.sigmas = torch.pow(max_inv_rho + ramp * (min_inv_rho - max_inv_rho), self.rho)
18
+ self.timesteps = torch.log(self.sigmas) * 0.25
19
+
20
+
21
+ def step(self, model_output, timestep, sample, to_final=False):
22
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
23
+ sigma = self.sigmas[timestep_id]
24
+ sample *= (sigma*sigma + 1).sqrt()
25
+ estimated_sample = -sigma / (sigma*sigma + 1).sqrt() * model_output + 1 / (sigma*sigma + 1) * sample
26
+ if to_final or timestep_id + 1 >= len(self.timesteps):
27
+ prev_sample = estimated_sample
28
+ else:
29
+ sigma_ = self.sigmas[timestep_id + 1]
30
+ derivative = 1 / sigma * (sample - estimated_sample)
31
+ prev_sample = sample + derivative * (sigma_ - sigma)
32
+ prev_sample /= (sigma_*sigma_ + 1).sqrt()
33
+ return prev_sample
34
+
35
+
36
+ def return_to_timestep(self, timestep, sample, sample_stablized):
37
+ # This scheduler doesn't support this function.
38
+ pass
39
+
40
+
41
+ def add_noise(self, original_samples, noise, timestep):
42
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
43
+ sigma = self.sigmas[timestep_id]
44
+ sample = (original_samples + noise * sigma) / (sigma*sigma + 1).sqrt()
45
+ return sample
46
+
47
+
48
+ def training_target(self, sample, noise, timestep):
49
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
50
+ sigma = self.sigmas[timestep_id]
51
+ target = (-(sigma*sigma + 1).sqrt() / sigma + 1 / (sigma*sigma + 1).sqrt() / sigma) * sample + 1 / (sigma*sigma + 1).sqrt() * noise
52
+ return target
53
+
54
+
55
+ def training_weight(self, timestep):
56
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
57
+ sigma = self.sigmas[timestep_id]
58
+ weight = (1 + sigma*sigma).sqrt() / sigma
59
+ return weight
diffsynth/schedulers/ddim.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+
3
+
4
+ class EnhancedDDIMScheduler():
5
+
6
+ def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False):
7
+ self.num_train_timesteps = num_train_timesteps
8
+ if beta_schedule == "scaled_linear":
9
+ betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
10
+ elif beta_schedule == "linear":
11
+ betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
12
+ else:
13
+ raise NotImplementedError(f"{beta_schedule} is not implemented")
14
+ self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
15
+ if rescale_zero_terminal_snr:
16
+ self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
17
+ self.alphas_cumprod = self.alphas_cumprod.tolist()
18
+ self.set_timesteps(10)
19
+ self.prediction_type = prediction_type
20
+
21
+
22
+ def rescale_zero_terminal_snr(self, alphas_cumprod):
23
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
24
+
25
+ # Store old values.
26
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
27
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
28
+
29
+ # Shift so the last timestep is zero.
30
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
31
+
32
+ # Scale so the first timestep is back to the old value.
33
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
34
+
35
+ # Convert alphas_bar_sqrt to betas
36
+ alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
37
+
38
+ return alphas_bar
39
+
40
+
41
+ def set_timesteps(self, num_inference_steps, denoising_strength=1.0, **kwargs):
42
+ # The timesteps are aligned to 999...0, which is different from other implementations,
43
+ # but I think this implementation is more reasonable in theory.
44
+ max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
45
+ num_inference_steps = min(num_inference_steps, max_timestep + 1)
46
+ if num_inference_steps == 1:
47
+ self.timesteps = torch.Tensor([max_timestep])
48
+ else:
49
+ step_length = max_timestep / (num_inference_steps - 1)
50
+ self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
51
+
52
+
53
+ def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
54
+ if self.prediction_type == "epsilon":
55
+ weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
56
+ weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
57
+ prev_sample = sample * weight_x + model_output * weight_e
58
+ elif self.prediction_type == "v_prediction":
59
+ weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
60
+ weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
61
+ prev_sample = sample * weight_x + model_output * weight_e
62
+ else:
63
+ raise NotImplementedError(f"{self.prediction_type} is not implemented")
64
+ return prev_sample
65
+
66
+
67
+ def step(self, model_output, timestep, sample, to_final=False):
68
+ alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
69
+ if isinstance(timestep, torch.Tensor):
70
+ timestep = timestep.cpu()
71
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
72
+ if to_final or timestep_id + 1 >= len(self.timesteps):
73
+ alpha_prod_t_prev = 1.0
74
+ else:
75
+ timestep_prev = int(self.timesteps[timestep_id + 1])
76
+ alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
77
+
78
+ return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
79
+
80
+
81
+ def return_to_timestep(self, timestep, sample, sample_stablized):
82
+ alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
83
+ noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
84
+ return noise_pred
85
+
86
+
87
+ def add_noise(self, original_samples, noise, timestep):
88
+ sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
89
+ sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
90
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
91
+ return noisy_samples
92
+
93
+
94
+ def training_target(self, sample, noise, timestep):
95
+ if self.prediction_type == "epsilon":
96
+ return noise
97
+ else:
98
+ sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
99
+ sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
100
+ target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
101
+ return target
102
+
103
+
104
+ def training_weight(self, timestep):
105
+ return 1.0
diffsynth/schedulers/flow_match.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class FlowMatchScheduler:
5
+ def __init__(
6
+ self,
7
+ num_inference_steps=100,
8
+ num_train_timesteps=1000,
9
+ shift=3.0,
10
+ sigma_max=1.0,
11
+ sigma_min=0.003 / 1.002,
12
+ inverse_timesteps=False,
13
+ extra_one_step=False,
14
+ reverse_sigmas=False,
15
+ training_target='x',
16
+ training_weight_type='default'
17
+ ):
18
+ self.num_train_timesteps = num_train_timesteps
19
+ self.shift = shift
20
+ self.sigma_max = sigma_max
21
+ self.sigma_min = sigma_min
22
+ self.inverse_timesteps = inverse_timesteps
23
+ self.extra_one_step = extra_one_step
24
+ self.reverse_sigmas = reverse_sigmas
25
+ self.training_weight_type = training_weight_type
26
+
27
+ # Initialize basic attributes
28
+ self.target = None
29
+ self.timesteps = None
30
+ self.sigmas = None
31
+ self.linear_timesteps_weights = None
32
+ self.training = False
33
+
34
+ self.set_training_target(training_target=training_target)
35
+ self.set_training_weight(training_weight_type=training_weight_type)
36
+
37
+ def set_training_target(self, training_target='x'):
38
+ self.target = training_target
39
+
40
+ def set_training_weight(self, training_weight_type):
41
+ valid_types = ["default", "equal", "early", "late"]
42
+ assert training_weight_type in valid_types, \
43
+ f"training_weight_type must be one of {valid_types}"
44
+ self.training_weight_type = training_weight_type
45
+
46
+ def set_timesteps(
47
+ self,
48
+ num_inference_steps=100, # Kept for signature compatibility if needed
49
+ denoising_strength=1.0, # Kept for signature compatibility if needed
50
+ training=False,
51
+ shift=None,
52
+ denoise_step=0.5,
53
+ **kwargs
54
+ ):
55
+ if shift is not None:
56
+ self.shift = shift
57
+
58
+ self.training = training
59
+
60
+ # As requested: single value calculations
61
+ # timestep = 1000 * denoise_step
62
+ # sigma = timestep / 1000 (which simplifies to just denoise_step)
63
+ # weight = 1.0
64
+
65
+ ts_val = self.num_train_timesteps * denoise_step
66
+ sigma_val = ts_val / self.num_train_timesteps
67
+ weight_val = 1.795
68
+
69
+ # Create tensors with a single value
70
+ self.timesteps = torch.tensor([ts_val], dtype=torch.float32)
71
+ self.sigmas = torch.tensor([sigma_val], dtype=torch.float32)
72
+
73
+ if self.training:
74
+ self.linear_timesteps_weights = torch.tensor(
75
+ [weight_val], dtype=torch.float32)
76
+ else:
77
+ self.linear_timesteps_weights = None
78
+
79
+ def step(self, model_output, sample, to_final=False, **kwargs):
80
+ if self.target == 'x':
81
+ # print(f"use target x")
82
+ return model_output
83
+ elif self.target == 'flow':
84
+ return sample - model_output
85
+
86
+ def training_target(self, sample, noise, timestep):
87
+ if self.target == 'x':
88
+ # print(f"use target x for training")
89
+ return sample
90
+ elif self.target == 'flow':
91
+ target = noise - sample
92
+ return target
93
+
94
+ def training_weight(self, timestep):
95
+ # Since linear_timesteps_weights only has one value now,
96
+ # we can just return it.
97
+ # (Assuming the logic intends to fetch the unified weight)
98
+ if self.linear_timesteps_weights is not None:
99
+ return self.linear_timesteps_weights[0]
100
+ return 1.0
101
+
102
+
103
+ if __name__ == "__main__":
104
+ scheduler = FlowMatchScheduler()
105
+ scheduler.set_training_weight("default")
106
+ scheduler.set_timesteps(
107
+ num_inference_steps=1,
108
+ training=True,
109
+ schedule_mode="default",
110
+ denoise_step=1,
111
+ shift=5
112
+ )
113
+
114
+ for step, sigma, weight in zip(scheduler.timesteps, scheduler.sigmas, scheduler.linear_timesteps_weights):
115
+ print(
116
+ f"Step: {step.item()}, Sigma: {sigma.item()}, Weight: {weight.item()}")
diffsynth/util/alignment.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Bingxin Ke
2
+ # Last modified: 2024-01-11
3
+
4
+ import numpy as np
5
+ import torch
6
+ def align_depth_least_square_video(
7
+ gt_arr: np.ndarray,
8
+ pred_arr: np.ndarray,
9
+ valid_mask_arr: np.ndarray,
10
+ return_scale_shift=True,
11
+ max_resolution=None,
12
+ ):
13
+ """
14
+ gt_arr, pred_arr, valid_mask_arr: shape can be (T, H, W) or (T, 1, H, W)
15
+ """
16
+ ori_shape = pred_arr.shape
17
+ squeeze = lambda x: x.squeeze() # handle (T,1,H,W) -> (T,H,W)
18
+
19
+ gt = squeeze(gt_arr)
20
+ pred = squeeze(pred_arr)
21
+ valid_mask = squeeze(valid_mask_arr)
22
+
23
+ # -----------------------------
24
+ # Optional downsampling (applied per-frame identically)
25
+ # -----------------------------
26
+ if max_resolution is not None:
27
+ H, W = gt.shape[-2:]
28
+ scale_factor = np.min(max_resolution / np.array([H, W]))
29
+ if scale_factor < 1:
30
+ downscaler = torch.nn.Upsample(scale_factor=float(scale_factor), mode="nearest")
31
+
32
+ gt = downscaler(torch.as_tensor(gt).unsqueeze(1)).squeeze(1).numpy()
33
+ pred = downscaler(torch.as_tensor(pred).unsqueeze(1)).squeeze(1).numpy()
34
+ valid_mask = (
35
+ downscaler(torch.as_tensor(valid_mask).unsqueeze(1).float())
36
+ .squeeze(1).bool().numpy()
37
+ )
38
+
39
+ assert gt.shape == pred.shape == valid_mask.shape, f"{gt.shape}, {pred.shape}, {valid_mask.shape}"
40
+
41
+ # -----------------------------
42
+ # Flatten ALL frames
43
+ # -----------------------------
44
+ gt_masked = gt[valid_mask].reshape(-1, 1) # (N, 1)
45
+ pred_masked = pred[valid_mask].reshape(-1, 1) # (N, 1)
46
+
47
+ # -----------------------------
48
+ # Solve least squares over ALL pixels (T*H*W)
49
+ # -----------------------------
50
+ _ones = np.ones_like(pred_masked)
51
+ A = np.concatenate([pred_masked, _ones], axis=-1) # (N, 2)
52
+
53
+ X = np.linalg.lstsq(A, gt_masked, rcond=None)[0]
54
+ scale, shift = X
55
+
56
+ # Apply to original resolution (not the downsampled)
57
+ aligned_pred = pred_arr * scale + shift
58
+ aligned_pred = aligned_pred.reshape(ori_shape)
59
+
60
+ if return_scale_shift:
61
+ return aligned_pred, scale, shift
62
+ else:
63
+ return aligned_pred
64
+
65
+
66
+ def align_depth_least_square(
67
+ gt_arr: np.ndarray,
68
+ pred_arr: np.ndarray,
69
+ valid_mask_arr: np.ndarray,
70
+ return_scale_shift=True,
71
+ max_resolution=None,
72
+ ):
73
+ ori_shape = pred_arr.shape # input shape
74
+
75
+ gt = gt_arr.squeeze() # [H, W]
76
+ pred = pred_arr.squeeze()
77
+ valid_mask = valid_mask_arr.squeeze()
78
+
79
+ # Downsample
80
+ if max_resolution is not None:
81
+ scale_factor = np.min(max_resolution / np.array(ori_shape[-2:]))
82
+ if scale_factor < 1:
83
+ downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
84
+ gt = downscaler(torch.as_tensor(gt).unsqueeze(0)).numpy()
85
+ pred = downscaler(torch.as_tensor(pred).unsqueeze(0)).numpy()
86
+ valid_mask = (
87
+ downscaler(torch.as_tensor(valid_mask).unsqueeze(0).float())
88
+ .bool()
89
+ .numpy()
90
+ )
91
+
92
+ assert (
93
+ gt.shape == pred.shape == valid_mask.shape
94
+ ), f"{gt.shape}, {pred.shape}, {valid_mask.shape}"
95
+
96
+ gt_masked = gt[valid_mask].reshape((-1, 1))
97
+ pred_masked = pred[valid_mask].reshape((-1, 1))
98
+
99
+ # numpy solver
100
+ _ones = np.ones_like(pred_masked)
101
+ A = np.concatenate([pred_masked, _ones], axis=-1)
102
+ X = np.linalg.lstsq(A, gt_masked, rcond=None)[0]
103
+ scale, shift = X
104
+
105
+ aligned_pred = pred_arr * scale + shift
106
+
107
+ # restore dimensions
108
+ aligned_pred = aligned_pred.reshape(ori_shape)
109
+
110
+ if return_scale_shift:
111
+ return aligned_pred, scale, shift
112
+ else:
113
+ return aligned_pred
114
+
115
+
116
+ # ******************** disparity space ********************
117
+ def depth2disparity(depth, return_mask=False):
118
+ if isinstance(depth, torch.Tensor):
119
+ disparity = torch.zeros_like(depth)
120
+ elif isinstance(depth, np.ndarray):
121
+ disparity = np.zeros_like(depth)
122
+ non_negtive_mask = depth > 0
123
+ disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
124
+ if return_mask:
125
+ return disparity, non_negtive_mask
126
+ else:
127
+ return disparity
128
+
129
+
130
+ def disparity2depth(disparity, **kwargs):
131
+ return depth2disparity(disparity, **kwargs)
diffsynth/util/depth_transform.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Bingxin Ke
2
+ # Last modified: 2024-02-08
3
+
4
+ import torch
5
+
6
+
7
+ def get_depth_normalizer(cfg_normalizer):
8
+ if cfg_normalizer is None:
9
+
10
+ def identical(x):
11
+ return x
12
+
13
+ depth_transform = identical
14
+
15
+ elif "near_far_metric" == cfg_normalizer.type:
16
+ depth_transform = NearFarMetricNormalizer(
17
+ norm_min=cfg_normalizer.norm_min,
18
+ norm_max=cfg_normalizer.norm_max,
19
+ min_max_quantile=cfg_normalizer.min_max_quantile,
20
+ clip=cfg_normalizer.clip,
21
+ )
22
+ else:
23
+ raise NotImplementedError
24
+ return depth_transform
25
+
26
+
27
+ class DepthNormalizerBase:
28
+ is_relative = None
29
+ far_plane_at_max = None
30
+
31
+ def __init__(
32
+ self,
33
+ norm_min=-1.0,
34
+ norm_max=1.0,
35
+ ) -> None:
36
+ self.norm_min = norm_min
37
+ self.norm_max = norm_max
38
+ raise NotImplementedError
39
+
40
+ def __call__(self, depth, valid_mask=None, clip=None):
41
+ raise NotImplementedError
42
+
43
+ def denormalize(self, depth_norm, **kwargs):
44
+ # For metric depth: convert prediction back to metric depth
45
+ # For relative depth: convert prediction to [0, 1]
46
+ raise NotImplementedError
47
+
48
+
49
+ class NearFarMetricNormalizer(DepthNormalizerBase):
50
+ """
51
+ depth in [0, d_max] -> [-1, 1]
52
+ """
53
+
54
+ is_relative = True
55
+ far_plane_at_max = True
56
+
57
+ def __init__(
58
+ self, norm_min=-1.0, norm_max=1.0, min_max_quantile=0.02, clip=True
59
+ ) -> None:
60
+ self.norm_min = norm_min
61
+ self.norm_max = norm_max
62
+ self.norm_range = self.norm_max - self.norm_min
63
+ self.min_quantile = min_max_quantile
64
+ self.max_quantile = 1.0 - self.min_quantile
65
+ self.clip = clip
66
+
67
+ def __call__(self, depth_linear, valid_mask=None, clip=None):
68
+ clip = clip if clip is not None else self.clip
69
+
70
+ if valid_mask is None:
71
+ valid_mask = torch.ones_like(depth_linear).bool()
72
+ valid_mask = valid_mask & (depth_linear > 0)
73
+
74
+ # Take quantiles as min and max
75
+ _min, _max = torch.quantile(
76
+ depth_linear[valid_mask],
77
+ torch.tensor([self.min_quantile, self.max_quantile]),
78
+ )
79
+
80
+ # scale and shift
81
+ depth_norm_linear = (depth_linear - _min) / (
82
+ _max - _min
83
+ ) * self.norm_range + self.norm_min
84
+
85
+ if clip:
86
+ depth_norm_linear = torch.clip(
87
+ depth_norm_linear, self.norm_min, self.norm_max
88
+ )
89
+
90
+ return depth_norm_linear
91
+
92
+ def scale_back(self, depth_norm):
93
+ # scale to [0, 1]
94
+ depth_linear = (depth_norm - self.norm_min) / self.norm_range
95
+ return depth_linear
96
+
97
+ def denormalize(self, depth_norm, **kwargs):
98
+ return self.scale_back(depth_norm=depth_norm)
diffsynth/util/metric.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Bingxin Ke
2
+ # Last modified: 2024-02-15
3
+
4
+
5
+ import pandas as pd
6
+ import torch
7
+
8
+
9
+ # Adapted from: https://github.com/victoresque/pytorch-template/blob/master/utils/util.py
10
+ class MetricTracker:
11
+ def __init__(self, *keys, writer=None):
12
+ self.writer = writer
13
+ self._data = pd.DataFrame(
14
+ index=keys, columns=["total", "counts", "average"])
15
+ self.reset()
16
+
17
+ def reset(self):
18
+ for col in self._data.columns:
19
+ self._data[col].values[:] = 0
20
+
21
+ def update(self, key, value, n=1):
22
+ if self.writer is not None:
23
+ self.writer.add_scalar(key, value)
24
+ self._data.loc[key, "total"] += value * n
25
+ self._data.loc[key, "counts"] += n
26
+ self._data.loc[key, "average"] = self._data.total[key] / \
27
+ self._data.counts[key]
28
+
29
+ def avg(self, key):
30
+ return self._data.average[key]
31
+
32
+ def result(self):
33
+ return dict(self._data.average)
34
+
35
+
36
+ def pixel_mean(pred, gt, valid_mask):
37
+ if valid_mask is not None:
38
+ masked_pred = pred * valid_mask
39
+ masked_gt = gt * valid_mask
40
+
41
+ valid_pixel_count = torch.sum(valid_mask, dim=(0, 1))
42
+
43
+ pred_mean = torch.sum(masked_pred, dim=(0, 1)) / valid_pixel_count
44
+ gt_mean = torch.sum(masked_gt, dim=(0, 1)) / valid_pixel_count
45
+ else:
46
+ pred_mean = torch.mean(pred, dim=(0, 1))
47
+ gt_mean = torch.mean(gt, dim=(0, 1))
48
+
49
+ mean_difference = torch.abs(pred_mean - gt_mean)
50
+ return mean_difference
51
+
52
+
53
+ def pixel_var(pred, gt, valid_mask):
54
+ if valid_mask is not None:
55
+ masked_pred = pred * valid_mask
56
+ masked_gt = gt * valid_mask
57
+
58
+ valid_pixel_count = torch.sum(valid_mask, dim=(0, 1))
59
+
60
+ pred_mean = torch.sum(masked_pred, dim=(0, 1)) / valid_pixel_count
61
+ gt_mean = torch.sum(masked_gt, dim=(0, 1)) / valid_pixel_count
62
+
63
+ pred_var = torch.sum(valid_mask * (pred - pred_mean)
64
+ ** 2, dim=(0, 1)) / valid_pixel_count
65
+ gt_var = torch.sum(valid_mask * (gt - gt_mean)**2,
66
+ dim=(0, 1)) / valid_pixel_count
67
+ else:
68
+ pred_var = torch.var(pred, dim=(0, 1))
69
+ gt_var = torch.var(gt, dim=(0, 1))
70
+
71
+ var_difference = torch.abs(pred_var - gt_var)
72
+
73
+ return var_difference
74
+
75
+
76
+ def abs_relative_difference(output, target, valid_mask=None):
77
+ actual_output = output
78
+ actual_target = target
79
+ abs_relative_diff = torch.abs(
80
+ actual_output - actual_target) / actual_target
81
+ if valid_mask is not None:
82
+ abs_relative_diff[~valid_mask] = 0
83
+ n = valid_mask.sum((-1, -2))
84
+ else:
85
+ n = output.shape[-1] * output.shape[-2]
86
+ # print(f"total mask: {n}")
87
+ abs_relative_diff = torch.sum(abs_relative_diff, (-1, -2)) / n
88
+ # print(f"abs_relative_diff: {abs_relative_diff}")
89
+ return abs_relative_diff.mean()
90
+
91
+
92
+ def squared_relative_difference(output, target, valid_mask=None):
93
+ actual_output = output
94
+ actual_target = target
95
+ square_relative_diff = (
96
+ torch.pow(torch.abs(actual_output - actual_target), 2) / actual_target
97
+ )
98
+ if valid_mask is not None:
99
+ square_relative_diff[~valid_mask] = 0
100
+ n = valid_mask.sum((-1, -2))
101
+ else:
102
+ n = output.shape[-1] * output.shape[-2]
103
+ square_relative_diff = torch.sum(square_relative_diff, (-1, -2)) / n
104
+ return square_relative_diff.mean()
105
+
106
+
107
+ def rmse_linear(output, target, valid_mask=None):
108
+ actual_output = output
109
+ actual_target = target
110
+ diff = actual_output - actual_target
111
+ if valid_mask is not None:
112
+ diff[~valid_mask] = 0
113
+ n = valid_mask.sum((-1, -2))
114
+ else:
115
+ n = output.shape[-1] * output.shape[-2]
116
+ diff2 = torch.pow(diff, 2)
117
+ mse = torch.sum(diff2, (-1, -2)) / n
118
+ rmse = torch.sqrt(mse)
119
+ return rmse.mean()
120
+
121
+
122
+ def rmse_log(output, target, valid_mask=None):
123
+ diff = torch.log(output) - torch.log(target)
124
+ if valid_mask is not None:
125
+ diff[~valid_mask] = 0
126
+ n = valid_mask.sum((-1, -2))
127
+ else:
128
+ n = output.shape[-1] * output.shape[-2]
129
+ diff2 = torch.pow(diff, 2)
130
+ mse = torch.sum(diff2, (-1, -2)) / n # [B]
131
+ rmse = torch.sqrt(mse)
132
+ return rmse.mean()
133
+
134
+
135
+ def log10(output, target, valid_mask=None):
136
+ if valid_mask is not None:
137
+ diff = torch.abs(
138
+ torch.log10(output[valid_mask]) - torch.log10(target[valid_mask])
139
+ )
140
+ else:
141
+ diff = torch.abs(torch.log10(output) - torch.log10(target))
142
+ return diff.mean()
143
+
144
+
145
+ # adapt from: https://github.com/imran3180/depth-map-prediction/blob/master/main.py
146
+ def threshold_percentage(output, target, threshold_val, valid_mask=None):
147
+ d1 = output / target
148
+ d2 = target / output
149
+ max_d1_d2 = torch.max(d1, d2)
150
+ zero = torch.zeros(*output.shape)
151
+ one = torch.ones(*output.shape)
152
+ bit_mat = torch.where(max_d1_d2.cpu() < threshold_val, one, zero)
153
+ if valid_mask is not None:
154
+ bit_mat[~valid_mask] = 0
155
+ n = valid_mask.sum((-1, -2))
156
+ else:
157
+ n = output.shape[-1] * output.shape[-2]
158
+ count_mat = torch.sum(bit_mat, (-1, -2))
159
+ threshold_mat = count_mat / n.cpu()
160
+ return threshold_mat.mean()
161
+
162
+
163
+ def delta1_acc(pred, gt, valid_mask):
164
+ return threshold_percentage(pred, gt, 1.25, valid_mask)
165
+
166
+
167
+ def delta2_acc(pred, gt, valid_mask):
168
+ return threshold_percentage(pred, gt, 1.25**2, valid_mask)
169
+
170
+
171
+ def delta3_acc(pred, gt, valid_mask):
172
+ return threshold_percentage(pred, gt, 1.25**3, valid_mask)
173
+
174
+
175
+ def i_rmse(output, target, valid_mask=None):
176
+ output_inv = 1.0 / output
177
+ target_inv = 1.0 / target
178
+ diff = output_inv - target_inv
179
+ if valid_mask is not None:
180
+ diff[~valid_mask] = 0
181
+ n = valid_mask.sum((-1, -2))
182
+ else:
183
+ n = output.shape[-1] * output.shape[-2]
184
+ diff2 = torch.pow(diff, 2)
185
+ mse = torch.sum(diff2, (-1, -2)) / n # [B]
186
+ rmse = torch.sqrt(mse)
187
+ return rmse.mean()
188
+
189
+
190
+ def silog_rmse(depth_pred, depth_gt, valid_mask=None):
191
+ diff = torch.log(depth_pred) - torch.log(depth_gt)
192
+ if valid_mask is not None:
193
+ diff[~valid_mask] = 0
194
+ n = valid_mask.sum((-1, -2))
195
+ else:
196
+ n = depth_gt.shape[-2] * depth_gt.shape[-1]
197
+
198
+ diff2 = torch.pow(diff, 2)
199
+
200
+ first_term = torch.sum(diff2, (-1, -2)) / n
201
+ second_term = torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2)
202
+ loss = torch.sqrt(torch.mean(first_term - second_term)) * 100
203
+ return loss
204
+
205
+
206
+
207
+ def relative_temporal_diff(pred, gt, valid_mask=None, eps=1e-6):
208
+ """
209
+ pred, gt: [F, H, W]
210
+ valid_mask: [F, H, W] (bool)
211
+ """
212
+
213
+ # relative temporal difference
214
+ pred_rel = (pred[1:] - pred[:-1]) / (pred[:-1] + eps) # [F-1, H, W]
215
+ gt_rel = (gt[1:] - gt[:-1]) / (gt[:-1] + eps)
216
+
217
+ diff = pred_rel - gt_rel
218
+
219
+ if valid_mask is not None:
220
+ # AND 两帧 mask
221
+ valid_pair = valid_mask[1:] & valid_mask[:-1]
222
+ diff[~valid_pair] = 0
223
+ n = valid_pair.sum((-1, -2)) # [F-1]
224
+ else:
225
+ n = diff.shape[-1] * diff.shape[-2]
226
+
227
+ diff2 = diff ** 2
228
+ # diff1 = torch.abs(diff)
229
+ # l1 = torch.sum(diff1, (-1, -2)) / (n + eps)
230
+ mse = torch.sum(diff2, (-1, -2)) / n # [B]
231
+ rmse = torch.sqrt(mse)
232
+ return rmse.mean()
233
+ # return rmse.mean()
234
+
235
+
236
+ def boundary_metrics(pred_depth, rgb, valid_mask=None,
237
+ th_depth_ratio=1.05, th_rgb_grad=0.15,
238
+ tolerance=1, eps=1e-6):
239
+ import torch
240
+ import torch.nn.functional as F
241
+
242
+ device = pred_depth.device
243
+
244
+ pred_depth, valid_mask = pred_depth.unsqueeze(1), valid_mask.unsqueeze(1)
245
+
246
+ if rgb.shape[1] == 3:
247
+ gray = 0.299 * rgb[:, 0:1] + 0.587 * rgb[:, 1:2] + 0.114 * rgb[:, 2:3]
248
+ else:
249
+ gray = rgb
250
+
251
+ sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
252
+ device=device, dtype=rgb.dtype).view(1, 1, 3, 3)
253
+ sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
254
+ device=device, dtype=rgb.dtype).view(1, 1, 3, 3)
255
+
256
+ gx = F.conv2d(gray, sobel_x, padding=1)
257
+ gy = F.conv2d(gray, sobel_y, padding=1)
258
+ mag = torch.sqrt(gx**2 + gy**2 + eps)
259
+
260
+ B = mag.shape[0]
261
+ mag_flat = mag.view(B, -1)
262
+ mag_min = mag_flat.min(dim=1, keepdim=True)[0].view(B, 1, 1, 1)
263
+ mag_max = mag_flat.max(dim=1, keepdim=True)[0].view(B, 1, 1, 1)
264
+ mag_norm = (mag - mag_min) / (mag_max - mag_min + eps)
265
+
266
+ edges_gt = (mag_norm > th_rgb_grad).float()
267
+
268
+ d = pred_depth.clamp(min=eps)
269
+
270
+ def get_edge_with_nms(ratio_map, dim):
271
+ is_candidate = ratio_map > th_depth_ratio
272
+
273
+ if dim == 3:
274
+ k_size, pad = (1, 3), (0, 1)
275
+ else:
276
+ k_size, pad = (3, 1), (1, 0)
277
+
278
+ local_max = F.max_pool2d(
279
+ ratio_map, kernel_size=k_size, stride=1, padding=pad)
280
+ is_peak = (ratio_map == local_max)
281
+
282
+ return is_candidate & is_peak
283
+
284
+ d_pad = F.pad(d, (1, 1, 1, 1), mode='replicate') # [B, 1, H+2, W+2]
285
+ d_center = d
286
+
287
+ # Right: d(x+1, y) / d(x, y)
288
+ ratio_right = d_pad[:, :, 1:-1, 2:] / d_center
289
+ mask_right = get_edge_with_nms(ratio_right, dim=3)
290
+
291
+ # Left: d(x-1, y) / d(x, y)
292
+ ratio_left = d_pad[:, :, 1:-1, :-2] / d_center
293
+ mask_left = get_edge_with_nms(ratio_left, dim=3)
294
+
295
+ # Bottom: d(x, y+1) / d(x, y)
296
+ ratio_bottom = d_pad[:, :, 2:, 1:-1] / d_center
297
+ mask_bottom = get_edge_with_nms(ratio_bottom, dim=2)
298
+
299
+ # Top: d(x, y-1) / d(x, y)
300
+ ratio_top = d_pad[:, :, :-2, 1:-1] / d_center
301
+ mask_top = get_edge_with_nms(ratio_top, dim=2)
302
+
303
+ edges_pred = (mask_right | mask_left | mask_bottom | mask_top).float()
304
+
305
+ if valid_mask is not None:
306
+ edges_gt = edges_gt * valid_mask
307
+ edges_pred = edges_pred * valid_mask
308
+
309
+ if tolerance > 0:
310
+ kernel_size = 2 * tolerance + 1
311
+ edges_gt_dilated = F.max_pool2d(
312
+ edges_gt, kernel_size=kernel_size, stride=1, padding=tolerance)
313
+ edges_pred_dilated = F.max_pool2d(
314
+ edges_pred, kernel_size=kernel_size, stride=1, padding=tolerance)
315
+ else:
316
+ edges_gt_dilated = edges_gt
317
+ edges_pred_dilated = edges_pred
318
+
319
+ # True Positives
320
+ tp_prec = (edges_pred * edges_gt_dilated).sum()
321
+ tp_rec = (edges_gt * edges_pred_dilated).sum()
322
+
323
+ # Totals
324
+ n_pred = edges_pred.sum()
325
+ n_gt = edges_gt.sum()
326
+
327
+ precision = tp_prec / (n_pred + eps)
328
+ recall = tp_rec / (n_gt + eps)
329
+ f1_score = 2 * precision * recall / (precision + recall + eps)
330
+
331
+ return {
332
+ "f1": f1_score.item(),
333
+ "precision": precision.item(),
334
+ "recall": recall.item(),
335
+ # "edges_pred": edges_pred,
336
+ # "edges_gt": edges_gt
337
+ }
diffsynth/util/normal_utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.distributed as dist
6
+
7
+
8
+ def get_padding(orig_H, orig_W):
9
+ """ returns how the input of shape (orig_H, orig_W) should be padded
10
+ this ensures that both H and W are divisible by 32
11
+ """
12
+ if orig_W % 32 == 0:
13
+ l = 0
14
+ r = 0
15
+ else:
16
+ new_W = 32 * ((orig_W // 32) + 1)
17
+ l = (new_W - orig_W) // 2
18
+ r = (new_W - orig_W) - l
19
+
20
+ if orig_H % 32 == 0:
21
+ t = 0
22
+ b = 0
23
+ else:
24
+ new_H = 32 * ((orig_H // 32) + 1)
25
+ t = (new_H - orig_H) // 2
26
+ b = (new_H - orig_H) - t
27
+ return l, r, t, b
28
+
29
+ def pad_input(img, intrins, lrtb=(0,0,0,0)):
30
+ """ pad input image
31
+ img should be a torch tensor of shape (B, 3, H, W)
32
+ intrins should be a torch tensor of shape (B, 3, 3)
33
+ """
34
+ l, r, t, b = lrtb
35
+ if l+r+t+b != 0:
36
+ pad_value_R = (0 - 0.485) / 0.229
37
+ pad_value_G = (0 - 0.456) / 0.224
38
+ pad_value_B = (0 - 0.406) / 0.225
39
+
40
+ img_R = F.pad(img[:,0:1,:,:], (l, r, t, b), mode="constant", value=pad_value_R)
41
+ img_G = F.pad(img[:,1:2,:,:], (l, r, t, b), mode="constant", value=pad_value_G)
42
+ img_B = F.pad(img[:,2:3,:,:], (l, r, t, b), mode="constant", value=pad_value_B)
43
+
44
+ img = torch.cat([img_R, img_G, img_B], dim=1)
45
+
46
+ if intrins is not None:
47
+ intrins[:, 0, 2] += l
48
+ intrins[:, 1, 2] += t
49
+ return img, intrins
50
+
51
+ def compute_normal_error(pred_norm, gt_norm):
52
+ """ compute per-pixel surface normal error in degrees
53
+ NOTE: pred_norm and gt_norm should be torch tensors of shape (B, 3, ...)
54
+ """
55
+ pred_error = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
56
+ pred_error = torch.clamp(pred_error, min=-1.0, max=1.0)
57
+ pred_error = torch.acos(pred_error) * 180.0 / np.pi
58
+ pred_error = pred_error.unsqueeze(1) # (B, 1, ...)
59
+ return pred_error
60
+
61
+ def compute_normal_metrics(total_normal_errors):
62
+ """ compute surface normal metrics (used for benchmarking)
63
+ NOTE: total_normal_errors should be a 1D torch tensor of errors in degrees
64
+ """
65
+ total_normal_errors = total_normal_errors.detach().cpu().numpy()
66
+ num_pixels = total_normal_errors.shape[0]
67
+
68
+ metrics = {
69
+ 'mean': np.average(total_normal_errors),
70
+ 'median': np.median(total_normal_errors),
71
+ 'rmse': np.sqrt(np.sum(total_normal_errors * total_normal_errors) / num_pixels),
72
+ 'a1': 100.0 * (np.sum(total_normal_errors < 5) / num_pixels),
73
+ 'a2': 100.0 * (np.sum(total_normal_errors < 7.5) / num_pixels),
74
+ 'a3': 100.0 * (np.sum(total_normal_errors < 11.25) / num_pixels),
75
+ 'a4': 100.0 * (np.sum(total_normal_errors < 22.5) / num_pixels),
76
+ 'a5': 100.0 * (np.sum(total_normal_errors < 30) / num_pixels)
77
+ }
78
+ return metrics
diffsynth/util/seed_all.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ import numpy as np
22
+ import random
23
+ import torch
24
+
25
+
26
+ def seed_all(seed: int = 0):
27
+ """
28
+ Set random seeds of all components.
29
+ """
30
+ random.seed(seed)
31
+ np.random.seed(seed)
32
+ torch.manual_seed(seed)
33
+ torch.cuda.manual_seed_all(seed)
diffsynth/vram_management/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .layers import *
2
+ from .gradient_checkpointing import *
diffsynth/vram_management/gradient_checkpointing.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def create_custom_forward(module):
5
+ def custom_forward(*inputs, **kwargs):
6
+ return module(*inputs, **kwargs)
7
+ return custom_forward
8
+
9
+
10
+ def gradient_checkpoint_forward(
11
+ model,
12
+ use_gradient_checkpointing,
13
+ use_gradient_checkpointing_offload,
14
+ *args,
15
+ **kwargs,
16
+ ):
17
+ if use_gradient_checkpointing_offload:
18
+ with torch.autograd.graph.save_on_cpu():
19
+ model_output = torch.utils.checkpoint.checkpoint(
20
+ create_custom_forward(model),
21
+ *args,
22
+ **kwargs,
23
+ use_reentrant=False,
24
+ )
25
+ elif use_gradient_checkpointing:
26
+ model_output = torch.utils.checkpoint.checkpoint(
27
+ create_custom_forward(model),
28
+ *args,
29
+ **kwargs,
30
+ use_reentrant=False,
31
+ )
32
+ else:
33
+ model_output = model(*args, **kwargs)
34
+ return model_output
diffsynth/vram_management/layers.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, copy
2
+ from ..models.utils import init_weights_on_device
3
+
4
+
5
+ def cast_to(weight, dtype, device):
6
+ r = torch.empty_like(weight, dtype=dtype, device=device)
7
+ r.copy_(weight)
8
+ return r
9
+
10
+
11
+ class AutoTorchModule(torch.nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def check_free_vram(self):
16
+ gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
17
+ used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024 ** 3)
18
+ return used_memory < self.vram_limit
19
+
20
+ def offload(self):
21
+ if self.state != 0:
22
+ self.to(dtype=self.offload_dtype, device=self.offload_device)
23
+ self.state = 0
24
+
25
+ def onload(self):
26
+ if self.state != 1:
27
+ self.to(dtype=self.onload_dtype, device=self.onload_device)
28
+ self.state = 1
29
+
30
+ def keep(self):
31
+ if self.state != 2:
32
+ self.to(dtype=self.computation_dtype, device=self.computation_device)
33
+ self.state = 2
34
+
35
+
36
+ class AutoWrappedModule(AutoTorchModule):
37
+ def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs):
38
+ super().__init__()
39
+ self.module = module.to(dtype=offload_dtype, device=offload_device)
40
+ self.offload_dtype = offload_dtype
41
+ self.offload_device = offload_device
42
+ self.onload_dtype = onload_dtype
43
+ self.onload_device = onload_device
44
+ self.computation_dtype = computation_dtype
45
+ self.computation_device = computation_device
46
+ self.vram_limit = vram_limit
47
+ self.state = 0
48
+
49
+ def forward(self, *args, **kwargs):
50
+ if self.state == 2:
51
+ module = self.module
52
+ else:
53
+ if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
54
+ module = self.module
55
+ elif self.vram_limit is not None and self.check_free_vram():
56
+ self.keep()
57
+ module = self.module
58
+ else:
59
+ module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
60
+ return module(*args, **kwargs)
61
+
62
+
63
+ class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule):
64
+ def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs):
65
+ with init_weights_on_device(device=torch.device("meta")):
66
+ super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
67
+ self.weight = module.weight
68
+ self.bias = module.bias
69
+ self.offload_dtype = offload_dtype
70
+ self.offload_device = offload_device
71
+ self.onload_dtype = onload_dtype
72
+ self.onload_device = onload_device
73
+ self.computation_dtype = computation_dtype
74
+ self.computation_device = computation_device
75
+ self.vram_limit = vram_limit
76
+ self.state = 0
77
+
78
+ def forward(self, x, *args, **kwargs):
79
+ if self.state == 2:
80
+ weight, bias = self.weight, self.bias
81
+ else:
82
+ if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
83
+ weight, bias = self.weight, self.bias
84
+ elif self.vram_limit is not None and self.check_free_vram():
85
+ self.keep()
86
+ weight, bias = self.weight, self.bias
87
+ else:
88
+ weight = None if self.weight is None else cast_to(self.weight, self.computation_dtype, self.computation_device)
89
+ bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
90
+ with torch.amp.autocast(device_type=x.device.type):
91
+ x = torch.nn.functional.layer_norm(x.float(), self.normalized_shape, weight, bias, self.eps).type_as(x)
92
+ return x
93
+
94
+
95
+ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
96
+ def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, name="", **kwargs):
97
+ with init_weights_on_device(device=torch.device("meta")):
98
+ super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
99
+ self.weight = module.weight
100
+ self.bias = module.bias
101
+ self.offload_dtype = offload_dtype
102
+ self.offload_device = offload_device
103
+ self.onload_dtype = onload_dtype
104
+ self.onload_device = onload_device
105
+ self.computation_dtype = computation_dtype
106
+ self.computation_device = computation_device
107
+ self.vram_limit = vram_limit
108
+ self.state = 0
109
+ self.name = name
110
+ self.lora_A_weights = []
111
+ self.lora_B_weights = []
112
+ self.lora_merger = None
113
+
114
+ def forward(self, x, *args, **kwargs):
115
+ if self.state == 2:
116
+ weight, bias = self.weight, self.bias
117
+ else:
118
+ if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
119
+ weight, bias = self.weight, self.bias
120
+ elif self.vram_limit is not None and self.check_free_vram():
121
+ self.keep()
122
+ weight, bias = self.weight, self.bias
123
+ else:
124
+ weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
125
+ bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
126
+ out = torch.nn.functional.linear(x, weight, bias)
127
+
128
+ if len(self.lora_A_weights) == 0:
129
+ # No LoRA
130
+ return out
131
+ elif self.lora_merger is None:
132
+ # Native LoRA inference
133
+ for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
134
+ out = out + x @ lora_A.T @ lora_B.T
135
+ else:
136
+ # LoRA fusion
137
+ lora_output = []
138
+ for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
139
+ lora_output.append(x @ lora_A.T @ lora_B.T)
140
+ lora_output = torch.stack(lora_output)
141
+ out = self.lora_merger(out, lora_output)
142
+ return out
143
+
144
+
145
+ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""):
146
+ for name, module in model.named_children():
147
+ layer_name = name if name_prefix == "" else name_prefix + "." + name
148
+ for source_module, target_module in module_map.items():
149
+ if isinstance(module, source_module):
150
+ num_param = sum(p.numel() for p in module.parameters())
151
+ if max_num_param is not None and total_num_param + num_param > max_num_param:
152
+ module_config_ = overflow_module_config
153
+ else:
154
+ module_config_ = module_config
155
+ module_ = target_module(module, **module_config_, vram_limit=vram_limit, name=layer_name)
156
+ setattr(model, name, module_)
157
+ total_num_param += num_param
158
+ break
159
+ else:
160
+ total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param, vram_limit=vram_limit, name_prefix=layer_name)
161
+ return total_num_param
162
+
163
+
164
+ def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, vram_limit=None):
165
+ enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0, vram_limit=vram_limit)
166
+ model.vram_management_enabled = True
167
+
examples/__init__.py ADDED
File without changes
examples/dataset/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .hypersim_dataset import HypersimDataset
2
+ from .video_dataset.kitti_vid_dataset import KITTI_VID_Dataset
3
+ from .video_dataset.nyuv2_dataset import NYUv2Dataset
4
+ from .video_dataset.scannet_dataset import Scannet_VID_Dataset
5
+ from .video_dataset.tartanair_vid_dataset import TartanAir_VID_Dataset
6
+ from .video_dataset.vkitti_vid_dataset import VKITTI_VID_Dataset
7
+ from .vkitti_dataset import VKITTIDataset
8
+
9
+ __all__ = [
10
+ "HypersimDataset",
11
+ "KITTI_VID_Dataset",
12
+ "VKITTI_VID_Dataset",
13
+ "TartanAir_VID_Dataset",
14
+ "NYUv2Dataset",
15
+ "VKITTIDataset",
16
+ 'Scannet_VID_Dataset'
17
+ ]