lsmpp commited on
Commit
e31d39e
·
verified ·
1 Parent(s): 4960ef6

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. diffusers/.github/ISSUE_TEMPLATE/config.yml +4 -0
  2. diffusers/.github/ISSUE_TEMPLATE/feedback.md +12 -0
  3. diffusers/.github/ISSUE_TEMPLATE/translate.md +29 -0
  4. diffusers/benchmarks/README.md +69 -0
  5. diffusers/benchmarks/__init__.py +0 -0
  6. diffusers/benchmarks/benchmarking_flux.py +98 -0
  7. diffusers/benchmarks/benchmarking_ltx.py +80 -0
  8. diffusers/benchmarks/benchmarking_sdxl.py +82 -0
  9. diffusers/benchmarks/benchmarking_utils.py +244 -0
  10. diffusers/benchmarks/benchmarking_wan.py +74 -0
  11. diffusers/benchmarks/populate_into_db.py +166 -0
  12. diffusers/benchmarks/push_results.py +76 -0
  13. diffusers/benchmarks/requirements.txt +6 -0
  14. diffusers/benchmarks/run_all.py +84 -0
  15. diffusers/examples/README.md +70 -0
  16. diffusers/examples/conftest.py +45 -0
  17. diffusers/examples/test_examples_utils.py +61 -0
  18. diffusers/scripts/__init__.py +0 -0
  19. diffusers/scripts/change_naming_configs_and_checkpoints.py +113 -0
  20. diffusers/scripts/convert_amused.py +523 -0
  21. diffusers/scripts/convert_animatediff_motion_module_to_diffusers.py +62 -0
  22. diffusers/scripts/convert_animatediff_sparsectrl_to_diffusers.py +83 -0
  23. diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py +184 -0
  24. diffusers/scripts/convert_aura_flow_to_diffusers.py +131 -0
  25. diffusers/scripts/convert_blipdiffusion_to_diffusers.py +344 -0
  26. diffusers/scripts/convert_cogview3_to_diffusers.py +242 -0
  27. diffusers/scripts/convert_cogview4_to_diffusers.py +254 -0
  28. diffusers/scripts/convert_cogview4_to_diffusers_megatron.py +384 -0
  29. diffusers/scripts/convert_consistency_to_diffusers.py +315 -0
  30. diffusers/scripts/convert_cosmos_to_diffusers.py +506 -0
  31. diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py +431 -0
  32. diffusers/scripts/convert_diffusers_to_original_sdxl.py +350 -0
  33. diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py +353 -0
  34. diffusers/scripts/convert_dit_to_diffusers.py +162 -0
  35. diffusers/scripts/convert_flux_to_diffusers.py +308 -0
  36. diffusers/scripts/convert_gligen_to_diffusers.py +581 -0
  37. diffusers/scripts/convert_hunyuan_video_to_diffusers.py +353 -0
  38. diffusers/scripts/convert_hunyuandit_to_diffusers.py +266 -0
  39. diffusers/scripts/convert_k_upscaler_to_diffusers.py +297 -0
  40. diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py +1159 -0
  41. diffusers/scripts/convert_kandinsky3_unet.py +98 -0
  42. diffusers/scripts/convert_kandinsky_to_diffusers.py +1411 -0
  43. diffusers/scripts/convert_ldm_original_checkpoint_to_diffusers.py +359 -0
  44. diffusers/scripts/convert_ltx_to_diffusers.py +516 -0
  45. diffusers/scripts/convert_lumina_to_diffusers.py +142 -0
  46. diffusers/scripts/convert_mochi_to_diffusers.py +463 -0
  47. diffusers/scripts/convert_models_diffuser_to_diffusers.py +100 -0
  48. diffusers/scripts/convert_ms_text_to_video_to_diffusers.py +428 -0
  49. diffusers/scripts/convert_music_spectrogram_to_diffusers.py +203 -0
  50. diffusers/scripts/convert_original_audioldm_to_diffusers.py +1042 -0
diffusers/.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ contact_links:
2
+ - name: Questions / Discussions
3
+ url: https://github.com/huggingface/diffusers/discussions
4
+ about: General usage questions and community discussions
diffusers/.github/ISSUE_TEMPLATE/feedback.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "💬 Feedback about API Design"
3
+ about: Give feedback about the current API design
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **What API design would you like to have changed or added to the library? Why?**
11
+
12
+ **What use case would this enable or better enable? Can you give us a code example?**
diffusers/.github/ISSUE_TEMPLATE/translate.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🌐 Translating a New Language?
3
+ about: Start a new translation effort in your language
4
+ title: '[<languageCode>] Translating docs to <languageName>'
5
+ labels: WIP
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ <!--
11
+ Note: Please search to see if an issue already exists for the language you are trying to translate.
12
+ -->
13
+
14
+ Hi!
15
+
16
+ Let's bring the documentation to all the <languageName>-speaking community 🌐.
17
+
18
+ Who would want to translate? Please follow the 🤗 [TRANSLATING guide](https://github.com/huggingface/diffusers/blob/main/docs/TRANSLATING.md). Here is a list of the files ready for translation. Let us know in this issue if you'd like to translate any, and we'll add your name to the list.
19
+
20
+ Some notes:
21
+
22
+ * Please translate using an informal tone (imagine you are talking with a friend about Diffusers 🤗).
23
+ * Please translate in a gender-neutral way.
24
+ * Add your translations to the folder called `<languageCode>` inside the [source folder](https://github.com/huggingface/diffusers/tree/main/docs/source).
25
+ * Register your translation in `<languageCode>/_toctree.yml`; please follow the order of the [English version](https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml).
26
+ * Once you're finished, open a pull request and tag this issue by including #issue-number in the description, where issue-number is the number of this issue. Please ping @stevhliu for review.
27
+ * 🙋 If you'd like others to help you with the translation, you can also post in the 🤗 [forums](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63).
28
+
29
+ Thank you so much for your help! 🤗
diffusers/benchmarks/README.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffusers Benchmarks
2
+
3
+ Welcome to Diffusers Benchmarks. These benchmarks are use to obtain latency and memory information of the most popular models across different scenarios such as:
4
+
5
+ * Base case i.e., when using `torch.bfloat16` and `torch.nn.functional.scaled_dot_product_attention`.
6
+ * Base + `torch.compile()`
7
+ * NF4 quantization
8
+ * Layerwise upcasting
9
+
10
+ Instead of full diffusion pipelines, only the forward pass of the respective model classes (such as `FluxTransformer2DModel`) is tested with the real checkpoints (such as `"black-forest-labs/FLUX.1-dev"`).
11
+
12
+ The entrypoint to running all the currently available benchmarks is in `run_all.py`. However, one can run the individual benchmarks, too, e.g., `python benchmarking_flux.py`. It should produce a CSV file containing various information about the benchmarks run.
13
+
14
+ The benchmarks are run on a weekly basis and the CI is defined in [benchmark.yml](../.github/workflows/benchmark.yml).
15
+
16
+ ## Running the benchmarks manually
17
+
18
+ First set up `torch` and install `diffusers` from the root of the directory:
19
+
20
+ ```py
21
+ pip install -e ".[quality,test]"
22
+ ```
23
+
24
+ Then make sure the other dependencies are installed:
25
+
26
+ ```sh
27
+ cd benchmarks/
28
+ pip install -r requirements.txt
29
+ ```
30
+
31
+ We need to be authenticated to access some of the checkpoints used during benchmarking:
32
+
33
+ ```sh
34
+ huggingface-cli login
35
+ ```
36
+
37
+ We use an L40 GPU with 128GB RAM to run the benchmark CI. As such, the benchmarks are configured to run on NVIDIA GPUs. So, make sure you have access to a similar machine (or modify the benchmarking scripts accordingly).
38
+
39
+ Then you can either launch the entire benchmarking suite by running:
40
+
41
+ ```sh
42
+ python run_all.py
43
+ ```
44
+
45
+ Or, you can run the individual benchmarks.
46
+
47
+ ## Customizing the benchmarks
48
+
49
+ We define "scenarios" to cover the most common ways in which these models are used. You can
50
+ define a new scenario, modifying an existing benchmark file:
51
+
52
+ ```py
53
+ BenchmarkScenario(
54
+ name=f"{CKPT_ID}-bnb-8bit",
55
+ model_cls=FluxTransformer2DModel,
56
+ model_init_kwargs={
57
+ "pretrained_model_name_or_path": CKPT_ID,
58
+ "torch_dtype": torch.bfloat16,
59
+ "subfolder": "transformer",
60
+ "quantization_config": BitsAndBytesConfig(load_in_8bit=True),
61
+ },
62
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
63
+ model_init_fn=model_init_fn,
64
+ )
65
+ ```
66
+
67
+ You can also configure a new model-level benchmark and add it to the existing suite. To do so, just defining a valid benchmarking file like `benchmarking_flux.py` should be enough.
68
+
69
+ Happy benchmarking 🧨
diffusers/benchmarks/__init__.py ADDED
File without changes
diffusers/benchmarks/benchmarking_flux.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
5
+
6
+ from diffusers import BitsAndBytesConfig, FluxTransformer2DModel
7
+ from diffusers.utils.testing_utils import torch_device
8
+
9
+
10
+ CKPT_ID = "black-forest-labs/FLUX.1-dev"
11
+ RESULT_FILENAME = "flux.csv"
12
+
13
+
14
+ def get_input_dict(**device_dtype_kwargs):
15
+ # resolution: 1024x1024
16
+ # maximum sequence length 512
17
+ hidden_states = torch.randn(1, 4096, 64, **device_dtype_kwargs)
18
+ encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)
19
+ pooled_prompt_embeds = torch.randn(1, 768, **device_dtype_kwargs)
20
+ image_ids = torch.ones(512, 3, **device_dtype_kwargs)
21
+ text_ids = torch.ones(4096, 3, **device_dtype_kwargs)
22
+ timestep = torch.tensor([1.0], **device_dtype_kwargs)
23
+ guidance = torch.tensor([1.0], **device_dtype_kwargs)
24
+
25
+ return {
26
+ "hidden_states": hidden_states,
27
+ "encoder_hidden_states": encoder_hidden_states,
28
+ "img_ids": image_ids,
29
+ "txt_ids": text_ids,
30
+ "pooled_projections": pooled_prompt_embeds,
31
+ "timestep": timestep,
32
+ "guidance": guidance,
33
+ }
34
+
35
+
36
+ if __name__ == "__main__":
37
+ scenarios = [
38
+ BenchmarkScenario(
39
+ name=f"{CKPT_ID}-bf16",
40
+ model_cls=FluxTransformer2DModel,
41
+ model_init_kwargs={
42
+ "pretrained_model_name_or_path": CKPT_ID,
43
+ "torch_dtype": torch.bfloat16,
44
+ "subfolder": "transformer",
45
+ },
46
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
47
+ model_init_fn=model_init_fn,
48
+ compile_kwargs={"fullgraph": True},
49
+ ),
50
+ BenchmarkScenario(
51
+ name=f"{CKPT_ID}-bnb-nf4",
52
+ model_cls=FluxTransformer2DModel,
53
+ model_init_kwargs={
54
+ "pretrained_model_name_or_path": CKPT_ID,
55
+ "torch_dtype": torch.bfloat16,
56
+ "subfolder": "transformer",
57
+ "quantization_config": BitsAndBytesConfig(
58
+ load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4"
59
+ ),
60
+ },
61
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
62
+ model_init_fn=model_init_fn,
63
+ ),
64
+ BenchmarkScenario(
65
+ name=f"{CKPT_ID}-layerwise-upcasting",
66
+ model_cls=FluxTransformer2DModel,
67
+ model_init_kwargs={
68
+ "pretrained_model_name_or_path": CKPT_ID,
69
+ "torch_dtype": torch.bfloat16,
70
+ "subfolder": "transformer",
71
+ },
72
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
73
+ model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
74
+ ),
75
+ BenchmarkScenario(
76
+ name=f"{CKPT_ID}-group-offload-leaf",
77
+ model_cls=FluxTransformer2DModel,
78
+ model_init_kwargs={
79
+ "pretrained_model_name_or_path": CKPT_ID,
80
+ "torch_dtype": torch.bfloat16,
81
+ "subfolder": "transformer",
82
+ },
83
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
84
+ model_init_fn=partial(
85
+ model_init_fn,
86
+ group_offload_kwargs={
87
+ "onload_device": torch_device,
88
+ "offload_device": torch.device("cpu"),
89
+ "offload_type": "leaf_level",
90
+ "use_stream": True,
91
+ "non_blocking": True,
92
+ },
93
+ ),
94
+ ),
95
+ ]
96
+
97
+ runner = BenchmarkMixin()
98
+ runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
diffusers/benchmarks/benchmarking_ltx.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
5
+
6
+ from diffusers import LTXVideoTransformer3DModel
7
+ from diffusers.utils.testing_utils import torch_device
8
+
9
+
10
+ CKPT_ID = "Lightricks/LTX-Video-0.9.7-dev"
11
+ RESULT_FILENAME = "ltx.csv"
12
+
13
+
14
+ def get_input_dict(**device_dtype_kwargs):
15
+ # 512x704 (161 frames)
16
+ # `max_sequence_length`: 256
17
+ hidden_states = torch.randn(1, 7392, 128, **device_dtype_kwargs)
18
+ encoder_hidden_states = torch.randn(1, 256, 4096, **device_dtype_kwargs)
19
+ encoder_attention_mask = torch.ones(1, 256, **device_dtype_kwargs)
20
+ timestep = torch.tensor([1.0], **device_dtype_kwargs)
21
+ video_coords = torch.randn(1, 3, 7392, **device_dtype_kwargs)
22
+
23
+ return {
24
+ "hidden_states": hidden_states,
25
+ "encoder_hidden_states": encoder_hidden_states,
26
+ "encoder_attention_mask": encoder_attention_mask,
27
+ "timestep": timestep,
28
+ "video_coords": video_coords,
29
+ }
30
+
31
+
32
+ if __name__ == "__main__":
33
+ scenarios = [
34
+ BenchmarkScenario(
35
+ name=f"{CKPT_ID}-bf16",
36
+ model_cls=LTXVideoTransformer3DModel,
37
+ model_init_kwargs={
38
+ "pretrained_model_name_or_path": CKPT_ID,
39
+ "torch_dtype": torch.bfloat16,
40
+ "subfolder": "transformer",
41
+ },
42
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
43
+ model_init_fn=model_init_fn,
44
+ compile_kwargs={"fullgraph": True},
45
+ ),
46
+ BenchmarkScenario(
47
+ name=f"{CKPT_ID}-layerwise-upcasting",
48
+ model_cls=LTXVideoTransformer3DModel,
49
+ model_init_kwargs={
50
+ "pretrained_model_name_or_path": CKPT_ID,
51
+ "torch_dtype": torch.bfloat16,
52
+ "subfolder": "transformer",
53
+ },
54
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
55
+ model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
56
+ ),
57
+ BenchmarkScenario(
58
+ name=f"{CKPT_ID}-group-offload-leaf",
59
+ model_cls=LTXVideoTransformer3DModel,
60
+ model_init_kwargs={
61
+ "pretrained_model_name_or_path": CKPT_ID,
62
+ "torch_dtype": torch.bfloat16,
63
+ "subfolder": "transformer",
64
+ },
65
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
66
+ model_init_fn=partial(
67
+ model_init_fn,
68
+ group_offload_kwargs={
69
+ "onload_device": torch_device,
70
+ "offload_device": torch.device("cpu"),
71
+ "offload_type": "leaf_level",
72
+ "use_stream": True,
73
+ "non_blocking": True,
74
+ },
75
+ ),
76
+ ),
77
+ ]
78
+
79
+ runner = BenchmarkMixin()
80
+ runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
diffusers/benchmarks/benchmarking_sdxl.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
5
+
6
+ from diffusers import UNet2DConditionModel
7
+ from diffusers.utils.testing_utils import torch_device
8
+
9
+
10
+ CKPT_ID = "stabilityai/stable-diffusion-xl-base-1.0"
11
+ RESULT_FILENAME = "sdxl.csv"
12
+
13
+
14
+ def get_input_dict(**device_dtype_kwargs):
15
+ # height: 1024
16
+ # width: 1024
17
+ # max_sequence_length: 77
18
+ hidden_states = torch.randn(1, 4, 128, 128, **device_dtype_kwargs)
19
+ encoder_hidden_states = torch.randn(1, 77, 2048, **device_dtype_kwargs)
20
+ timestep = torch.tensor([1.0], **device_dtype_kwargs)
21
+ added_cond_kwargs = {
22
+ "text_embeds": torch.randn(1, 1280, **device_dtype_kwargs),
23
+ "time_ids": torch.ones(1, 6, **device_dtype_kwargs),
24
+ }
25
+
26
+ return {
27
+ "sample": hidden_states,
28
+ "encoder_hidden_states": encoder_hidden_states,
29
+ "timestep": timestep,
30
+ "added_cond_kwargs": added_cond_kwargs,
31
+ }
32
+
33
+
34
+ if __name__ == "__main__":
35
+ scenarios = [
36
+ BenchmarkScenario(
37
+ name=f"{CKPT_ID}-bf16",
38
+ model_cls=UNet2DConditionModel,
39
+ model_init_kwargs={
40
+ "pretrained_model_name_or_path": CKPT_ID,
41
+ "torch_dtype": torch.bfloat16,
42
+ "subfolder": "unet",
43
+ },
44
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
45
+ model_init_fn=model_init_fn,
46
+ compile_kwargs={"fullgraph": True},
47
+ ),
48
+ BenchmarkScenario(
49
+ name=f"{CKPT_ID}-layerwise-upcasting",
50
+ model_cls=UNet2DConditionModel,
51
+ model_init_kwargs={
52
+ "pretrained_model_name_or_path": CKPT_ID,
53
+ "torch_dtype": torch.bfloat16,
54
+ "subfolder": "unet",
55
+ },
56
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
57
+ model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
58
+ ),
59
+ BenchmarkScenario(
60
+ name=f"{CKPT_ID}-group-offload-leaf",
61
+ model_cls=UNet2DConditionModel,
62
+ model_init_kwargs={
63
+ "pretrained_model_name_or_path": CKPT_ID,
64
+ "torch_dtype": torch.bfloat16,
65
+ "subfolder": "unet",
66
+ },
67
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
68
+ model_init_fn=partial(
69
+ model_init_fn,
70
+ group_offload_kwargs={
71
+ "onload_device": torch_device,
72
+ "offload_device": torch.device("cpu"),
73
+ "offload_type": "leaf_level",
74
+ "use_stream": True,
75
+ "non_blocking": True,
76
+ },
77
+ ),
78
+ ),
79
+ ]
80
+
81
+ runner = BenchmarkMixin()
82
+ runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
diffusers/benchmarks/benchmarking_utils.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import inspect
3
+ import logging
4
+ import os
5
+ import queue
6
+ import threading
7
+ from contextlib import nullcontext
8
+ from dataclasses import dataclass
9
+ from typing import Any, Callable, Dict, Optional, Union
10
+
11
+ import pandas as pd
12
+ import torch
13
+ import torch.utils.benchmark as benchmark
14
+
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.utils.testing_utils import require_torch_gpu, torch_device
17
+
18
+
19
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
20
+ logger = logging.getLogger(__name__)
21
+
22
+ NUM_WARMUP_ROUNDS = 5
23
+
24
+
25
+ def benchmark_fn(f, *args, **kwargs):
26
+ t0 = benchmark.Timer(
27
+ stmt="f(*args, **kwargs)",
28
+ globals={"args": args, "kwargs": kwargs, "f": f},
29
+ num_threads=1,
30
+ )
31
+ return float(f"{(t0.blocked_autorange().mean):.3f}")
32
+
33
+
34
+ def flush():
35
+ gc.collect()
36
+ torch.cuda.empty_cache()
37
+ torch.cuda.reset_max_memory_allocated()
38
+ torch.cuda.reset_peak_memory_stats()
39
+
40
+
41
+ # Adapted from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py
42
+ def calculate_flops(model, input_dict):
43
+ try:
44
+ from torchprofile import profile_macs
45
+ except ModuleNotFoundError:
46
+ raise
47
+
48
+ # This is a hacky way to convert the kwargs to args as `profile_macs` cries about kwargs.
49
+ sig = inspect.signature(model.forward)
50
+ param_names = [
51
+ p.name
52
+ for p in sig.parameters.values()
53
+ if p.kind
54
+ in (
55
+ inspect.Parameter.POSITIONAL_ONLY,
56
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
57
+ )
58
+ and p.name != "self"
59
+ ]
60
+ bound = sig.bind_partial(**input_dict)
61
+ bound.apply_defaults()
62
+ args = tuple(bound.arguments[name] for name in param_names)
63
+
64
+ model.eval()
65
+ with torch.no_grad():
66
+ macs = profile_macs(model, args)
67
+ flops = 2 * macs # 1 MAC operation = 2 FLOPs (1 multiplication + 1 addition)
68
+ return flops
69
+
70
+
71
+ def calculate_params(model):
72
+ return sum(p.numel() for p in model.parameters())
73
+
74
+
75
+ # Users can define their own in case this doesn't suffice. For most cases,
76
+ # it should be sufficient.
77
+ def model_init_fn(model_cls, group_offload_kwargs=None, layerwise_upcasting=False, **init_kwargs):
78
+ model = model_cls.from_pretrained(**init_kwargs).eval()
79
+ if group_offload_kwargs and isinstance(group_offload_kwargs, dict):
80
+ model.enable_group_offload(**group_offload_kwargs)
81
+ else:
82
+ model.to(torch_device)
83
+ if layerwise_upcasting:
84
+ model.enable_layerwise_casting(
85
+ storage_dtype=torch.float8_e4m3fn, compute_dtype=init_kwargs.get("torch_dtype", torch.bfloat16)
86
+ )
87
+ return model
88
+
89
+
90
+ @dataclass
91
+ class BenchmarkScenario:
92
+ name: str
93
+ model_cls: ModelMixin
94
+ model_init_kwargs: Dict[str, Any]
95
+ model_init_fn: Callable
96
+ get_model_input_dict: Callable
97
+ compile_kwargs: Optional[Dict[str, Any]] = None
98
+
99
+
100
+ @require_torch_gpu
101
+ class BenchmarkMixin:
102
+ def pre_benchmark(self):
103
+ flush()
104
+ torch.compiler.reset()
105
+
106
+ def post_benchmark(self, model):
107
+ model.cpu()
108
+ flush()
109
+ torch.compiler.reset()
110
+
111
+ @torch.no_grad()
112
+ def run_benchmark(self, scenario: BenchmarkScenario):
113
+ # 0) Basic stats
114
+ logger.info(f"Running scenario: {scenario.name}.")
115
+ try:
116
+ model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs)
117
+ num_params = round(calculate_params(model) / 1e9, 2)
118
+ try:
119
+ flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e9, 2)
120
+ except Exception as e:
121
+ logger.info(f"Problem in calculating FLOPs:\n{e}")
122
+ flops = None
123
+ model.cpu()
124
+ del model
125
+ except Exception as e:
126
+ logger.info(f"Error while initializing the model and calculating FLOPs:\n{e}")
127
+ return {}
128
+ self.pre_benchmark()
129
+
130
+ # 1) plain stats
131
+ results = {}
132
+ plain = None
133
+ try:
134
+ plain = self._run_phase(
135
+ model_cls=scenario.model_cls,
136
+ init_fn=scenario.model_init_fn,
137
+ init_kwargs=scenario.model_init_kwargs,
138
+ get_input_fn=scenario.get_model_input_dict,
139
+ compile_kwargs=None,
140
+ )
141
+ except Exception as e:
142
+ logger.info(f"Benchmark could not be run with the following error:\n{e}")
143
+ return results
144
+
145
+ # 2) compiled stats (if any)
146
+ compiled = {"time": None, "memory": None}
147
+ if scenario.compile_kwargs:
148
+ try:
149
+ compiled = self._run_phase(
150
+ model_cls=scenario.model_cls,
151
+ init_fn=scenario.model_init_fn,
152
+ init_kwargs=scenario.model_init_kwargs,
153
+ get_input_fn=scenario.get_model_input_dict,
154
+ compile_kwargs=scenario.compile_kwargs,
155
+ )
156
+ except Exception as e:
157
+ logger.info(f"Compilation benchmark could not be run with the following error\n: {e}")
158
+ if plain is None:
159
+ return results
160
+
161
+ # 3) merge
162
+ result = {
163
+ "scenario": scenario.name,
164
+ "model_cls": scenario.model_cls.__name__,
165
+ "num_params_B": num_params,
166
+ "flops_G": flops,
167
+ "time_plain_s": plain["time"],
168
+ "mem_plain_GB": plain["memory"],
169
+ "time_compile_s": compiled["time"],
170
+ "mem_compile_GB": compiled["memory"],
171
+ }
172
+ if scenario.compile_kwargs:
173
+ result["fullgraph"] = scenario.compile_kwargs.get("fullgraph", False)
174
+ result["mode"] = scenario.compile_kwargs.get("mode", "default")
175
+ else:
176
+ result["fullgraph"], result["mode"] = None, None
177
+ return result
178
+
179
+ def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[BenchmarkScenario]], filename: str):
180
+ if not isinstance(scenarios, list):
181
+ scenarios = [scenarios]
182
+ record_queue = queue.Queue()
183
+ stop_signal = object()
184
+
185
+ def _writer_thread():
186
+ while True:
187
+ item = record_queue.get()
188
+ if item is stop_signal:
189
+ break
190
+ df_row = pd.DataFrame([item])
191
+ write_header = not os.path.exists(filename)
192
+ df_row.to_csv(filename, mode="a", header=write_header, index=False)
193
+ record_queue.task_done()
194
+
195
+ record_queue.task_done()
196
+
197
+ writer = threading.Thread(target=_writer_thread, daemon=True)
198
+ writer.start()
199
+
200
+ for s in scenarios:
201
+ try:
202
+ record = self.run_benchmark(s)
203
+ if record:
204
+ record_queue.put(record)
205
+ else:
206
+ logger.info(f"Record empty from scenario: {s.name}.")
207
+ except Exception as e:
208
+ logger.info(f"Running scenario ({s.name}) led to error:\n{e}")
209
+ record_queue.put(stop_signal)
210
+ logger.info(f"Results serialized to {filename=}.")
211
+
212
+ def _run_phase(
213
+ self,
214
+ *,
215
+ model_cls: ModelMixin,
216
+ init_fn: Callable,
217
+ init_kwargs: Dict[str, Any],
218
+ get_input_fn: Callable,
219
+ compile_kwargs: Optional[Dict[str, Any]],
220
+ ) -> Dict[str, float]:
221
+ # setup
222
+ self.pre_benchmark()
223
+
224
+ # init & (optional) compile
225
+ model = init_fn(model_cls, **init_kwargs)
226
+ if compile_kwargs:
227
+ model.compile(**compile_kwargs)
228
+
229
+ # build inputs
230
+ inp = get_input_fn()
231
+
232
+ # measure
233
+ run_ctx = torch._inductor.utils.fresh_inductor_cache() if compile_kwargs else nullcontext()
234
+ with run_ctx:
235
+ for _ in range(NUM_WARMUP_ROUNDS):
236
+ _ = model(**inp)
237
+ time_s = benchmark_fn(lambda m, d: m(**d), model, inp)
238
+ mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
239
+ mem_gb = round(mem_gb, 2)
240
+
241
+ # teardown
242
+ self.post_benchmark(model)
243
+ del model
244
+ return {"time": time_s, "memory": mem_gb}
diffusers/benchmarks/benchmarking_wan.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
5
+
6
+ from diffusers import WanTransformer3DModel
7
+ from diffusers.utils.testing_utils import torch_device
8
+
9
+
10
+ CKPT_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
11
+ RESULT_FILENAME = "wan.csv"
12
+
13
+
14
+ def get_input_dict(**device_dtype_kwargs):
15
+ # height: 480
16
+ # width: 832
17
+ # num_frames: 81
18
+ # max_sequence_length: 512
19
+ hidden_states = torch.randn(1, 16, 21, 60, 104, **device_dtype_kwargs)
20
+ encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)
21
+ timestep = torch.tensor([1.0], **device_dtype_kwargs)
22
+
23
+ return {"hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep}
24
+
25
+
26
+ if __name__ == "__main__":
27
+ scenarios = [
28
+ BenchmarkScenario(
29
+ name=f"{CKPT_ID}-bf16",
30
+ model_cls=WanTransformer3DModel,
31
+ model_init_kwargs={
32
+ "pretrained_model_name_or_path": CKPT_ID,
33
+ "torch_dtype": torch.bfloat16,
34
+ "subfolder": "transformer",
35
+ },
36
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
37
+ model_init_fn=model_init_fn,
38
+ compile_kwargs={"fullgraph": True},
39
+ ),
40
+ BenchmarkScenario(
41
+ name=f"{CKPT_ID}-layerwise-upcasting",
42
+ model_cls=WanTransformer3DModel,
43
+ model_init_kwargs={
44
+ "pretrained_model_name_or_path": CKPT_ID,
45
+ "torch_dtype": torch.bfloat16,
46
+ "subfolder": "transformer",
47
+ },
48
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
49
+ model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
50
+ ),
51
+ BenchmarkScenario(
52
+ name=f"{CKPT_ID}-group-offload-leaf",
53
+ model_cls=WanTransformer3DModel,
54
+ model_init_kwargs={
55
+ "pretrained_model_name_or_path": CKPT_ID,
56
+ "torch_dtype": torch.bfloat16,
57
+ "subfolder": "transformer",
58
+ },
59
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
60
+ model_init_fn=partial(
61
+ model_init_fn,
62
+ group_offload_kwargs={
63
+ "onload_device": torch_device,
64
+ "offload_device": torch.device("cpu"),
65
+ "offload_type": "leaf_level",
66
+ "use_stream": True,
67
+ "non_blocking": True,
68
+ },
69
+ ),
70
+ ),
71
+ ]
72
+
73
+ runner = BenchmarkMixin()
74
+ runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
diffusers/benchmarks/populate_into_db.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+ import gpustat
6
+ import pandas as pd
7
+ import psycopg2
8
+ import psycopg2.extras
9
+ from psycopg2.extensions import register_adapter
10
+ from psycopg2.extras import Json
11
+
12
+
13
+ register_adapter(dict, Json)
14
+
15
+ FINAL_CSV_FILENAME = "collated_results.csv"
16
+ # https://github.com/huggingface/transformers/blob/593e29c5e2a9b17baec010e8dc7c1431fed6e841/benchmark/init_db.sql#L27
17
+ BENCHMARKS_TABLE_NAME = "benchmarks"
18
+ MEASUREMENTS_TABLE_NAME = "model_measurements"
19
+
20
+
21
+ def _init_benchmark(conn, branch, commit_id, commit_msg):
22
+ gpu_stats = gpustat.GPUStatCollection.new_query()
23
+ metadata = {"gpu_name": gpu_stats[0]["name"]}
24
+ repository = "huggingface/diffusers"
25
+ with conn.cursor() as cur:
26
+ cur.execute(
27
+ f"INSERT INTO {BENCHMARKS_TABLE_NAME} (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id",
28
+ (repository, branch, commit_id, commit_msg, metadata),
29
+ )
30
+ benchmark_id = cur.fetchone()[0]
31
+ print(f"Initialised benchmark #{benchmark_id}")
32
+ return benchmark_id
33
+
34
+
35
+ def parse_args():
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument(
38
+ "branch",
39
+ type=str,
40
+ help="The branch name on which the benchmarking is performed.",
41
+ )
42
+
43
+ parser.add_argument(
44
+ "commit_id",
45
+ type=str,
46
+ help="The commit hash on which the benchmarking is performed.",
47
+ )
48
+
49
+ parser.add_argument(
50
+ "commit_msg",
51
+ type=str,
52
+ help="The commit message associated with the commit, truncated to 70 characters.",
53
+ )
54
+ args = parser.parse_args()
55
+ return args
56
+
57
+
58
+ if __name__ == "__main__":
59
+ args = parse_args()
60
+ try:
61
+ conn = psycopg2.connect(
62
+ host=os.getenv("PGHOST"),
63
+ database=os.getenv("PGDATABASE"),
64
+ user=os.getenv("PGUSER"),
65
+ password=os.getenv("PGPASSWORD"),
66
+ )
67
+ print("DB connection established successfully.")
68
+ except Exception as e:
69
+ print(f"Problem during DB init: {e}")
70
+ sys.exit(1)
71
+
72
+ try:
73
+ benchmark_id = _init_benchmark(
74
+ conn=conn,
75
+ branch=args.branch,
76
+ commit_id=args.commit_id,
77
+ commit_msg=args.commit_msg,
78
+ )
79
+ except Exception as e:
80
+ print(f"Problem during initializing benchmark: {e}")
81
+ sys.exit(1)
82
+
83
+ cur = conn.cursor()
84
+
85
+ df = pd.read_csv(FINAL_CSV_FILENAME)
86
+
87
+ # Helper to cast values (or None) given a dtype
88
+ def _cast_value(val, dtype: str):
89
+ if pd.isna(val):
90
+ return None
91
+
92
+ if dtype == "text":
93
+ return str(val).strip()
94
+
95
+ if dtype == "float":
96
+ try:
97
+ return float(val)
98
+ except ValueError:
99
+ return None
100
+
101
+ if dtype == "bool":
102
+ s = str(val).strip().lower()
103
+ if s in ("true", "t", "yes", "1"):
104
+ return True
105
+ if s in ("false", "f", "no", "0"):
106
+ return False
107
+ if val in (1, 1.0):
108
+ return True
109
+ if val in (0, 0.0):
110
+ return False
111
+ return None
112
+
113
+ return val
114
+
115
+ try:
116
+ rows_to_insert = []
117
+ for _, row in df.iterrows():
118
+ scenario = _cast_value(row.get("scenario"), "text")
119
+ model_cls = _cast_value(row.get("model_cls"), "text")
120
+ num_params_B = _cast_value(row.get("num_params_B"), "float")
121
+ flops_G = _cast_value(row.get("flops_G"), "float")
122
+ time_plain_s = _cast_value(row.get("time_plain_s"), "float")
123
+ mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float")
124
+ time_compile_s = _cast_value(row.get("time_compile_s"), "float")
125
+ mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float")
126
+ fullgraph = _cast_value(row.get("fullgraph"), "bool")
127
+ mode = _cast_value(row.get("mode"), "text")
128
+
129
+ # If "github_sha" column exists in the CSV, cast it; else default to None
130
+ if "github_sha" in df.columns:
131
+ github_sha = _cast_value(row.get("github_sha"), "text")
132
+ else:
133
+ github_sha = None
134
+
135
+ measurements = {
136
+ "scenario": scenario,
137
+ "model_cls": model_cls,
138
+ "num_params_B": num_params_B,
139
+ "flops_G": flops_G,
140
+ "time_plain_s": time_plain_s,
141
+ "mem_plain_GB": mem_plain_GB,
142
+ "time_compile_s": time_compile_s,
143
+ "mem_compile_GB": mem_compile_GB,
144
+ "fullgraph": fullgraph,
145
+ "mode": mode,
146
+ "github_sha": github_sha,
147
+ }
148
+ rows_to_insert.append((benchmark_id, measurements))
149
+
150
+ # Batch-insert all rows
151
+ insert_sql = f"""
152
+ INSERT INTO {MEASUREMENTS_TABLE_NAME} (
153
+ benchmark_id,
154
+ measurements
155
+ )
156
+ VALUES (%s, %s);
157
+ """
158
+
159
+ psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert)
160
+ conn.commit()
161
+
162
+ cur.close()
163
+ conn.close()
164
+ except Exception as e:
165
+ print(f"Exception: {e}")
166
+ sys.exit(1)
diffusers/benchmarks/push_results.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pandas as pd
4
+ from huggingface_hub import hf_hub_download, upload_file
5
+ from huggingface_hub.utils import EntryNotFoundError
6
+
7
+
8
+ REPO_ID = "diffusers/benchmarks"
9
+
10
+
11
+ def has_previous_benchmark() -> str:
12
+ from run_all import FINAL_CSV_FILENAME
13
+
14
+ csv_path = None
15
+ try:
16
+ csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILENAME)
17
+ except EntryNotFoundError:
18
+ csv_path = None
19
+ return csv_path
20
+
21
+
22
+ def filter_float(value):
23
+ if isinstance(value, str):
24
+ return float(value.split()[0])
25
+ return value
26
+
27
+
28
+ def push_to_hf_dataset():
29
+ from run_all import FINAL_CSV_FILENAME, GITHUB_SHA
30
+
31
+ csv_path = has_previous_benchmark()
32
+ if csv_path is not None:
33
+ current_results = pd.read_csv(FINAL_CSV_FILENAME)
34
+ previous_results = pd.read_csv(csv_path)
35
+
36
+ numeric_columns = current_results.select_dtypes(include=["float64", "int64"]).columns
37
+
38
+ for column in numeric_columns:
39
+ # get previous values as floats, aligned to current index
40
+ prev_vals = previous_results[column].map(filter_float).reindex(current_results.index)
41
+
42
+ # get current values as floats
43
+ curr_vals = current_results[column].astype(float)
44
+
45
+ # stringify the current values
46
+ curr_str = curr_vals.map(str)
47
+
48
+ # build an appendage only when prev exists and differs
49
+ append_str = prev_vals.where(prev_vals.notnull() & (prev_vals != curr_vals), other=pd.NA).map(
50
+ lambda x: f" ({x})" if pd.notnull(x) else ""
51
+ )
52
+
53
+ # combine
54
+ current_results[column] = curr_str + append_str
55
+ os.remove(FINAL_CSV_FILENAME)
56
+ current_results.to_csv(FINAL_CSV_FILENAME, index=False)
57
+
58
+ commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results"
59
+ upload_file(
60
+ repo_id=REPO_ID,
61
+ path_in_repo=FINAL_CSV_FILENAME,
62
+ path_or_fileobj=FINAL_CSV_FILENAME,
63
+ repo_type="dataset",
64
+ commit_message=commit_message,
65
+ )
66
+ upload_file(
67
+ repo_id="diffusers/benchmark-analyzer",
68
+ path_in_repo=FINAL_CSV_FILENAME,
69
+ path_or_fileobj=FINAL_CSV_FILENAME,
70
+ repo_type="space",
71
+ commit_message=commit_message,
72
+ )
73
+
74
+
75
+ if __name__ == "__main__":
76
+ push_to_hf_dataset()
diffusers/benchmarks/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ pandas
2
+ psutil
3
+ gpustat
4
+ torchprofile
5
+ bitsandbytes
6
+ psycopg2==2.9.9
diffusers/benchmarks/run_all.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import logging
3
+ import os
4
+ import subprocess
5
+
6
+ import pandas as pd
7
+
8
+
9
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
10
+ logger = logging.getLogger(__name__)
11
+
12
+ PATTERN = "benchmarking_*.py"
13
+ FINAL_CSV_FILENAME = "collated_results.csv"
14
+ GITHUB_SHA = os.getenv("GITHUB_SHA", None)
15
+
16
+
17
+ class SubprocessCallException(Exception):
18
+ pass
19
+
20
+
21
+ def run_command(command: list[str], return_stdout=False):
22
+ try:
23
+ output = subprocess.check_output(command, stderr=subprocess.STDOUT)
24
+ if return_stdout and hasattr(output, "decode"):
25
+ return output.decode("utf-8")
26
+ except subprocess.CalledProcessError as e:
27
+ raise SubprocessCallException(f"Command `{' '.join(command)}` failed with:\n{e.output.decode()}") from e
28
+
29
+
30
+ def merge_csvs(final_csv: str = "collated_results.csv"):
31
+ all_csvs = glob.glob("*.csv")
32
+ all_csvs = [f for f in all_csvs if f != final_csv]
33
+ if not all_csvs:
34
+ logger.info("No result CSVs found to merge.")
35
+ return
36
+
37
+ df_list = []
38
+ for f in all_csvs:
39
+ try:
40
+ d = pd.read_csv(f)
41
+ except pd.errors.EmptyDataError:
42
+ # If a file existed but was zero‐bytes or corrupted, skip it
43
+ continue
44
+ df_list.append(d)
45
+
46
+ if not df_list:
47
+ logger.info("All result CSVs were empty or invalid; nothing to merge.")
48
+ return
49
+
50
+ final_df = pd.concat(df_list, ignore_index=True)
51
+ if GITHUB_SHA is not None:
52
+ final_df["github_sha"] = GITHUB_SHA
53
+ final_df.to_csv(final_csv, index=False)
54
+ logger.info(f"Merged {len(all_csvs)} partial CSVs → {final_csv}.")
55
+
56
+
57
+ def run_scripts():
58
+ python_files = sorted(glob.glob(PATTERN))
59
+ python_files = [f for f in python_files if f != "benchmarking_utils.py"]
60
+
61
+ for file in python_files:
62
+ script_name = file.split(".py")[0].split("_")[-1] # example: benchmarking_foo.py -> foo
63
+ logger.info(f"\n****** Running file: {file} ******")
64
+
65
+ partial_csv = f"{script_name}.csv"
66
+ if os.path.exists(partial_csv):
67
+ logger.info(f"Found {partial_csv}. Removing for safer numbers and duplication.")
68
+ os.remove(partial_csv)
69
+
70
+ command = ["python", file]
71
+ try:
72
+ run_command(command)
73
+ logger.info(f"→ {file} finished normally.")
74
+ except SubprocessCallException as e:
75
+ logger.info(f"Error running {file}:\n{e}")
76
+ finally:
77
+ logger.info(f"→ Merging partial CSVs after {file} …")
78
+ merge_csvs(final_csv=FINAL_CSV_FILENAME)
79
+
80
+ logger.info(f"\nAll scripts attempted. Final collated CSV: {FINAL_CSV_FILENAME}")
81
+
82
+
83
+ if __name__ == "__main__":
84
+ run_scripts()
diffusers/examples/README.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!---
2
+ Copyright 2025 The HuggingFace Team. All rights reserved.
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ -->
15
+
16
+ # 🧨 Diffusers Examples
17
+
18
+ Diffusers examples are a collection of scripts to demonstrate how to effectively use the `diffusers` library
19
+ for a variety of use cases involving training or fine-tuning.
20
+
21
+ **Note**: If you are looking for **official** examples on how to use `diffusers` for inference, please have a look at [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
22
+
23
+ Our examples aspire to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.
24
+ More specifically, this means:
25
+
26
+ - **Self-contained**: An example script shall only depend on "pip-install-able" Python packages that can be found in a `requirements.txt` file. Example scripts shall **not** depend on any local files. This means that one can simply download an example script, *e.g.* [train_unconditional.py](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py), install the required dependencies, *e.g.* [requirements.txt](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/requirements.txt) and execute the example script.
27
+ - **Easy-to-tweak**: While we strive to present as many use cases as possible, the example scripts are just that - examples. It is expected that they won't work out-of-the box on your specific problem and that you will be required to change a few lines of code to adapt them to your needs. To help you with that, most of the examples fully expose the preprocessing of the data and the training loop to allow you to tweak and edit them as required.
28
+ - **Beginner-friendly**: We do not aim for providing state-of-the-art training scripts for the newest models, but rather examples that can be used as a way to better understand diffusion models and how to use them with the `diffusers` library. We often purposefully leave out certain state-of-the-art methods if we consider them too complex for beginners.
29
+ - **One-purpose-only**: Examples should show one task and one task only. Even if a task is from a modeling point of view very similar, *e.g.* image super-resolution and image modification tend to use the same model and training method, we want examples to showcase only one task to keep them as readable and easy-to-understand as possible.
30
+
31
+ We provide **official** examples that cover the most popular tasks of diffusion models.
32
+ *Official* examples are **actively** maintained by the `diffusers` maintainers and we try to rigorously follow our example philosophy as defined above.
33
+ If you feel like another important example should exist, we are more than happy to welcome a [Feature Request](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=) or directly a [Pull Request](https://github.com/huggingface/diffusers/compare) from you!
34
+
35
+ Training examples show how to pretrain or fine-tune diffusion models for a variety of tasks. Currently we support:
36
+
37
+ | Task | 🤗 Accelerate | 🤗 Datasets | Colab
38
+ |---|---|:---:|:---:|
39
+ | [**Unconditional Image Generation**](./unconditional_image_generation) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
40
+ | [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |
41
+ | [**Textual Inversion**](./textual_inversion) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
42
+ | [**Dreambooth**](./dreambooth) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
43
+ | [**ControlNet**](./controlnet) | ✅ | ✅ | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)
44
+ | [**InstructPix2Pix**](./instruct_pix2pix) | ✅ | ✅ | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/InstructPix2Pix_using_diffusers.ipynb)
45
+ | [**Reinforcement Learning for Control**](./reinforcement_learning) | - | - | [Notebook1](https://github.com/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_for_control.ipynb), [Notebook2](https://github.com/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb)
46
+
47
+ ## Community
48
+
49
+ In addition, we provide **community** examples, which are examples added and maintained by our community.
50
+ Community examples can consist of both *training* examples or *inference* pipelines.
51
+ For such examples, we are more lenient regarding the philosophy defined above and also cannot guarantee to provide maintenance for every issue.
52
+ Examples that are useful for the community, but are either not yet deemed popular or not yet following our above philosophy should go into the [community examples](https://github.com/huggingface/diffusers/tree/main/examples/community) folder. The community folder therefore includes training examples and inference pipelines.
53
+ **Note**: Community examples can be a [great first contribution](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) to show to the community how you like to use `diffusers` 🪄.
54
+
55
+ ## Research Projects
56
+
57
+ We also provide **research_projects** examples that are maintained by the community as defined in the respective research project folders. These examples are useful and offer the extended capabilities which are complementary to the official examples. You may refer to [research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) for details.
58
+
59
+ ## Important note
60
+
61
+ To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
62
+ ```bash
63
+ git clone https://github.com/huggingface/diffusers
64
+ cd diffusers
65
+ pip install .
66
+ ```
67
+ Then cd in the example folder of your choice and run
68
+ ```bash
69
+ pip install -r requirements.txt
70
+ ```
diffusers/examples/conftest.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # tests directory-specific settings - this file is run automatically
16
+ # by pytest before any tests are run
17
+
18
+ import sys
19
+ import warnings
20
+ from os.path import abspath, dirname, join
21
+
22
+
23
+ # allow having multiple repository checkouts and not needing to remember to rerun
24
+ # 'pip install -e .[dev]' when switching between checkouts and running tests.
25
+ git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src"))
26
+ sys.path.insert(1, git_repo_path)
27
+
28
+
29
+ # silence FutureWarning warnings in tests since often we can't act on them until
30
+ # they become normal warnings - i.e. the tests still need to test the current functionality
31
+ warnings.simplefilter(action="ignore", category=FutureWarning)
32
+
33
+
34
+ def pytest_addoption(parser):
35
+ from diffusers.utils.testing_utils import pytest_addoption_shared
36
+
37
+ pytest_addoption_shared(parser)
38
+
39
+
40
+ def pytest_terminal_summary(terminalreporter):
41
+ from diffusers.utils.testing_utils import pytest_terminal_summary_main
42
+
43
+ make_reports = terminalreporter.config.getoption("--make-reports")
44
+ if make_reports:
45
+ pytest_terminal_summary_main(terminalreporter, id=make_reports)
diffusers/examples/test_examples_utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import shutil
18
+ import subprocess
19
+ import tempfile
20
+ import unittest
21
+ from typing import List
22
+
23
+ from accelerate.utils import write_basic_config
24
+
25
+
26
+ # These utils relate to ensuring the right error message is received when running scripts
27
+ class SubprocessCallException(Exception):
28
+ pass
29
+
30
+
31
+ def run_command(command: List[str], return_stdout=False):
32
+ """
33
+ Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
34
+ if an error occurred while running `command`
35
+ """
36
+ try:
37
+ output = subprocess.check_output(command, stderr=subprocess.STDOUT)
38
+ if return_stdout:
39
+ if hasattr(output, "decode"):
40
+ output = output.decode("utf-8")
41
+ return output
42
+ except subprocess.CalledProcessError as e:
43
+ raise SubprocessCallException(
44
+ f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
45
+ ) from e
46
+
47
+
48
+ class ExamplesTestsAccelerate(unittest.TestCase):
49
+ @classmethod
50
+ def setUpClass(cls):
51
+ super().setUpClass()
52
+ cls._tmpdir = tempfile.mkdtemp()
53
+ cls.configPath = os.path.join(cls._tmpdir, "default_config.yml")
54
+
55
+ write_basic_config(save_location=cls.configPath)
56
+ cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]
57
+
58
+ @classmethod
59
+ def tearDownClass(cls):
60
+ super().tearDownClass()
61
+ shutil.rmtree(cls._tmpdir)
diffusers/scripts/__init__.py ADDED
File without changes
diffusers/scripts/change_naming_configs_and_checkpoints.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Conversion script for the LDM checkpoints."""
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+
21
+ import torch
22
+ from transformers.file_utils import has_file
23
+
24
+ from diffusers import UNet2DConditionModel, UNet2DModel
25
+
26
+
27
+ do_only_config = False
28
+ do_only_weights = True
29
+ do_only_renaming = False
30
+
31
+
32
+ if __name__ == "__main__":
33
+ parser = argparse.ArgumentParser()
34
+
35
+ parser.add_argument(
36
+ "--repo_path",
37
+ default=None,
38
+ type=str,
39
+ required=True,
40
+ help="The config json file corresponding to the architecture.",
41
+ )
42
+
43
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
44
+
45
+ args = parser.parse_args()
46
+
47
+ config_parameters_to_change = {
48
+ "image_size": "sample_size",
49
+ "num_res_blocks": "layers_per_block",
50
+ "block_channels": "block_out_channels",
51
+ "down_blocks": "down_block_types",
52
+ "up_blocks": "up_block_types",
53
+ "downscale_freq_shift": "freq_shift",
54
+ "resnet_num_groups": "norm_num_groups",
55
+ "resnet_act_fn": "act_fn",
56
+ "resnet_eps": "norm_eps",
57
+ "num_head_channels": "attention_head_dim",
58
+ }
59
+
60
+ key_parameters_to_change = {
61
+ "time_steps": "time_proj",
62
+ "mid": "mid_block",
63
+ "downsample_blocks": "down_blocks",
64
+ "upsample_blocks": "up_blocks",
65
+ }
66
+
67
+ subfolder = "" if has_file(args.repo_path, "config.json") else "unet"
68
+
69
+ with open(os.path.join(args.repo_path, subfolder, "config.json"), "r", encoding="utf-8") as reader:
70
+ text = reader.read()
71
+ config = json.loads(text)
72
+
73
+ if do_only_config:
74
+ for key in config_parameters_to_change.keys():
75
+ config.pop(key, None)
76
+
77
+ if has_file(args.repo_path, "config.json"):
78
+ model = UNet2DModel(**config)
79
+ else:
80
+ class_name = UNet2DConditionModel if "ldm-text2im-large-256" in args.repo_path else UNet2DModel
81
+ model = class_name(**config)
82
+
83
+ if do_only_config:
84
+ model.save_config(os.path.join(args.repo_path, subfolder))
85
+
86
+ config = dict(model.config)
87
+
88
+ if do_only_renaming:
89
+ for key, value in config_parameters_to_change.items():
90
+ if key in config:
91
+ config[value] = config[key]
92
+ del config[key]
93
+
94
+ config["down_block_types"] = [k.replace("UNetRes", "") for k in config["down_block_types"]]
95
+ config["up_block_types"] = [k.replace("UNetRes", "") for k in config["up_block_types"]]
96
+
97
+ if do_only_weights:
98
+ state_dict = torch.load(os.path.join(args.repo_path, subfolder, "diffusion_pytorch_model.bin"))
99
+
100
+ new_state_dict = {}
101
+ for param_key, param_value in state_dict.items():
102
+ if param_key.endswith(".op.bias") or param_key.endswith(".op.weight"):
103
+ continue
104
+ has_changed = False
105
+ for key, new_key in key_parameters_to_change.items():
106
+ if not has_changed and param_key.split(".")[0] == key:
107
+ new_state_dict[".".join([new_key] + param_key.split(".")[1:])] = param_value
108
+ has_changed = True
109
+ if not has_changed:
110
+ new_state_dict[param_key] = param_value
111
+
112
+ model.load_state_dict(new_state_dict)
113
+ model.save_pretrained(os.path.join(args.repo_path, subfolder))
diffusers/scripts/convert_amused.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+ from argparse import ArgumentParser
4
+
5
+ import numpy as np
6
+ import torch
7
+ from muse import MaskGiTUViT, VQGANModel
8
+ from muse import PipelineMuse as OldPipelineMuse
9
+ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
10
+
11
+ from diffusers import VQModel
12
+ from diffusers.models.attention_processor import AttnProcessor
13
+ from diffusers.models.unets.uvit_2d import UVit2DModel
14
+ from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline
15
+ from diffusers.schedulers import AmusedScheduler
16
+
17
+
18
+ torch.backends.cuda.enable_flash_sdp(False)
19
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
20
+ torch.backends.cuda.enable_math_sdp(True)
21
+
22
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
23
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
24
+ torch.use_deterministic_algorithms(True)
25
+
26
+ # Enable CUDNN deterministic mode
27
+ torch.backends.cudnn.deterministic = True
28
+ torch.backends.cudnn.benchmark = False
29
+ torch.backends.cuda.matmul.allow_tf32 = False
30
+
31
+ device = "cuda"
32
+
33
+
34
+ def main():
35
+ args = ArgumentParser()
36
+ args.add_argument("--model_256", action="store_true")
37
+ args.add_argument("--write_to", type=str, required=False, default=None)
38
+ args.add_argument("--transformer_path", type=str, required=False, default=None)
39
+ args = args.parse_args()
40
+
41
+ transformer_path = args.transformer_path
42
+ subfolder = "transformer"
43
+
44
+ if transformer_path is None:
45
+ if args.model_256:
46
+ transformer_path = "openMUSE/muse-256"
47
+ else:
48
+ transformer_path = (
49
+ "../research-run-512-checkpoints/research-run-512-with-downsample-checkpoint-554000/unwrapped_model/"
50
+ )
51
+ subfolder = None
52
+
53
+ old_transformer = MaskGiTUViT.from_pretrained(transformer_path, subfolder=subfolder)
54
+
55
+ old_transformer.to(device)
56
+
57
+ old_vae = VQGANModel.from_pretrained("openMUSE/muse-512", subfolder="vae")
58
+ old_vae.to(device)
59
+
60
+ vqvae = make_vqvae(old_vae)
61
+
62
+ tokenizer = CLIPTokenizer.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
63
+
64
+ text_encoder = CLIPTextModelWithProjection.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
65
+ text_encoder.to(device)
66
+
67
+ transformer = make_transformer(old_transformer, args.model_256)
68
+
69
+ scheduler = AmusedScheduler(mask_token_id=old_transformer.config.mask_token_id)
70
+
71
+ new_pipe = AmusedPipeline(
72
+ vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler
73
+ )
74
+
75
+ old_pipe = OldPipelineMuse(
76
+ vae=old_vae, transformer=old_transformer, text_encoder=text_encoder, tokenizer=tokenizer
77
+ )
78
+ old_pipe.to(device)
79
+
80
+ if args.model_256:
81
+ transformer_seq_len = 256
82
+ orig_size = (256, 256)
83
+ else:
84
+ transformer_seq_len = 1024
85
+ orig_size = (512, 512)
86
+
87
+ old_out = old_pipe(
88
+ "dog",
89
+ generator=torch.Generator(device).manual_seed(0),
90
+ transformer_seq_len=transformer_seq_len,
91
+ orig_size=orig_size,
92
+ timesteps=12,
93
+ )[0]
94
+
95
+ new_out = new_pipe("dog", generator=torch.Generator(device).manual_seed(0)).images[0]
96
+
97
+ old_out = np.array(old_out)
98
+ new_out = np.array(new_out)
99
+
100
+ diff = np.abs(old_out.astype(np.float64) - new_out.astype(np.float64))
101
+
102
+ # assert diff diff.sum() == 0
103
+ print("skipping pipeline full equivalence check")
104
+
105
+ print(f"max diff: {diff.max()}, diff.sum() / diff.size {diff.sum() / diff.size}")
106
+
107
+ if args.model_256:
108
+ assert diff.max() <= 3
109
+ assert diff.sum() / diff.size < 0.7
110
+ else:
111
+ assert diff.max() <= 1
112
+ assert diff.sum() / diff.size < 0.4
113
+
114
+ if args.write_to is not None:
115
+ new_pipe.save_pretrained(args.write_to)
116
+
117
+
118
+ def make_transformer(old_transformer, model_256):
119
+ args = dict(old_transformer.config)
120
+ force_down_up_sample = args["force_down_up_sample"]
121
+
122
+ signature = inspect.signature(UVit2DModel.__init__)
123
+
124
+ args_ = {
125
+ "downsample": force_down_up_sample,
126
+ "upsample": force_down_up_sample,
127
+ "block_out_channels": args["block_out_channels"][0],
128
+ "sample_size": 16 if model_256 else 32,
129
+ }
130
+
131
+ for s in list(signature.parameters.keys()):
132
+ if s in ["self", "downsample", "upsample", "sample_size", "block_out_channels"]:
133
+ continue
134
+
135
+ args_[s] = args[s]
136
+
137
+ new_transformer = UVit2DModel(**args_)
138
+ new_transformer.to(device)
139
+
140
+ new_transformer.set_attn_processor(AttnProcessor())
141
+
142
+ state_dict = old_transformer.state_dict()
143
+
144
+ state_dict["cond_embed.linear_1.weight"] = state_dict.pop("cond_embed.0.weight")
145
+ state_dict["cond_embed.linear_2.weight"] = state_dict.pop("cond_embed.2.weight")
146
+
147
+ for i in range(22):
148
+ state_dict[f"transformer_layers.{i}.norm1.norm.weight"] = state_dict.pop(
149
+ f"transformer_layers.{i}.attn_layer_norm.weight"
150
+ )
151
+ state_dict[f"transformer_layers.{i}.norm1.linear.weight"] = state_dict.pop(
152
+ f"transformer_layers.{i}.self_attn_adaLN_modulation.mapper.weight"
153
+ )
154
+
155
+ state_dict[f"transformer_layers.{i}.attn1.to_q.weight"] = state_dict.pop(
156
+ f"transformer_layers.{i}.attention.query.weight"
157
+ )
158
+ state_dict[f"transformer_layers.{i}.attn1.to_k.weight"] = state_dict.pop(
159
+ f"transformer_layers.{i}.attention.key.weight"
160
+ )
161
+ state_dict[f"transformer_layers.{i}.attn1.to_v.weight"] = state_dict.pop(
162
+ f"transformer_layers.{i}.attention.value.weight"
163
+ )
164
+ state_dict[f"transformer_layers.{i}.attn1.to_out.0.weight"] = state_dict.pop(
165
+ f"transformer_layers.{i}.attention.out.weight"
166
+ )
167
+
168
+ state_dict[f"transformer_layers.{i}.norm2.norm.weight"] = state_dict.pop(
169
+ f"transformer_layers.{i}.crossattn_layer_norm.weight"
170
+ )
171
+ state_dict[f"transformer_layers.{i}.norm2.linear.weight"] = state_dict.pop(
172
+ f"transformer_layers.{i}.cross_attn_adaLN_modulation.mapper.weight"
173
+ )
174
+
175
+ state_dict[f"transformer_layers.{i}.attn2.to_q.weight"] = state_dict.pop(
176
+ f"transformer_layers.{i}.crossattention.query.weight"
177
+ )
178
+ state_dict[f"transformer_layers.{i}.attn2.to_k.weight"] = state_dict.pop(
179
+ f"transformer_layers.{i}.crossattention.key.weight"
180
+ )
181
+ state_dict[f"transformer_layers.{i}.attn2.to_v.weight"] = state_dict.pop(
182
+ f"transformer_layers.{i}.crossattention.value.weight"
183
+ )
184
+ state_dict[f"transformer_layers.{i}.attn2.to_out.0.weight"] = state_dict.pop(
185
+ f"transformer_layers.{i}.crossattention.out.weight"
186
+ )
187
+
188
+ state_dict[f"transformer_layers.{i}.norm3.norm.weight"] = state_dict.pop(
189
+ f"transformer_layers.{i}.ffn.pre_mlp_layer_norm.weight"
190
+ )
191
+ state_dict[f"transformer_layers.{i}.norm3.linear.weight"] = state_dict.pop(
192
+ f"transformer_layers.{i}.ffn.adaLN_modulation.mapper.weight"
193
+ )
194
+
195
+ wi_0_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_0.weight")
196
+ wi_1_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_1.weight")
197
+ proj_weight = torch.concat([wi_1_weight, wi_0_weight], dim=0)
198
+ state_dict[f"transformer_layers.{i}.ff.net.0.proj.weight"] = proj_weight
199
+
200
+ state_dict[f"transformer_layers.{i}.ff.net.2.weight"] = state_dict.pop(f"transformer_layers.{i}.ffn.wo.weight")
201
+
202
+ if force_down_up_sample:
203
+ state_dict["down_block.downsample.norm.weight"] = state_dict.pop("down_blocks.0.downsample.0.norm.weight")
204
+ state_dict["down_block.downsample.conv.weight"] = state_dict.pop("down_blocks.0.downsample.1.weight")
205
+
206
+ state_dict["up_block.upsample.norm.weight"] = state_dict.pop("up_blocks.0.upsample.0.norm.weight")
207
+ state_dict["up_block.upsample.conv.weight"] = state_dict.pop("up_blocks.0.upsample.1.weight")
208
+
209
+ state_dict["mlm_layer.layer_norm.weight"] = state_dict.pop("mlm_layer.layer_norm.norm.weight")
210
+
211
+ for i in range(3):
212
+ state_dict[f"down_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
213
+ f"down_blocks.0.res_blocks.{i}.norm.norm.weight"
214
+ )
215
+ state_dict[f"down_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
216
+ f"down_blocks.0.res_blocks.{i}.channelwise.0.weight"
217
+ )
218
+ state_dict[f"down_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
219
+ f"down_blocks.0.res_blocks.{i}.channelwise.2.gamma"
220
+ )
221
+ state_dict[f"down_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
222
+ f"down_blocks.0.res_blocks.{i}.channelwise.2.beta"
223
+ )
224
+ state_dict[f"down_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
225
+ f"down_blocks.0.res_blocks.{i}.channelwise.4.weight"
226
+ )
227
+ state_dict[f"down_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
228
+ f"down_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
229
+ )
230
+
231
+ state_dict[f"down_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
232
+ f"down_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
233
+ )
234
+ state_dict[f"down_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
235
+ f"down_blocks.0.attention_blocks.{i}.attention.query.weight"
236
+ )
237
+ state_dict[f"down_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
238
+ f"down_blocks.0.attention_blocks.{i}.attention.key.weight"
239
+ )
240
+ state_dict[f"down_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
241
+ f"down_blocks.0.attention_blocks.{i}.attention.value.weight"
242
+ )
243
+ state_dict[f"down_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
244
+ f"down_blocks.0.attention_blocks.{i}.attention.out.weight"
245
+ )
246
+
247
+ state_dict[f"down_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
248
+ f"down_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
249
+ )
250
+ state_dict[f"down_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
251
+ f"down_blocks.0.attention_blocks.{i}.crossattention.query.weight"
252
+ )
253
+ state_dict[f"down_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
254
+ f"down_blocks.0.attention_blocks.{i}.crossattention.key.weight"
255
+ )
256
+ state_dict[f"down_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
257
+ f"down_blocks.0.attention_blocks.{i}.crossattention.value.weight"
258
+ )
259
+ state_dict[f"down_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
260
+ f"down_blocks.0.attention_blocks.{i}.crossattention.out.weight"
261
+ )
262
+
263
+ state_dict[f"up_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
264
+ f"up_blocks.0.res_blocks.{i}.norm.norm.weight"
265
+ )
266
+ state_dict[f"up_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
267
+ f"up_blocks.0.res_blocks.{i}.channelwise.0.weight"
268
+ )
269
+ state_dict[f"up_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
270
+ f"up_blocks.0.res_blocks.{i}.channelwise.2.gamma"
271
+ )
272
+ state_dict[f"up_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
273
+ f"up_blocks.0.res_blocks.{i}.channelwise.2.beta"
274
+ )
275
+ state_dict[f"up_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
276
+ f"up_blocks.0.res_blocks.{i}.channelwise.4.weight"
277
+ )
278
+ state_dict[f"up_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
279
+ f"up_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
280
+ )
281
+
282
+ state_dict[f"up_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
283
+ f"up_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
284
+ )
285
+ state_dict[f"up_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
286
+ f"up_blocks.0.attention_blocks.{i}.attention.query.weight"
287
+ )
288
+ state_dict[f"up_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
289
+ f"up_blocks.0.attention_blocks.{i}.attention.key.weight"
290
+ )
291
+ state_dict[f"up_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
292
+ f"up_blocks.0.attention_blocks.{i}.attention.value.weight"
293
+ )
294
+ state_dict[f"up_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
295
+ f"up_blocks.0.attention_blocks.{i}.attention.out.weight"
296
+ )
297
+
298
+ state_dict[f"up_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
299
+ f"up_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
300
+ )
301
+ state_dict[f"up_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
302
+ f"up_blocks.0.attention_blocks.{i}.crossattention.query.weight"
303
+ )
304
+ state_dict[f"up_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
305
+ f"up_blocks.0.attention_blocks.{i}.crossattention.key.weight"
306
+ )
307
+ state_dict[f"up_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
308
+ f"up_blocks.0.attention_blocks.{i}.crossattention.value.weight"
309
+ )
310
+ state_dict[f"up_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
311
+ f"up_blocks.0.attention_blocks.{i}.crossattention.out.weight"
312
+ )
313
+
314
+ for key in list(state_dict.keys()):
315
+ if key.startswith("up_blocks.0"):
316
+ key_ = "up_block." + ".".join(key.split(".")[2:])
317
+ state_dict[key_] = state_dict.pop(key)
318
+
319
+ if key.startswith("down_blocks.0"):
320
+ key_ = "down_block." + ".".join(key.split(".")[2:])
321
+ state_dict[key_] = state_dict.pop(key)
322
+
323
+ new_transformer.load_state_dict(state_dict)
324
+
325
+ input_ids = torch.randint(0, 10, (1, 32, 32), device=old_transformer.device)
326
+ encoder_hidden_states = torch.randn((1, 77, 768), device=old_transformer.device)
327
+ cond_embeds = torch.randn((1, 768), device=old_transformer.device)
328
+ micro_conds = torch.tensor([[512, 512, 0, 0, 6]], dtype=torch.float32, device=old_transformer.device)
329
+
330
+ old_out = old_transformer(input_ids.reshape(1, -1), encoder_hidden_states, cond_embeds, micro_conds)
331
+ old_out = old_out.reshape(1, 32, 32, 8192).permute(0, 3, 1, 2)
332
+
333
+ new_out = new_transformer(input_ids, encoder_hidden_states, cond_embeds, micro_conds)
334
+
335
+ # NOTE: these differences are solely due to using the geglu block that has a single linear layer of
336
+ # double output dimension instead of two different linear layers
337
+ max_diff = (old_out - new_out).abs().max()
338
+ total_diff = (old_out - new_out).abs().sum()
339
+ print(f"Transformer max_diff: {max_diff} total_diff: {total_diff}")
340
+ assert max_diff < 0.01
341
+ assert total_diff < 1500
342
+
343
+ return new_transformer
344
+
345
+
346
+ def make_vqvae(old_vae):
347
+ new_vae = VQModel(
348
+ act_fn="silu",
349
+ block_out_channels=[128, 256, 256, 512, 768],
350
+ down_block_types=[
351
+ "DownEncoderBlock2D",
352
+ "DownEncoderBlock2D",
353
+ "DownEncoderBlock2D",
354
+ "DownEncoderBlock2D",
355
+ "DownEncoderBlock2D",
356
+ ],
357
+ in_channels=3,
358
+ latent_channels=64,
359
+ layers_per_block=2,
360
+ norm_num_groups=32,
361
+ num_vq_embeddings=8192,
362
+ out_channels=3,
363
+ sample_size=32,
364
+ up_block_types=[
365
+ "UpDecoderBlock2D",
366
+ "UpDecoderBlock2D",
367
+ "UpDecoderBlock2D",
368
+ "UpDecoderBlock2D",
369
+ "UpDecoderBlock2D",
370
+ ],
371
+ mid_block_add_attention=False,
372
+ lookup_from_codebook=True,
373
+ )
374
+ new_vae.to(device)
375
+
376
+ # fmt: off
377
+
378
+ new_state_dict = {}
379
+
380
+ old_state_dict = old_vae.state_dict()
381
+
382
+ new_state_dict["encoder.conv_in.weight"] = old_state_dict.pop("encoder.conv_in.weight")
383
+ new_state_dict["encoder.conv_in.bias"] = old_state_dict.pop("encoder.conv_in.bias")
384
+
385
+ convert_vae_block_state_dict(old_state_dict, "encoder.down.0", new_state_dict, "encoder.down_blocks.0")
386
+ convert_vae_block_state_dict(old_state_dict, "encoder.down.1", new_state_dict, "encoder.down_blocks.1")
387
+ convert_vae_block_state_dict(old_state_dict, "encoder.down.2", new_state_dict, "encoder.down_blocks.2")
388
+ convert_vae_block_state_dict(old_state_dict, "encoder.down.3", new_state_dict, "encoder.down_blocks.3")
389
+ convert_vae_block_state_dict(old_state_dict, "encoder.down.4", new_state_dict, "encoder.down_blocks.4")
390
+
391
+ new_state_dict["encoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("encoder.mid.block_1.norm1.weight")
392
+ new_state_dict["encoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("encoder.mid.block_1.norm1.bias")
393
+ new_state_dict["encoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("encoder.mid.block_1.conv1.weight")
394
+ new_state_dict["encoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("encoder.mid.block_1.conv1.bias")
395
+ new_state_dict["encoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("encoder.mid.block_1.norm2.weight")
396
+ new_state_dict["encoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("encoder.mid.block_1.norm2.bias")
397
+ new_state_dict["encoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("encoder.mid.block_1.conv2.weight")
398
+ new_state_dict["encoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("encoder.mid.block_1.conv2.bias")
399
+ new_state_dict["encoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("encoder.mid.block_2.norm1.weight")
400
+ new_state_dict["encoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("encoder.mid.block_2.norm1.bias")
401
+ new_state_dict["encoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("encoder.mid.block_2.conv1.weight")
402
+ new_state_dict["encoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("encoder.mid.block_2.conv1.bias")
403
+ new_state_dict["encoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("encoder.mid.block_2.norm2.weight")
404
+ new_state_dict["encoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("encoder.mid.block_2.norm2.bias")
405
+ new_state_dict["encoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("encoder.mid.block_2.conv2.weight")
406
+ new_state_dict["encoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("encoder.mid.block_2.conv2.bias")
407
+ new_state_dict["encoder.conv_norm_out.weight"] = old_state_dict.pop("encoder.norm_out.weight")
408
+ new_state_dict["encoder.conv_norm_out.bias"] = old_state_dict.pop("encoder.norm_out.bias")
409
+ new_state_dict["encoder.conv_out.weight"] = old_state_dict.pop("encoder.conv_out.weight")
410
+ new_state_dict["encoder.conv_out.bias"] = old_state_dict.pop("encoder.conv_out.bias")
411
+ new_state_dict["quant_conv.weight"] = old_state_dict.pop("quant_conv.weight")
412
+ new_state_dict["quant_conv.bias"] = old_state_dict.pop("quant_conv.bias")
413
+ new_state_dict["quantize.embedding.weight"] = old_state_dict.pop("quantize.embedding.weight")
414
+ new_state_dict["post_quant_conv.weight"] = old_state_dict.pop("post_quant_conv.weight")
415
+ new_state_dict["post_quant_conv.bias"] = old_state_dict.pop("post_quant_conv.bias")
416
+ new_state_dict["decoder.conv_in.weight"] = old_state_dict.pop("decoder.conv_in.weight")
417
+ new_state_dict["decoder.conv_in.bias"] = old_state_dict.pop("decoder.conv_in.bias")
418
+ new_state_dict["decoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("decoder.mid.block_1.norm1.weight")
419
+ new_state_dict["decoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("decoder.mid.block_1.norm1.bias")
420
+ new_state_dict["decoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("decoder.mid.block_1.conv1.weight")
421
+ new_state_dict["decoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("decoder.mid.block_1.conv1.bias")
422
+ new_state_dict["decoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("decoder.mid.block_1.norm2.weight")
423
+ new_state_dict["decoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("decoder.mid.block_1.norm2.bias")
424
+ new_state_dict["decoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("decoder.mid.block_1.conv2.weight")
425
+ new_state_dict["decoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("decoder.mid.block_1.conv2.bias")
426
+ new_state_dict["decoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("decoder.mid.block_2.norm1.weight")
427
+ new_state_dict["decoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("decoder.mid.block_2.norm1.bias")
428
+ new_state_dict["decoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("decoder.mid.block_2.conv1.weight")
429
+ new_state_dict["decoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("decoder.mid.block_2.conv1.bias")
430
+ new_state_dict["decoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("decoder.mid.block_2.norm2.weight")
431
+ new_state_dict["decoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("decoder.mid.block_2.norm2.bias")
432
+ new_state_dict["decoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("decoder.mid.block_2.conv2.weight")
433
+ new_state_dict["decoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("decoder.mid.block_2.conv2.bias")
434
+
435
+ convert_vae_block_state_dict(old_state_dict, "decoder.up.0", new_state_dict, "decoder.up_blocks.4")
436
+ convert_vae_block_state_dict(old_state_dict, "decoder.up.1", new_state_dict, "decoder.up_blocks.3")
437
+ convert_vae_block_state_dict(old_state_dict, "decoder.up.2", new_state_dict, "decoder.up_blocks.2")
438
+ convert_vae_block_state_dict(old_state_dict, "decoder.up.3", new_state_dict, "decoder.up_blocks.1")
439
+ convert_vae_block_state_dict(old_state_dict, "decoder.up.4", new_state_dict, "decoder.up_blocks.0")
440
+
441
+ new_state_dict["decoder.conv_norm_out.weight"] = old_state_dict.pop("decoder.norm_out.weight")
442
+ new_state_dict["decoder.conv_norm_out.bias"] = old_state_dict.pop("decoder.norm_out.bias")
443
+ new_state_dict["decoder.conv_out.weight"] = old_state_dict.pop("decoder.conv_out.weight")
444
+ new_state_dict["decoder.conv_out.bias"] = old_state_dict.pop("decoder.conv_out.bias")
445
+
446
+ # fmt: on
447
+
448
+ assert len(old_state_dict.keys()) == 0
449
+
450
+ new_vae.load_state_dict(new_state_dict)
451
+
452
+ input = torch.randn((1, 3, 512, 512), device=device)
453
+ input = input.clamp(-1, 1)
454
+
455
+ old_encoder_output = old_vae.quant_conv(old_vae.encoder(input))
456
+ new_encoder_output = new_vae.quant_conv(new_vae.encoder(input))
457
+ assert (old_encoder_output == new_encoder_output).all()
458
+
459
+ old_decoder_output = old_vae.decoder(old_vae.post_quant_conv(old_encoder_output))
460
+ new_decoder_output = new_vae.decoder(new_vae.post_quant_conv(new_encoder_output))
461
+
462
+ # assert (old_decoder_output == new_decoder_output).all()
463
+ print("kipping vae decoder equivalence check")
464
+ print(f"vae decoder diff {(old_decoder_output - new_decoder_output).float().abs().sum()}")
465
+
466
+ old_output = old_vae(input)[0]
467
+ new_output = new_vae(input)[0]
468
+
469
+ # assert (old_output == new_output).all()
470
+ print("skipping full vae equivalence check")
471
+ print(f"vae full diff {(old_output - new_output).float().abs().sum()}")
472
+
473
+ return new_vae
474
+
475
+
476
+ def convert_vae_block_state_dict(old_state_dict, prefix_from, new_state_dict, prefix_to):
477
+ # fmt: off
478
+
479
+ new_state_dict[f"{prefix_to}.resnets.0.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.weight")
480
+ new_state_dict[f"{prefix_to}.resnets.0.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.bias")
481
+ new_state_dict[f"{prefix_to}.resnets.0.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.weight")
482
+ new_state_dict[f"{prefix_to}.resnets.0.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.bias")
483
+ new_state_dict[f"{prefix_to}.resnets.0.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.weight")
484
+ new_state_dict[f"{prefix_to}.resnets.0.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.bias")
485
+ new_state_dict[f"{prefix_to}.resnets.0.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.weight")
486
+ new_state_dict[f"{prefix_to}.resnets.0.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.bias")
487
+
488
+ if f"{prefix_from}.block.0.nin_shortcut.weight" in old_state_dict:
489
+ new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.weight")
490
+ new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.bias")
491
+
492
+ new_state_dict[f"{prefix_to}.resnets.1.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.weight")
493
+ new_state_dict[f"{prefix_to}.resnets.1.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.bias")
494
+ new_state_dict[f"{prefix_to}.resnets.1.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.weight")
495
+ new_state_dict[f"{prefix_to}.resnets.1.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.bias")
496
+ new_state_dict[f"{prefix_to}.resnets.1.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.weight")
497
+ new_state_dict[f"{prefix_to}.resnets.1.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.bias")
498
+ new_state_dict[f"{prefix_to}.resnets.1.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.weight")
499
+ new_state_dict[f"{prefix_to}.resnets.1.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.bias")
500
+
501
+ if f"{prefix_from}.downsample.conv.weight" in old_state_dict:
502
+ new_state_dict[f"{prefix_to}.downsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.weight")
503
+ new_state_dict[f"{prefix_to}.downsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.bias")
504
+
505
+ if f"{prefix_from}.upsample.conv.weight" in old_state_dict:
506
+ new_state_dict[f"{prefix_to}.upsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.weight")
507
+ new_state_dict[f"{prefix_to}.upsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.bias")
508
+
509
+ if f"{prefix_from}.block.2.norm1.weight" in old_state_dict:
510
+ new_state_dict[f"{prefix_to}.resnets.2.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.weight")
511
+ new_state_dict[f"{prefix_to}.resnets.2.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.bias")
512
+ new_state_dict[f"{prefix_to}.resnets.2.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.weight")
513
+ new_state_dict[f"{prefix_to}.resnets.2.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.bias")
514
+ new_state_dict[f"{prefix_to}.resnets.2.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.weight")
515
+ new_state_dict[f"{prefix_to}.resnets.2.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.bias")
516
+ new_state_dict[f"{prefix_to}.resnets.2.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.weight")
517
+ new_state_dict[f"{prefix_to}.resnets.2.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.bias")
518
+
519
+ # fmt: on
520
+
521
+
522
+ if __name__ == "__main__":
523
+ main()
diffusers/scripts/convert_animatediff_motion_module_to_diffusers.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from safetensors.torch import load_file
5
+
6
+ from diffusers import MotionAdapter
7
+
8
+
9
+ def convert_motion_module(original_state_dict):
10
+ converted_state_dict = {}
11
+ for k, v in original_state_dict.items():
12
+ if "pos_encoder" in k:
13
+ continue
14
+
15
+ else:
16
+ converted_state_dict[
17
+ k.replace(".norms.0", ".norm1")
18
+ .replace(".norms.1", ".norm2")
19
+ .replace(".ff_norm", ".norm3")
20
+ .replace(".attention_blocks.0", ".attn1")
21
+ .replace(".attention_blocks.1", ".attn2")
22
+ .replace(".temporal_transformer", "")
23
+ ] = v
24
+
25
+ return converted_state_dict
26
+
27
+
28
+ def get_args():
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument("--ckpt_path", type=str, required=True)
31
+ parser.add_argument("--output_path", type=str, required=True)
32
+ parser.add_argument("--use_motion_mid_block", action="store_true")
33
+ parser.add_argument("--motion_max_seq_length", type=int, default=32)
34
+ parser.add_argument("--block_out_channels", nargs="+", default=[320, 640, 1280, 1280], type=int)
35
+ parser.add_argument("--save_fp16", action="store_true")
36
+
37
+ return parser.parse_args()
38
+
39
+
40
+ if __name__ == "__main__":
41
+ args = get_args()
42
+
43
+ if args.ckpt_path.endswith(".safetensors"):
44
+ state_dict = load_file(args.ckpt_path)
45
+ else:
46
+ state_dict = torch.load(args.ckpt_path, map_location="cpu")
47
+
48
+ if "state_dict" in state_dict.keys():
49
+ state_dict = state_dict["state_dict"]
50
+
51
+ conv_state_dict = convert_motion_module(state_dict)
52
+ adapter = MotionAdapter(
53
+ block_out_channels=args.block_out_channels,
54
+ use_motion_mid_block=args.use_motion_mid_block,
55
+ motion_max_seq_length=args.motion_max_seq_length,
56
+ )
57
+ # skip loading position embeddings
58
+ adapter.load_state_dict(conv_state_dict, strict=False)
59
+ adapter.save_pretrained(args.output_path)
60
+
61
+ if args.save_fp16:
62
+ adapter.to(dtype=torch.float16).save_pretrained(args.output_path, variant="fp16")
diffusers/scripts/convert_animatediff_sparsectrl_to_diffusers.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import Dict
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from diffusers import SparseControlNetModel
8
+
9
+
10
+ KEYS_RENAME_MAPPING = {
11
+ ".attention_blocks.0": ".attn1",
12
+ ".attention_blocks.1": ".attn2",
13
+ ".attn1.pos_encoder": ".pos_embed",
14
+ ".ff_norm": ".norm3",
15
+ ".norms.0": ".norm1",
16
+ ".norms.1": ".norm2",
17
+ ".temporal_transformer": "",
18
+ }
19
+
20
+
21
+ def convert(original_state_dict: Dict[str, nn.Module]) -> Dict[str, nn.Module]:
22
+ converted_state_dict = {}
23
+
24
+ for key in list(original_state_dict.keys()):
25
+ renamed_key = key
26
+ for new_name, old_name in KEYS_RENAME_MAPPING.items():
27
+ renamed_key = renamed_key.replace(new_name, old_name)
28
+ converted_state_dict[renamed_key] = original_state_dict.pop(key)
29
+
30
+ return converted_state_dict
31
+
32
+
33
+ def get_args():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument("--ckpt_path", type=str, required=True, help="Path to checkpoint")
36
+ parser.add_argument("--output_path", type=str, required=True, help="Path to output directory")
37
+ parser.add_argument(
38
+ "--max_motion_seq_length",
39
+ type=int,
40
+ default=32,
41
+ help="Max motion sequence length supported by the motion adapter",
42
+ )
43
+ parser.add_argument(
44
+ "--conditioning_channels", type=int, default=4, help="Number of channels in conditioning input to controlnet"
45
+ )
46
+ parser.add_argument(
47
+ "--use_simplified_condition_embedding",
48
+ action="store_true",
49
+ default=False,
50
+ help="Whether or not to use simplified condition embedding. When `conditioning_channels==4` i.e. latent inputs, set this to `True`. When `conditioning_channels==3` i.e. image inputs, set this to `False`",
51
+ )
52
+ parser.add_argument(
53
+ "--save_fp16",
54
+ action="store_true",
55
+ default=False,
56
+ help="Whether or not to save model in fp16 precision along with fp32",
57
+ )
58
+ parser.add_argument(
59
+ "--push_to_hub", action="store_true", default=False, help="Whether or not to push saved model to the HF hub"
60
+ )
61
+ return parser.parse_args()
62
+
63
+
64
+ if __name__ == "__main__":
65
+ args = get_args()
66
+
67
+ state_dict = torch.load(args.ckpt_path, map_location="cpu")
68
+ if "state_dict" in state_dict.keys():
69
+ state_dict: dict = state_dict["state_dict"]
70
+
71
+ controlnet = SparseControlNetModel(
72
+ conditioning_channels=args.conditioning_channels,
73
+ motion_max_seq_length=args.max_motion_seq_length,
74
+ use_simplified_condition_embedding=args.use_simplified_condition_embedding,
75
+ )
76
+
77
+ state_dict = convert(state_dict)
78
+ controlnet.load_state_dict(state_dict, strict=True)
79
+
80
+ controlnet.save_pretrained(args.output_path, push_to_hub=args.push_to_hub)
81
+ if args.save_fp16:
82
+ controlnet = controlnet.to(dtype=torch.float16)
83
+ controlnet.save_pretrained(args.output_path, variant="fp16", push_to_hub=args.push_to_hub)
diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Any, Dict, Literal
5
+
6
+ import torch
7
+
8
+ from diffusers import AsymmetricAutoencoderKL
9
+
10
+
11
+ ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG = {
12
+ "in_channels": 3,
13
+ "out_channels": 3,
14
+ "down_block_types": [
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D",
17
+ "DownEncoderBlock2D",
18
+ "DownEncoderBlock2D",
19
+ ],
20
+ "down_block_out_channels": [128, 256, 512, 512],
21
+ "layers_per_down_block": 2,
22
+ "up_block_types": [
23
+ "UpDecoderBlock2D",
24
+ "UpDecoderBlock2D",
25
+ "UpDecoderBlock2D",
26
+ "UpDecoderBlock2D",
27
+ ],
28
+ "up_block_out_channels": [192, 384, 768, 768],
29
+ "layers_per_up_block": 3,
30
+ "act_fn": "silu",
31
+ "latent_channels": 4,
32
+ "norm_num_groups": 32,
33
+ "sample_size": 256,
34
+ "scaling_factor": 0.18215,
35
+ }
36
+
37
+ ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG = {
38
+ "in_channels": 3,
39
+ "out_channels": 3,
40
+ "down_block_types": [
41
+ "DownEncoderBlock2D",
42
+ "DownEncoderBlock2D",
43
+ "DownEncoderBlock2D",
44
+ "DownEncoderBlock2D",
45
+ ],
46
+ "down_block_out_channels": [128, 256, 512, 512],
47
+ "layers_per_down_block": 2,
48
+ "up_block_types": [
49
+ "UpDecoderBlock2D",
50
+ "UpDecoderBlock2D",
51
+ "UpDecoderBlock2D",
52
+ "UpDecoderBlock2D",
53
+ ],
54
+ "up_block_out_channels": [256, 512, 1024, 1024],
55
+ "layers_per_up_block": 5,
56
+ "act_fn": "silu",
57
+ "latent_channels": 4,
58
+ "norm_num_groups": 32,
59
+ "sample_size": 256,
60
+ "scaling_factor": 0.18215,
61
+ }
62
+
63
+
64
+ def convert_asymmetric_autoencoder_kl_state_dict(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
65
+ converted_state_dict = {}
66
+ for k, v in original_state_dict.items():
67
+ if k.startswith("encoder."):
68
+ converted_state_dict[
69
+ k.replace("encoder.down.", "encoder.down_blocks.")
70
+ .replace("encoder.mid.", "encoder.mid_block.")
71
+ .replace("encoder.norm_out.", "encoder.conv_norm_out.")
72
+ .replace(".downsample.", ".downsamplers.0.")
73
+ .replace(".nin_shortcut.", ".conv_shortcut.")
74
+ .replace(".block.", ".resnets.")
75
+ .replace(".block_1.", ".resnets.0.")
76
+ .replace(".block_2.", ".resnets.1.")
77
+ .replace(".attn_1.k.", ".attentions.0.to_k.")
78
+ .replace(".attn_1.q.", ".attentions.0.to_q.")
79
+ .replace(".attn_1.v.", ".attentions.0.to_v.")
80
+ .replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
81
+ .replace(".attn_1.norm.", ".attentions.0.group_norm.")
82
+ ] = v
83
+ elif k.startswith("decoder.") and "up_layers" not in k:
84
+ converted_state_dict[
85
+ k.replace("decoder.encoder.", "decoder.condition_encoder.")
86
+ .replace(".norm_out.", ".conv_norm_out.")
87
+ .replace(".up.0.", ".up_blocks.3.")
88
+ .replace(".up.1.", ".up_blocks.2.")
89
+ .replace(".up.2.", ".up_blocks.1.")
90
+ .replace(".up.3.", ".up_blocks.0.")
91
+ .replace(".block.", ".resnets.")
92
+ .replace("mid", "mid_block")
93
+ .replace(".0.upsample.", ".0.upsamplers.0.")
94
+ .replace(".1.upsample.", ".1.upsamplers.0.")
95
+ .replace(".2.upsample.", ".2.upsamplers.0.")
96
+ .replace(".nin_shortcut.", ".conv_shortcut.")
97
+ .replace(".block_1.", ".resnets.0.")
98
+ .replace(".block_2.", ".resnets.1.")
99
+ .replace(".attn_1.k.", ".attentions.0.to_k.")
100
+ .replace(".attn_1.q.", ".attentions.0.to_q.")
101
+ .replace(".attn_1.v.", ".attentions.0.to_v.")
102
+ .replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
103
+ .replace(".attn_1.norm.", ".attentions.0.group_norm.")
104
+ ] = v
105
+ elif k.startswith("quant_conv."):
106
+ converted_state_dict[k] = v
107
+ elif k.startswith("post_quant_conv."):
108
+ converted_state_dict[k] = v
109
+ else:
110
+ print(f" skipping key `{k}`")
111
+ # fix weights shape
112
+ for k, v in converted_state_dict.items():
113
+ if (
114
+ (k.startswith("encoder.mid_block.attentions.0") or k.startswith("decoder.mid_block.attentions.0"))
115
+ and k.endswith("weight")
116
+ and ("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k)
117
+ ):
118
+ converted_state_dict[k] = converted_state_dict[k][:, :, 0, 0]
119
+
120
+ return converted_state_dict
121
+
122
+
123
+ def get_asymmetric_autoencoder_kl_from_original_checkpoint(
124
+ scale: Literal["1.5", "2"], original_checkpoint_path: str, map_location: torch.device
125
+ ) -> AsymmetricAutoencoderKL:
126
+ print("Loading original state_dict")
127
+ original_state_dict = torch.load(original_checkpoint_path, map_location=map_location)
128
+ original_state_dict = original_state_dict["state_dict"]
129
+ print("Converting state_dict")
130
+ converted_state_dict = convert_asymmetric_autoencoder_kl_state_dict(original_state_dict)
131
+ kwargs = ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG if scale == "1.5" else ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG
132
+ print("Initializing AsymmetricAutoencoderKL model")
133
+ asymmetric_autoencoder_kl = AsymmetricAutoencoderKL(**kwargs)
134
+ print("Loading weight from converted state_dict")
135
+ asymmetric_autoencoder_kl.load_state_dict(converted_state_dict)
136
+ asymmetric_autoencoder_kl.eval()
137
+ print("AsymmetricAutoencoderKL successfully initialized")
138
+ return asymmetric_autoencoder_kl
139
+
140
+
141
+ if __name__ == "__main__":
142
+ start = time.time()
143
+ parser = argparse.ArgumentParser()
144
+ parser.add_argument(
145
+ "--scale",
146
+ default=None,
147
+ type=str,
148
+ required=True,
149
+ help="Asymmetric VQGAN scale: `1.5` or `2`",
150
+ )
151
+ parser.add_argument(
152
+ "--original_checkpoint_path",
153
+ default=None,
154
+ type=str,
155
+ required=True,
156
+ help="Path to the original Asymmetric VQGAN checkpoint",
157
+ )
158
+ parser.add_argument(
159
+ "--output_path",
160
+ default=None,
161
+ type=str,
162
+ required=True,
163
+ help="Path to save pretrained AsymmetricAutoencoderKL model",
164
+ )
165
+ parser.add_argument(
166
+ "--map_location",
167
+ default="cpu",
168
+ type=str,
169
+ required=False,
170
+ help="The device passed to `map_location` when loading the checkpoint",
171
+ )
172
+ args = parser.parse_args()
173
+
174
+ assert args.scale in ["1.5", "2"], f"{args.scale} should be `1.5` of `2`"
175
+ assert Path(args.original_checkpoint_path).is_file()
176
+
177
+ asymmetric_autoencoder_kl = get_asymmetric_autoencoder_kl_from_original_checkpoint(
178
+ scale=args.scale,
179
+ original_checkpoint_path=args.original_checkpoint_path,
180
+ map_location=torch.device(args.map_location),
181
+ )
182
+ print("Saving pretrained AsymmetricAutoencoderKL")
183
+ asymmetric_autoencoder_kl.save_pretrained(args.output_path)
184
+ print(f"Done in {time.time() - start:.2f} seconds")
diffusers/scripts/convert_aura_flow_to_diffusers.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ from diffusers.models.transformers.auraflow_transformer_2d import AuraFlowTransformer2DModel
7
+
8
+
9
+ def load_original_state_dict(args):
10
+ model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="aura_diffusion_pytorch_model.bin")
11
+ state_dict = torch.load(model_pt, map_location="cpu")
12
+ return state_dict
13
+
14
+
15
+ def calculate_layers(state_dict_keys, key_prefix):
16
+ dit_layers = set()
17
+ for k in state_dict_keys:
18
+ if key_prefix in k:
19
+ dit_layers.add(int(k.split(".")[2]))
20
+ print(f"{key_prefix}: {len(dit_layers)}")
21
+ return len(dit_layers)
22
+
23
+
24
+ # similar to SD3 but only for the last norm layer
25
+ def swap_scale_shift(weight, dim):
26
+ shift, scale = weight.chunk(2, dim=0)
27
+ new_weight = torch.cat([scale, shift], dim=0)
28
+ return new_weight
29
+
30
+
31
+ def convert_transformer(state_dict):
32
+ converted_state_dict = {}
33
+ state_dict_keys = list(state_dict.keys())
34
+
35
+ converted_state_dict["register_tokens"] = state_dict.pop("model.register_tokens")
36
+ converted_state_dict["pos_embed.pos_embed"] = state_dict.pop("model.positional_encoding")
37
+ converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("model.init_x_linear.weight")
38
+ converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("model.init_x_linear.bias")
39
+
40
+ converted_state_dict["time_step_proj.linear_1.weight"] = state_dict.pop("model.t_embedder.mlp.0.weight")
41
+ converted_state_dict["time_step_proj.linear_1.bias"] = state_dict.pop("model.t_embedder.mlp.0.bias")
42
+ converted_state_dict["time_step_proj.linear_2.weight"] = state_dict.pop("model.t_embedder.mlp.2.weight")
43
+ converted_state_dict["time_step_proj.linear_2.bias"] = state_dict.pop("model.t_embedder.mlp.2.bias")
44
+
45
+ converted_state_dict["context_embedder.weight"] = state_dict.pop("model.cond_seq_linear.weight")
46
+
47
+ mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
48
+ single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
49
+
50
+ # MMDiT blocks 🎸.
51
+ for i in range(mmdit_layers):
52
+ # feed-forward
53
+ path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
54
+ weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
55
+ for orig_k, diffuser_k in path_mapping.items():
56
+ for k, v in weight_mapping.items():
57
+ converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = state_dict.pop(
58
+ f"model.double_layers.{i}.{orig_k}.{k}.weight"
59
+ )
60
+
61
+ # norms
62
+ path_mapping = {"modX": "norm1", "modC": "norm1_context"}
63
+ for orig_k, diffuser_k in path_mapping.items():
64
+ converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = state_dict.pop(
65
+ f"model.double_layers.{i}.{orig_k}.1.weight"
66
+ )
67
+
68
+ # attns
69
+ x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
70
+ context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
71
+ for attn_mapping in [x_attn_mapping, context_attn_mapping]:
72
+ for k, v in attn_mapping.items():
73
+ converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
74
+ f"model.double_layers.{i}.attn.{k}.weight"
75
+ )
76
+
77
+ # Single-DiT blocks.
78
+ for i in range(single_dit_layers):
79
+ # feed-forward
80
+ mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
81
+ for k, v in mapping.items():
82
+ converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = state_dict.pop(
83
+ f"model.single_layers.{i}.mlp.{k}.weight"
84
+ )
85
+
86
+ # norms
87
+ converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = state_dict.pop(
88
+ f"model.single_layers.{i}.modCX.1.weight"
89
+ )
90
+
91
+ # attns
92
+ x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
93
+ for k, v in x_attn_mapping.items():
94
+ converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
95
+ f"model.single_layers.{i}.attn.{k}.weight"
96
+ )
97
+
98
+ # Final blocks.
99
+ converted_state_dict["proj_out.weight"] = state_dict.pop("model.final_linear.weight")
100
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict.pop("model.modF.1.weight"), dim=None)
101
+
102
+ return converted_state_dict
103
+
104
+
105
+ @torch.no_grad()
106
+ def populate_state_dict(args):
107
+ original_state_dict = load_original_state_dict(args)
108
+ state_dict_keys = list(original_state_dict.keys())
109
+ mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
110
+ single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
111
+
112
+ converted_state_dict = convert_transformer(original_state_dict)
113
+ model_diffusers = AuraFlowTransformer2DModel(
114
+ num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers
115
+ )
116
+ model_diffusers.load_state_dict(converted_state_dict, strict=True)
117
+
118
+ return model_diffusers
119
+
120
+
121
+ if __name__ == "__main__":
122
+ parser = argparse.ArgumentParser()
123
+ parser.add_argument("--original_state_dict_repo_id", default="AuraDiffusion/auradiffusion-v0.1a0", type=str)
124
+ parser.add_argument("--dump_path", default="aura-flow", type=str)
125
+ parser.add_argument("--hub_id", default=None, type=str)
126
+ args = parser.parse_args()
127
+
128
+ model_diffusers = populate_state_dict(args)
129
+ model_diffusers.save_pretrained(args.dump_path)
130
+ if args.hub_id is not None:
131
+ model_diffusers.push_to_hub(args.hub_id)
diffusers/scripts/convert_blipdiffusion_to_diffusers.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main.
3
+ """
4
+
5
+ import argparse
6
+ import os
7
+ import tempfile
8
+
9
+ import torch
10
+ from lavis.models import load_model_and_preprocess
11
+ from transformers import CLIPTokenizer
12
+ from transformers.models.blip_2.configuration_blip_2 import Blip2Config
13
+
14
+ from diffusers import (
15
+ AutoencoderKL,
16
+ PNDMScheduler,
17
+ UNet2DConditionModel,
18
+ )
19
+ from diffusers.pipelines import BlipDiffusionPipeline
20
+ from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
21
+ from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
22
+ from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
23
+
24
+
25
+ BLIP2_CONFIG = {
26
+ "vision_config": {
27
+ "hidden_size": 1024,
28
+ "num_hidden_layers": 23,
29
+ "num_attention_heads": 16,
30
+ "image_size": 224,
31
+ "patch_size": 14,
32
+ "intermediate_size": 4096,
33
+ "hidden_act": "quick_gelu",
34
+ },
35
+ "qformer_config": {
36
+ "cross_attention_frequency": 1,
37
+ "encoder_hidden_size": 1024,
38
+ "vocab_size": 30523,
39
+ },
40
+ "num_query_tokens": 16,
41
+ }
42
+ blip2config = Blip2Config(**BLIP2_CONFIG)
43
+
44
+
45
+ def qformer_model_from_original_config():
46
+ qformer = Blip2QFormerModel(blip2config)
47
+ return qformer
48
+
49
+
50
+ def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix):
51
+ embeddings = {}
52
+ embeddings.update(
53
+ {
54
+ f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[
55
+ f"{original_embeddings_prefix}.word_embeddings.weight"
56
+ ]
57
+ }
58
+ )
59
+ embeddings.update(
60
+ {
61
+ f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[
62
+ f"{original_embeddings_prefix}.position_embeddings.weight"
63
+ ]
64
+ }
65
+ )
66
+ embeddings.update(
67
+ {f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]}
68
+ )
69
+ embeddings.update(
70
+ {f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]}
71
+ )
72
+ return embeddings
73
+
74
+
75
+ def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix):
76
+ proj_layer = {}
77
+ proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]})
78
+ proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]})
79
+ proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]})
80
+ proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]})
81
+ proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]})
82
+ proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]})
83
+ return proj_layer
84
+
85
+
86
+ def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix):
87
+ attention = {}
88
+ attention.update(
89
+ {
90
+ f"{diffuser_attention_prefix}.attention.query.weight": model[
91
+ f"{original_attention_prefix}.self.query.weight"
92
+ ]
93
+ }
94
+ )
95
+ attention.update(
96
+ {f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]}
97
+ )
98
+ attention.update(
99
+ {f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]}
100
+ )
101
+ attention.update(
102
+ {f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]}
103
+ )
104
+ attention.update(
105
+ {
106
+ f"{diffuser_attention_prefix}.attention.value.weight": model[
107
+ f"{original_attention_prefix}.self.value.weight"
108
+ ]
109
+ }
110
+ )
111
+ attention.update(
112
+ {f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]}
113
+ )
114
+ attention.update(
115
+ {f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]}
116
+ )
117
+ attention.update(
118
+ {f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]}
119
+ )
120
+ attention.update(
121
+ {
122
+ f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[
123
+ f"{original_attention_prefix}.output.LayerNorm.weight"
124
+ ]
125
+ }
126
+ )
127
+ attention.update(
128
+ {
129
+ f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[
130
+ f"{original_attention_prefix}.output.LayerNorm.bias"
131
+ ]
132
+ }
133
+ )
134
+ return attention
135
+
136
+
137
+ def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix):
138
+ output_layers = {}
139
+ output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]})
140
+ output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]})
141
+ output_layers.update(
142
+ {f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]}
143
+ )
144
+ output_layers.update(
145
+ {f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]}
146
+ )
147
+ return output_layers
148
+
149
+
150
+ def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix):
151
+ encoder = {}
152
+ for i in range(blip2config.qformer_config.num_hidden_layers):
153
+ encoder.update(
154
+ attention_from_original_checkpoint(
155
+ model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention"
156
+ )
157
+ )
158
+ encoder.update(
159
+ attention_from_original_checkpoint(
160
+ model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention"
161
+ )
162
+ )
163
+
164
+ encoder.update(
165
+ {
166
+ f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[
167
+ f"{original_encoder_prefix}.{i}.intermediate.dense.weight"
168
+ ]
169
+ }
170
+ )
171
+ encoder.update(
172
+ {
173
+ f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[
174
+ f"{original_encoder_prefix}.{i}.intermediate.dense.bias"
175
+ ]
176
+ }
177
+ )
178
+ encoder.update(
179
+ {
180
+ f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[
181
+ f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight"
182
+ ]
183
+ }
184
+ )
185
+ encoder.update(
186
+ {
187
+ f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[
188
+ f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias"
189
+ ]
190
+ }
191
+ )
192
+
193
+ encoder.update(
194
+ output_layers_from_original_checkpoint(
195
+ model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output"
196
+ )
197
+ )
198
+ encoder.update(
199
+ output_layers_from_original_checkpoint(
200
+ model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query"
201
+ )
202
+ )
203
+ return encoder
204
+
205
+
206
+ def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix):
207
+ visual_encoder_layer = {}
208
+
209
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]})
210
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]})
211
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]})
212
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]})
213
+ visual_encoder_layer.update(
214
+ {f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]}
215
+ )
216
+ visual_encoder_layer.update(
217
+ {f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]}
218
+ )
219
+ visual_encoder_layer.update(
220
+ {f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]}
221
+ )
222
+ visual_encoder_layer.update(
223
+ {f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]}
224
+ )
225
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]})
226
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]})
227
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]})
228
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]})
229
+
230
+ return visual_encoder_layer
231
+
232
+
233
+ def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix):
234
+ visual_encoder = {}
235
+
236
+ visual_encoder.update(
237
+ {
238
+ f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"]
239
+ .unsqueeze(0)
240
+ .unsqueeze(0)
241
+ }
242
+ )
243
+ visual_encoder.update(
244
+ {
245
+ f"{diffuser_prefix}.embeddings.position_embedding": model[
246
+ f"{original_prefix}.positional_embedding"
247
+ ].unsqueeze(0)
248
+ }
249
+ )
250
+ visual_encoder.update(
251
+ {f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]}
252
+ )
253
+ visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]})
254
+ visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]})
255
+
256
+ for i in range(blip2config.vision_config.num_hidden_layers):
257
+ visual_encoder.update(
258
+ visual_encoder_layer_from_original_checkpoint(
259
+ model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}"
260
+ )
261
+ )
262
+
263
+ visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]})
264
+ visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]})
265
+
266
+ return visual_encoder
267
+
268
+
269
+ def qformer_original_checkpoint_to_diffusers_checkpoint(model):
270
+ qformer_checkpoint = {}
271
+ qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings"))
272
+ qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]})
273
+ qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer"))
274
+ qformer_checkpoint.update(
275
+ encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer")
276
+ )
277
+ qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder"))
278
+ return qformer_checkpoint
279
+
280
+
281
+ def get_qformer(model):
282
+ print("loading qformer")
283
+
284
+ qformer = qformer_model_from_original_config()
285
+ qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model)
286
+
287
+ load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer)
288
+
289
+ print("done loading qformer")
290
+ return qformer
291
+
292
+
293
+ def load_checkpoint_to_model(checkpoint, model):
294
+ with tempfile.NamedTemporaryFile(delete=False) as file:
295
+ torch.save(checkpoint, file.name)
296
+ del checkpoint
297
+ model.load_state_dict(torch.load(file.name), strict=False)
298
+
299
+ os.remove(file.name)
300
+
301
+
302
+ def save_blip_diffusion_model(model, args):
303
+ qformer = get_qformer(model)
304
+ qformer.eval()
305
+
306
+ text_encoder = ContextCLIPTextModel.from_pretrained(
307
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="text_encoder"
308
+ )
309
+ vae = AutoencoderKL.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="vae")
310
+ unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
311
+ vae.eval()
312
+ text_encoder.eval()
313
+ scheduler = PNDMScheduler(
314
+ beta_start=0.00085,
315
+ beta_end=0.012,
316
+ beta_schedule="scaled_linear",
317
+ set_alpha_to_one=False,
318
+ skip_prk_steps=True,
319
+ )
320
+ tokenizer = CLIPTokenizer.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="tokenizer")
321
+ image_processor = BlipImageProcessor()
322
+ blip_diffusion = BlipDiffusionPipeline(
323
+ tokenizer=tokenizer,
324
+ text_encoder=text_encoder,
325
+ vae=vae,
326
+ unet=unet,
327
+ scheduler=scheduler,
328
+ qformer=qformer,
329
+ image_processor=image_processor,
330
+ )
331
+ blip_diffusion.save_pretrained(args.checkpoint_path)
332
+
333
+
334
+ def main(args):
335
+ model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True)
336
+ save_blip_diffusion_model(model.state_dict(), args)
337
+
338
+
339
+ if __name__ == "__main__":
340
+ parser = argparse.ArgumentParser()
341
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
342
+ args = parser.parse_args()
343
+
344
+ main(args)
diffusers/scripts/convert_cogview3_to_diffusers.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert a CogView3 checkpoint to the Diffusers format.
3
+
4
+ This script converts a CogView3 checkpoint to the Diffusers format, which can then be used
5
+ with the Diffusers library.
6
+
7
+ Example usage:
8
+ python scripts/convert_cogview3_to_diffusers.py \
9
+ --transformer_checkpoint_path 'your path/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
10
+ --vae_checkpoint_path 'your path/3plus_ae/imagekl_ch16.pt' \
11
+ --output_path "/raid/yiyi/cogview3_diffusers" \
12
+ --dtype "bf16"
13
+
14
+ Arguments:
15
+ --transformer_checkpoint_path: Path to Transformer state dict.
16
+ --vae_checkpoint_path: Path to VAE state dict.
17
+ --output_path: The path to save the converted model.
18
+ --push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
19
+ --text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
20
+ --dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
21
+
22
+ Default is "bf16" because CogView3 uses bfloat16 for Training.
23
+
24
+ Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
25
+ """
26
+
27
+ import argparse
28
+ from contextlib import nullcontext
29
+
30
+ import torch
31
+ from accelerate import init_empty_weights
32
+ from transformers import T5EncoderModel, T5Tokenizer
33
+
34
+ from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
35
+ from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
36
+ from diffusers.utils.import_utils import is_accelerate_available
37
+
38
+
39
+ CTX = init_empty_weights if is_accelerate_available() else nullcontext
40
+
41
+ TOKENIZER_MAX_LENGTH = 224
42
+
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
45
+ parser.add_argument("--vae_checkpoint_path", default=None, type=str)
46
+ parser.add_argument("--output_path", required=True, type=str)
47
+ parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
48
+ parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
49
+ parser.add_argument("--dtype", type=str, default="bf16")
50
+
51
+ args = parser.parse_args()
52
+
53
+
54
+ # this is specific to `AdaLayerNormContinuous`:
55
+ # diffusers implementation split the linear projection into the scale, shift while CogView3 split it tino shift, scale
56
+ def swap_scale_shift(weight, dim):
57
+ shift, scale = weight.chunk(2, dim=0)
58
+ new_weight = torch.cat([scale, shift], dim=0)
59
+ return new_weight
60
+
61
+
62
+ def convert_cogview3_transformer_checkpoint_to_diffusers(ckpt_path):
63
+ original_state_dict = torch.load(ckpt_path, map_location="cpu")
64
+ original_state_dict = original_state_dict["module"]
65
+ original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()}
66
+
67
+ new_state_dict = {}
68
+
69
+ # Convert patch_embed
70
+ new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
71
+ new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
72
+ new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
73
+ new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
74
+
75
+ # Convert time_condition_embed
76
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
77
+ "time_embed.0.weight"
78
+ )
79
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
80
+ "time_embed.0.bias"
81
+ )
82
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
83
+ "time_embed.2.weight"
84
+ )
85
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
86
+ "time_embed.2.bias"
87
+ )
88
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop(
89
+ "label_emb.0.0.weight"
90
+ )
91
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop(
92
+ "label_emb.0.0.bias"
93
+ )
94
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop(
95
+ "label_emb.0.2.weight"
96
+ )
97
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop(
98
+ "label_emb.0.2.bias"
99
+ )
100
+
101
+ # Convert transformer blocks
102
+ for i in range(30):
103
+ block_prefix = f"transformer_blocks.{i}."
104
+ old_prefix = f"transformer.layers.{i}."
105
+ adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
106
+
107
+ new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
108
+ new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
109
+
110
+ qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
111
+ qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
112
+ q, k, v = qkv_weight.chunk(3, dim=0)
113
+ q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
114
+
115
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q
116
+ new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
117
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k
118
+ new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
119
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v
120
+ new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
121
+
122
+ new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
123
+ old_prefix + "attention.dense.weight"
124
+ )
125
+ new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(
126
+ old_prefix + "attention.dense.bias"
127
+ )
128
+
129
+ new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
130
+ old_prefix + "mlp.dense_h_to_4h.weight"
131
+ )
132
+ new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
133
+ old_prefix + "mlp.dense_h_to_4h.bias"
134
+ )
135
+ new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
136
+ old_prefix + "mlp.dense_4h_to_h.weight"
137
+ )
138
+ new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")
139
+
140
+ # Convert final norm and projection
141
+ new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
142
+ original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
143
+ )
144
+ new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
145
+ original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
146
+ )
147
+ new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
148
+ new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")
149
+
150
+ return new_state_dict
151
+
152
+
153
+ def convert_cogview3_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
154
+ original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
155
+ return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
156
+
157
+
158
+ def main(args):
159
+ if args.dtype == "fp16":
160
+ dtype = torch.float16
161
+ elif args.dtype == "bf16":
162
+ dtype = torch.bfloat16
163
+ elif args.dtype == "fp32":
164
+ dtype = torch.float32
165
+ else:
166
+ raise ValueError(f"Unsupported dtype: {args.dtype}")
167
+
168
+ transformer = None
169
+ vae = None
170
+
171
+ if args.transformer_checkpoint_path is not None:
172
+ converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers(
173
+ args.transformer_checkpoint_path
174
+ )
175
+ transformer = CogView3PlusTransformer2DModel()
176
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
177
+ if dtype is not None:
178
+ # Original checkpoint data type will be preserved
179
+ transformer = transformer.to(dtype=dtype)
180
+
181
+ if args.vae_checkpoint_path is not None:
182
+ vae_config = {
183
+ "in_channels": 3,
184
+ "out_channels": 3,
185
+ "down_block_types": ("DownEncoderBlock2D",) * 4,
186
+ "up_block_types": ("UpDecoderBlock2D",) * 4,
187
+ "block_out_channels": (128, 512, 1024, 1024),
188
+ "layers_per_block": 3,
189
+ "act_fn": "silu",
190
+ "latent_channels": 16,
191
+ "norm_num_groups": 32,
192
+ "sample_size": 1024,
193
+ "scaling_factor": 1.0,
194
+ "force_upcast": True,
195
+ "use_quant_conv": False,
196
+ "use_post_quant_conv": False,
197
+ "mid_block_add_attention": False,
198
+ }
199
+ converted_vae_state_dict = convert_cogview3_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
200
+ vae = AutoencoderKL(**vae_config)
201
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
202
+ if dtype is not None:
203
+ vae = vae.to(dtype=dtype)
204
+
205
+ text_encoder_id = "google/t5-v1_1-xxl"
206
+ tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
207
+ text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
208
+
209
+ # Apparently, the conversion does not work anymore without this :shrug:
210
+ for param in text_encoder.parameters():
211
+ param.data = param.data.contiguous()
212
+
213
+ scheduler = CogVideoXDDIMScheduler.from_config(
214
+ {
215
+ "snr_shift_scale": 4.0,
216
+ "beta_end": 0.012,
217
+ "beta_schedule": "scaled_linear",
218
+ "beta_start": 0.00085,
219
+ "clip_sample": False,
220
+ "num_train_timesteps": 1000,
221
+ "prediction_type": "v_prediction",
222
+ "rescale_betas_zero_snr": True,
223
+ "set_alpha_to_one": True,
224
+ "timestep_spacing": "trailing",
225
+ }
226
+ )
227
+
228
+ pipe = CogView3PlusPipeline(
229
+ tokenizer=tokenizer,
230
+ text_encoder=text_encoder,
231
+ vae=vae,
232
+ transformer=transformer,
233
+ scheduler=scheduler,
234
+ )
235
+
236
+ # This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can
237
+ # save some memory used for model loading.
238
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
239
+
240
+
241
+ if __name__ == "__main__":
242
+ main(args)
diffusers/scripts/convert_cogview4_to_diffusers.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert a CogView4 checkpoint from SAT(https://github.com/THUDM/SwissArmyTransformer) to the Diffusers format.
3
+ (deprecated Since 2025-02-07 and will remove it in later CogView4 version)
4
+
5
+ This script converts a CogView4 checkpoint to the Diffusers format, which can then be used
6
+ with the Diffusers library.
7
+
8
+ Example usage:
9
+ python scripts/convert_cogview4_to_diffusers.py \
10
+ --transformer_checkpoint_path 'your path/cogview4_6b/1/mp_rank_00_model_states.pt' \
11
+ --vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \
12
+ --output_path "THUDM/CogView4-6B" \
13
+ --dtype "bf16"
14
+
15
+ Arguments:
16
+ --transformer_checkpoint_path: Path to Transformer state dict.
17
+ --vae_checkpoint_path: Path to VAE state dict.
18
+ --output_path: The path to save the converted model.
19
+ --push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
20
+ --text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
21
+ --dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
22
+
23
+ Default is "bf16" because CogView4 uses bfloat16 for Training.
24
+
25
+ Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
26
+ """
27
+
28
+ import argparse
29
+ from contextlib import nullcontext
30
+
31
+ import torch
32
+ from accelerate import init_empty_weights
33
+ from transformers import GlmForCausalLM, PreTrainedTokenizerFast
34
+
35
+ from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
36
+ from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
37
+ from diffusers.utils.import_utils import is_accelerate_available
38
+
39
+
40
+ CTX = init_empty_weights if is_accelerate_available() else nullcontext
41
+
42
+ parser = argparse.ArgumentParser()
43
+ parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
44
+ parser.add_argument("--vae_checkpoint_path", default=None, type=str)
45
+ parser.add_argument("--output_path", required=True, type=str)
46
+ parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
47
+ parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
48
+ parser.add_argument("--dtype", type=str, default="bf16")
49
+
50
+ args = parser.parse_args()
51
+
52
+
53
+ # this is specific to `AdaLayerNormContinuous`:
54
+ # diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale
55
+ def swap_scale_shift(weight, dim):
56
+ """
57
+ Swap the scale and shift components in the weight tensor.
58
+
59
+ Args:
60
+ weight (torch.Tensor): The original weight tensor.
61
+ dim (int): The dimension along which to split.
62
+
63
+ Returns:
64
+ torch.Tensor: The modified weight tensor with scale and shift swapped.
65
+ """
66
+ shift, scale = weight.chunk(2, dim=dim)
67
+ new_weight = torch.cat([scale, shift], dim=dim)
68
+ return new_weight
69
+
70
+
71
+ def convert_cogview4_transformer_checkpoint_to_diffusers(ckpt_path):
72
+ original_state_dict = torch.load(ckpt_path, map_location="cpu")
73
+ original_state_dict = original_state_dict["module"]
74
+ original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()}
75
+
76
+ new_state_dict = {}
77
+
78
+ # Convert patch_embed
79
+ new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
80
+ new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
81
+ new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
82
+ new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
83
+
84
+ # Convert time_condition_embed
85
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
86
+ "time_embed.0.weight"
87
+ )
88
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
89
+ "time_embed.0.bias"
90
+ )
91
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
92
+ "time_embed.2.weight"
93
+ )
94
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
95
+ "time_embed.2.bias"
96
+ )
97
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop(
98
+ "label_emb.0.0.weight"
99
+ )
100
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop(
101
+ "label_emb.0.0.bias"
102
+ )
103
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop(
104
+ "label_emb.0.2.weight"
105
+ )
106
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop(
107
+ "label_emb.0.2.bias"
108
+ )
109
+
110
+ # Convert transformer blocks, for cogview4 is 28 blocks
111
+ for i in range(28):
112
+ block_prefix = f"transformer_blocks.{i}."
113
+ old_prefix = f"transformer.layers.{i}."
114
+ adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
115
+ new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
116
+ new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
117
+
118
+ qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
119
+ qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
120
+ q, k, v = qkv_weight.chunk(3, dim=0)
121
+ q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
122
+
123
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q
124
+ new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
125
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k
126
+ new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
127
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v
128
+ new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
129
+
130
+ new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
131
+ old_prefix + "attention.dense.weight"
132
+ )
133
+ new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(
134
+ old_prefix + "attention.dense.bias"
135
+ )
136
+
137
+ new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
138
+ old_prefix + "mlp.dense_h_to_4h.weight"
139
+ )
140
+ new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
141
+ old_prefix + "mlp.dense_h_to_4h.bias"
142
+ )
143
+ new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
144
+ old_prefix + "mlp.dense_4h_to_h.weight"
145
+ )
146
+ new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")
147
+
148
+ # Convert final norm and projection
149
+ new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
150
+ original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
151
+ )
152
+ new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
153
+ original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
154
+ )
155
+ new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
156
+ new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")
157
+
158
+ return new_state_dict
159
+
160
+
161
+ def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
162
+ original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
163
+ return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
164
+
165
+
166
+ def main(args):
167
+ if args.dtype == "fp16":
168
+ dtype = torch.float16
169
+ elif args.dtype == "bf16":
170
+ dtype = torch.bfloat16
171
+ elif args.dtype == "fp32":
172
+ dtype = torch.float32
173
+ else:
174
+ raise ValueError(f"Unsupported dtype: {args.dtype}")
175
+
176
+ transformer = None
177
+ vae = None
178
+
179
+ if args.transformer_checkpoint_path is not None:
180
+ converted_transformer_state_dict = convert_cogview4_transformer_checkpoint_to_diffusers(
181
+ args.transformer_checkpoint_path
182
+ )
183
+ transformer = CogView4Transformer2DModel(
184
+ patch_size=2,
185
+ in_channels=16,
186
+ num_layers=28,
187
+ attention_head_dim=128,
188
+ num_attention_heads=32,
189
+ out_channels=16,
190
+ text_embed_dim=4096,
191
+ time_embed_dim=512,
192
+ condition_dim=256,
193
+ pos_embed_max_size=128,
194
+ )
195
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
196
+ if dtype is not None:
197
+ # Original checkpoint data type will be preserved
198
+ transformer = transformer.to(dtype=dtype)
199
+
200
+ if args.vae_checkpoint_path is not None:
201
+ vae_config = {
202
+ "in_channels": 3,
203
+ "out_channels": 3,
204
+ "down_block_types": ("DownEncoderBlock2D",) * 4,
205
+ "up_block_types": ("UpDecoderBlock2D",) * 4,
206
+ "block_out_channels": (128, 512, 1024, 1024),
207
+ "layers_per_block": 3,
208
+ "act_fn": "silu",
209
+ "latent_channels": 16,
210
+ "norm_num_groups": 32,
211
+ "sample_size": 1024,
212
+ "scaling_factor": 1.0,
213
+ "shift_factor": 0.0,
214
+ "force_upcast": True,
215
+ "use_quant_conv": False,
216
+ "use_post_quant_conv": False,
217
+ "mid_block_add_attention": False,
218
+ }
219
+ converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
220
+ vae = AutoencoderKL(**vae_config)
221
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
222
+ if dtype is not None:
223
+ vae = vae.to(dtype=dtype)
224
+
225
+ text_encoder_id = "THUDM/glm-4-9b-hf"
226
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
227
+ text_encoder = GlmForCausalLM.from_pretrained(
228
+ text_encoder_id,
229
+ cache_dir=args.text_encoder_cache_dir,
230
+ torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
231
+ )
232
+
233
+ for param in text_encoder.parameters():
234
+ param.data = param.data.contiguous()
235
+
236
+ scheduler = FlowMatchEulerDiscreteScheduler(
237
+ base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
238
+ )
239
+
240
+ pipe = CogView4Pipeline(
241
+ tokenizer=tokenizer,
242
+ text_encoder=text_encoder,
243
+ vae=vae,
244
+ transformer=transformer,
245
+ scheduler=scheduler,
246
+ )
247
+
248
+ # This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can
249
+ # save some memory used for model loading.
250
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
251
+
252
+
253
+ if __name__ == "__main__":
254
+ main(args)
diffusers/scripts/convert_cogview4_to_diffusers_megatron.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert a CogView4 checkpoint from Megatron to the Diffusers format.
3
+
4
+ Example usage:
5
+ python scripts/convert_cogview4_to_diffusers.py \
6
+ --transformer_checkpoint_path 'your path/cogview4_6b/mp_rank_00/model_optim_rng.pt' \
7
+ --vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \
8
+ --output_path "THUDM/CogView4-6B" \
9
+ --dtype "bf16"
10
+
11
+ Arguments:
12
+ --transformer_checkpoint_path: Path to Transformer state dict.
13
+ --vae_checkpoint_path: Path to VAE state dict.
14
+ --output_path: The path to save the converted model.
15
+ --push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
16
+ --text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used.
17
+ --dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
18
+
19
+ Default is "bf16" because CogView4 uses bfloat16 for training.
20
+
21
+ Note: You must provide either --transformer_checkpoint_path or --vae_checkpoint_path.
22
+ """
23
+
24
+ import argparse
25
+
26
+ import torch
27
+ from tqdm import tqdm
28
+ from transformers import GlmModel, PreTrainedTokenizerFast
29
+
30
+ from diffusers import (
31
+ AutoencoderKL,
32
+ CogView4ControlPipeline,
33
+ CogView4Pipeline,
34
+ CogView4Transformer2DModel,
35
+ FlowMatchEulerDiscreteScheduler,
36
+ )
37
+ from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
38
+
39
+
40
+ parser = argparse.ArgumentParser()
41
+ parser.add_argument(
42
+ "--transformer_checkpoint_path",
43
+ default=None,
44
+ type=str,
45
+ help="Path to Megatron (not SAT) Transformer checkpoint, e.g., 'model_optim_rng.pt'.",
46
+ )
47
+ parser.add_argument(
48
+ "--vae_checkpoint_path",
49
+ default=None,
50
+ type=str,
51
+ help="(Optional) Path to VAE checkpoint, e.g., 'imagekl_ch16.pt'.",
52
+ )
53
+ parser.add_argument(
54
+ "--output_path",
55
+ required=True,
56
+ type=str,
57
+ help="Directory to save the final Diffusers format pipeline.",
58
+ )
59
+ parser.add_argument(
60
+ "--push_to_hub",
61
+ action="store_true",
62
+ default=False,
63
+ help="Whether to push the converted model to the HuggingFace Hub.",
64
+ )
65
+ parser.add_argument(
66
+ "--text_encoder_cache_dir",
67
+ type=str,
68
+ default=None,
69
+ help="Specify the cache directory for the text encoder.",
70
+ )
71
+ parser.add_argument(
72
+ "--dtype",
73
+ type=str,
74
+ default="bf16",
75
+ choices=["fp16", "bf16", "fp32"],
76
+ help="Data type to save the model in.",
77
+ )
78
+
79
+ parser.add_argument(
80
+ "--num_layers",
81
+ type=int,
82
+ default=28,
83
+ help="Number of Transformer layers (e.g., 28, 48...).",
84
+ )
85
+ parser.add_argument(
86
+ "--num_heads",
87
+ type=int,
88
+ default=32,
89
+ help="Number of attention heads.",
90
+ )
91
+ parser.add_argument(
92
+ "--hidden_size",
93
+ type=int,
94
+ default=4096,
95
+ help="Transformer hidden dimension size.",
96
+ )
97
+ parser.add_argument(
98
+ "--attention_head_dim",
99
+ type=int,
100
+ default=128,
101
+ help="Dimension of each attention head.",
102
+ )
103
+ parser.add_argument(
104
+ "--time_embed_dim",
105
+ type=int,
106
+ default=512,
107
+ help="Dimension of time embeddings.",
108
+ )
109
+ parser.add_argument(
110
+ "--condition_dim",
111
+ type=int,
112
+ default=256,
113
+ help="Dimension of condition embeddings.",
114
+ )
115
+ parser.add_argument(
116
+ "--pos_embed_max_size",
117
+ type=int,
118
+ default=128,
119
+ help="Maximum size for positional embeddings.",
120
+ )
121
+ parser.add_argument(
122
+ "--control",
123
+ action="store_true",
124
+ default=False,
125
+ help="Whether to use control model.",
126
+ )
127
+
128
+ args = parser.parse_args()
129
+
130
+
131
+ def swap_scale_shift(weight, dim):
132
+ """
133
+ Swap the scale and shift components in the weight tensor.
134
+
135
+ Args:
136
+ weight (torch.Tensor): The original weight tensor.
137
+ dim (int): The dimension along which to split.
138
+
139
+ Returns:
140
+ torch.Tensor: The modified weight tensor with scale and shift swapped.
141
+ """
142
+ shift, scale = weight.chunk(2, dim=dim)
143
+ new_weight = torch.cat([scale, shift], dim=dim)
144
+ return new_weight
145
+
146
+
147
+ def convert_megatron_transformer_checkpoint_to_diffusers(
148
+ ckpt_path: str,
149
+ num_layers: int,
150
+ num_heads: int,
151
+ hidden_size: int,
152
+ ):
153
+ """
154
+ Convert a Megatron Transformer checkpoint to Diffusers format.
155
+
156
+ Args:
157
+ ckpt_path (str): Path to the Megatron Transformer checkpoint.
158
+ num_layers (int): Number of Transformer layers.
159
+ num_heads (int): Number of attention heads.
160
+ hidden_size (int): Hidden size of the Transformer.
161
+
162
+ Returns:
163
+ dict: The converted state dictionary compatible with Diffusers.
164
+ """
165
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
166
+ mega = ckpt["model"]
167
+
168
+ new_state_dict = {}
169
+
170
+ # Patch Embedding
171
+ new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(
172
+ hidden_size, 128 if args.control else 64
173
+ )
174
+ new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"]
175
+ new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"]
176
+ new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"]
177
+
178
+ # Time Condition Embedding
179
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = mega[
180
+ "time_embedding.time_embed.0.weight"
181
+ ]
182
+ new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = mega["time_embedding.time_embed.0.bias"]
183
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = mega[
184
+ "time_embedding.time_embed.2.weight"
185
+ ]
186
+ new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = mega["time_embedding.time_embed.2.bias"]
187
+
188
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = mega[
189
+ "label_embedding.label_embed.0.weight"
190
+ ]
191
+ new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = mega[
192
+ "label_embedding.label_embed.0.bias"
193
+ ]
194
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = mega[
195
+ "label_embedding.label_embed.2.weight"
196
+ ]
197
+ new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = mega[
198
+ "label_embedding.label_embed.2.bias"
199
+ ]
200
+
201
+ # Convert each Transformer layer
202
+ for i in tqdm(range(num_layers), desc="Converting layers (Megatron->Diffusers)"):
203
+ block_prefix = f"transformer_blocks.{i}."
204
+
205
+ # AdaLayerNorm
206
+ new_state_dict[block_prefix + "norm1.linear.weight"] = mega[f"decoder.layers.{i}.adaln.weight"]
207
+ new_state_dict[block_prefix + "norm1.linear.bias"] = mega[f"decoder.layers.{i}.adaln.bias"]
208
+ qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"]
209
+ qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"]
210
+
211
+ # Reshape to match SAT logic
212
+ qkv_weight = qkv_weight.view(num_heads, 3, hidden_size // num_heads, hidden_size)
213
+ qkv_weight = qkv_weight.permute(1, 0, 2, 3).reshape(3 * hidden_size, hidden_size)
214
+
215
+ qkv_bias = qkv_bias.view(num_heads, 3, hidden_size // num_heads)
216
+ qkv_bias = qkv_bias.permute(1, 0, 2).reshape(3 * hidden_size)
217
+
218
+ # Assign to Diffusers keys
219
+ q, k, v = torch.chunk(qkv_weight, 3, dim=0)
220
+ qb, kb, vb = torch.chunk(qkv_bias, 3, dim=0)
221
+
222
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q
223
+ new_state_dict[block_prefix + "attn1.to_q.bias"] = qb
224
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k
225
+ new_state_dict[block_prefix + "attn1.to_k.bias"] = kb
226
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v
227
+ new_state_dict[block_prefix + "attn1.to_v.bias"] = vb
228
+
229
+ # Attention Output
230
+ new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[
231
+ f"decoder.layers.{i}.self_attention.linear_proj.weight"
232
+ ]
233
+ new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[
234
+ f"decoder.layers.{i}.self_attention.linear_proj.bias"
235
+ ]
236
+
237
+ # MLP
238
+ new_state_dict[block_prefix + "ff.net.0.proj.weight"] = mega[f"decoder.layers.{i}.mlp.linear_fc1.weight"]
239
+ new_state_dict[block_prefix + "ff.net.0.proj.bias"] = mega[f"decoder.layers.{i}.mlp.linear_fc1.bias"]
240
+ new_state_dict[block_prefix + "ff.net.2.weight"] = mega[f"decoder.layers.{i}.mlp.linear_fc2.weight"]
241
+ new_state_dict[block_prefix + "ff.net.2.bias"] = mega[f"decoder.layers.{i}.mlp.linear_fc2.bias"]
242
+
243
+ # Final Layers
244
+ new_state_dict["norm_out.linear.weight"] = swap_scale_shift(mega["adaln_final.weight"], dim=0)
245
+ new_state_dict["norm_out.linear.bias"] = swap_scale_shift(mega["adaln_final.bias"], dim=0)
246
+ new_state_dict["proj_out.weight"] = mega["output_projector.weight"]
247
+ new_state_dict["proj_out.bias"] = mega["output_projector.bias"]
248
+
249
+ return new_state_dict
250
+
251
+
252
+ def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
253
+ """
254
+ Convert a CogView4 VAE checkpoint to Diffusers format.
255
+
256
+ Args:
257
+ ckpt_path (str): Path to the VAE checkpoint.
258
+ vae_config (dict): Configuration dictionary for the VAE.
259
+
260
+ Returns:
261
+ dict: The converted VAE state dictionary compatible with Diffusers.
262
+ """
263
+ original_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"]
264
+ return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
265
+
266
+
267
+ def main(args):
268
+ """
269
+ Main function to convert CogView4 checkpoints to Diffusers format.
270
+
271
+ Args:
272
+ args (argparse.Namespace): Parsed command-line arguments.
273
+ """
274
+ # Determine the desired data type
275
+ if args.dtype == "fp16":
276
+ dtype = torch.float16
277
+ elif args.dtype == "bf16":
278
+ dtype = torch.bfloat16
279
+ elif args.dtype == "fp32":
280
+ dtype = torch.float32
281
+ else:
282
+ raise ValueError(f"Unsupported dtype: {args.dtype}")
283
+
284
+ transformer = None
285
+ vae = None
286
+
287
+ # Convert Transformer checkpoint if provided
288
+ if args.transformer_checkpoint_path is not None:
289
+ converted_transformer_state_dict = convert_megatron_transformer_checkpoint_to_diffusers(
290
+ ckpt_path=args.transformer_checkpoint_path,
291
+ num_layers=args.num_layers,
292
+ num_heads=args.num_heads,
293
+ hidden_size=args.hidden_size,
294
+ )
295
+ transformer = CogView4Transformer2DModel(
296
+ patch_size=2,
297
+ in_channels=32 if args.control else 16,
298
+ num_layers=args.num_layers,
299
+ attention_head_dim=args.attention_head_dim,
300
+ num_attention_heads=args.num_heads,
301
+ out_channels=16,
302
+ text_embed_dim=args.hidden_size,
303
+ time_embed_dim=args.time_embed_dim,
304
+ condition_dim=args.condition_dim,
305
+ pos_embed_max_size=args.pos_embed_max_size,
306
+ )
307
+
308
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
309
+
310
+ # Convert to the specified dtype
311
+ if dtype is not None:
312
+ transformer = transformer.to(dtype=dtype)
313
+
314
+ # Convert VAE checkpoint if provided
315
+ if args.vae_checkpoint_path is not None:
316
+ vae_config = {
317
+ "in_channels": 3,
318
+ "out_channels": 3,
319
+ "down_block_types": ("DownEncoderBlock2D",) * 4,
320
+ "up_block_types": ("UpDecoderBlock2D",) * 4,
321
+ "block_out_channels": (128, 512, 1024, 1024),
322
+ "layers_per_block": 3,
323
+ "act_fn": "silu",
324
+ "latent_channels": 16,
325
+ "norm_num_groups": 32,
326
+ "sample_size": 1024,
327
+ "scaling_factor": 1.0,
328
+ "shift_factor": 0.0,
329
+ "force_upcast": True,
330
+ "use_quant_conv": False,
331
+ "use_post_quant_conv": False,
332
+ "mid_block_add_attention": False,
333
+ }
334
+ converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
335
+ vae = AutoencoderKL(**vae_config)
336
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
337
+ if dtype is not None:
338
+ vae = vae.to(dtype=dtype)
339
+
340
+ # Load the text encoder and tokenizer
341
+ text_encoder_id = "THUDM/glm-4-9b-hf"
342
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
343
+ text_encoder = GlmModel.from_pretrained(
344
+ text_encoder_id,
345
+ cache_dir=args.text_encoder_cache_dir,
346
+ torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
347
+ )
348
+ for param in text_encoder.parameters():
349
+ param.data = param.data.contiguous()
350
+
351
+ # Initialize the scheduler
352
+ scheduler = FlowMatchEulerDiscreteScheduler(
353
+ base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
354
+ )
355
+
356
+ # Create the pipeline
357
+ if args.control:
358
+ pipe = CogView4ControlPipeline(
359
+ tokenizer=tokenizer,
360
+ text_encoder=text_encoder,
361
+ vae=vae,
362
+ transformer=transformer,
363
+ scheduler=scheduler,
364
+ )
365
+ else:
366
+ pipe = CogView4Pipeline(
367
+ tokenizer=tokenizer,
368
+ text_encoder=text_encoder,
369
+ vae=vae,
370
+ transformer=transformer,
371
+ scheduler=scheduler,
372
+ )
373
+
374
+ # Save the converted pipeline
375
+ pipe.save_pretrained(
376
+ args.output_path,
377
+ safe_serialization=True,
378
+ max_shard_size="5GB",
379
+ push_to_hub=args.push_to_hub,
380
+ )
381
+
382
+
383
+ if __name__ == "__main__":
384
+ main(args)
diffusers/scripts/convert_consistency_to_diffusers.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+
6
+ from diffusers import (
7
+ CMStochasticIterativeScheduler,
8
+ ConsistencyModelPipeline,
9
+ UNet2DModel,
10
+ )
11
+
12
+
13
+ TEST_UNET_CONFIG = {
14
+ "sample_size": 32,
15
+ "in_channels": 3,
16
+ "out_channels": 3,
17
+ "layers_per_block": 2,
18
+ "num_class_embeds": 1000,
19
+ "block_out_channels": [32, 64],
20
+ "attention_head_dim": 8,
21
+ "down_block_types": [
22
+ "ResnetDownsampleBlock2D",
23
+ "AttnDownBlock2D",
24
+ ],
25
+ "up_block_types": [
26
+ "AttnUpBlock2D",
27
+ "ResnetUpsampleBlock2D",
28
+ ],
29
+ "resnet_time_scale_shift": "scale_shift",
30
+ "attn_norm_num_groups": 32,
31
+ "upsample_type": "resnet",
32
+ "downsample_type": "resnet",
33
+ }
34
+
35
+ IMAGENET_64_UNET_CONFIG = {
36
+ "sample_size": 64,
37
+ "in_channels": 3,
38
+ "out_channels": 3,
39
+ "layers_per_block": 3,
40
+ "num_class_embeds": 1000,
41
+ "block_out_channels": [192, 192 * 2, 192 * 3, 192 * 4],
42
+ "attention_head_dim": 64,
43
+ "down_block_types": [
44
+ "ResnetDownsampleBlock2D",
45
+ "AttnDownBlock2D",
46
+ "AttnDownBlock2D",
47
+ "AttnDownBlock2D",
48
+ ],
49
+ "up_block_types": [
50
+ "AttnUpBlock2D",
51
+ "AttnUpBlock2D",
52
+ "AttnUpBlock2D",
53
+ "ResnetUpsampleBlock2D",
54
+ ],
55
+ "resnet_time_scale_shift": "scale_shift",
56
+ "attn_norm_num_groups": 32,
57
+ "upsample_type": "resnet",
58
+ "downsample_type": "resnet",
59
+ }
60
+
61
+ LSUN_256_UNET_CONFIG = {
62
+ "sample_size": 256,
63
+ "in_channels": 3,
64
+ "out_channels": 3,
65
+ "layers_per_block": 2,
66
+ "num_class_embeds": None,
67
+ "block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4],
68
+ "attention_head_dim": 64,
69
+ "down_block_types": [
70
+ "ResnetDownsampleBlock2D",
71
+ "ResnetDownsampleBlock2D",
72
+ "ResnetDownsampleBlock2D",
73
+ "AttnDownBlock2D",
74
+ "AttnDownBlock2D",
75
+ "AttnDownBlock2D",
76
+ ],
77
+ "up_block_types": [
78
+ "AttnUpBlock2D",
79
+ "AttnUpBlock2D",
80
+ "AttnUpBlock2D",
81
+ "ResnetUpsampleBlock2D",
82
+ "ResnetUpsampleBlock2D",
83
+ "ResnetUpsampleBlock2D",
84
+ ],
85
+ "resnet_time_scale_shift": "default",
86
+ "upsample_type": "resnet",
87
+ "downsample_type": "resnet",
88
+ }
89
+
90
+ CD_SCHEDULER_CONFIG = {
91
+ "num_train_timesteps": 40,
92
+ "sigma_min": 0.002,
93
+ "sigma_max": 80.0,
94
+ }
95
+
96
+ CT_IMAGENET_64_SCHEDULER_CONFIG = {
97
+ "num_train_timesteps": 201,
98
+ "sigma_min": 0.002,
99
+ "sigma_max": 80.0,
100
+ }
101
+
102
+ CT_LSUN_256_SCHEDULER_CONFIG = {
103
+ "num_train_timesteps": 151,
104
+ "sigma_min": 0.002,
105
+ "sigma_max": 80.0,
106
+ }
107
+
108
+
109
+ def str2bool(v):
110
+ """
111
+ https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
112
+ """
113
+ if isinstance(v, bool):
114
+ return v
115
+ if v.lower() in ("yes", "true", "t", "y", "1"):
116
+ return True
117
+ elif v.lower() in ("no", "false", "f", "n", "0"):
118
+ return False
119
+ else:
120
+ raise argparse.ArgumentTypeError("boolean value expected")
121
+
122
+
123
+ def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False):
124
+ new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"]
125
+ new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"]
126
+ new_checkpoint[f"{new_prefix}.conv1.weight"] = checkpoint[f"{old_prefix}.in_layers.2.weight"]
127
+ new_checkpoint[f"{new_prefix}.conv1.bias"] = checkpoint[f"{old_prefix}.in_layers.2.bias"]
128
+ new_checkpoint[f"{new_prefix}.time_emb_proj.weight"] = checkpoint[f"{old_prefix}.emb_layers.1.weight"]
129
+ new_checkpoint[f"{new_prefix}.time_emb_proj.bias"] = checkpoint[f"{old_prefix}.emb_layers.1.bias"]
130
+ new_checkpoint[f"{new_prefix}.norm2.weight"] = checkpoint[f"{old_prefix}.out_layers.0.weight"]
131
+ new_checkpoint[f"{new_prefix}.norm2.bias"] = checkpoint[f"{old_prefix}.out_layers.0.bias"]
132
+ new_checkpoint[f"{new_prefix}.conv2.weight"] = checkpoint[f"{old_prefix}.out_layers.3.weight"]
133
+ new_checkpoint[f"{new_prefix}.conv2.bias"] = checkpoint[f"{old_prefix}.out_layers.3.bias"]
134
+
135
+ if has_skip:
136
+ new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"]
137
+ new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"]
138
+
139
+ return new_checkpoint
140
+
141
+
142
+ def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None):
143
+ weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0)
144
+ bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0)
145
+
146
+ new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"]
147
+ new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"]
148
+
149
+ new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1)
150
+ new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1)
151
+ new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1)
152
+ new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1)
153
+ new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1)
154
+ new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1)
155
+
156
+ new_checkpoint[f"{new_prefix}.to_out.0.weight"] = (
157
+ checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1)
158
+ )
159
+ new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1)
160
+
161
+ return new_checkpoint
162
+
163
+
164
+ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
165
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
166
+ new_checkpoint = {}
167
+
168
+ new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
169
+ new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
170
+ new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
171
+ new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
172
+
173
+ if unet_config["num_class_embeds"] is not None:
174
+ new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"]
175
+
176
+ new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
177
+ new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
178
+
179
+ down_block_types = unet_config["down_block_types"]
180
+ layers_per_block = unet_config["layers_per_block"]
181
+ attention_head_dim = unet_config["attention_head_dim"]
182
+ channels_list = unet_config["block_out_channels"]
183
+ current_layer = 1
184
+ prev_channels = channels_list[0]
185
+
186
+ for i, layer_type in enumerate(down_block_types):
187
+ current_channels = channels_list[i]
188
+ downsample_block_has_skip = current_channels != prev_channels
189
+ if layer_type == "ResnetDownsampleBlock2D":
190
+ for j in range(layers_per_block):
191
+ new_prefix = f"down_blocks.{i}.resnets.{j}"
192
+ old_prefix = f"input_blocks.{current_layer}.0"
193
+ has_skip = True if j == 0 and downsample_block_has_skip else False
194
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
195
+ current_layer += 1
196
+
197
+ elif layer_type == "AttnDownBlock2D":
198
+ for j in range(layers_per_block):
199
+ new_prefix = f"down_blocks.{i}.resnets.{j}"
200
+ old_prefix = f"input_blocks.{current_layer}.0"
201
+ has_skip = True if j == 0 and downsample_block_has_skip else False
202
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
203
+ new_prefix = f"down_blocks.{i}.attentions.{j}"
204
+ old_prefix = f"input_blocks.{current_layer}.1"
205
+ new_checkpoint = convert_attention(
206
+ checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
207
+ )
208
+ current_layer += 1
209
+
210
+ if i != len(down_block_types) - 1:
211
+ new_prefix = f"down_blocks.{i}.downsamplers.0"
212
+ old_prefix = f"input_blocks.{current_layer}.0"
213
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
214
+ current_layer += 1
215
+
216
+ prev_channels = current_channels
217
+
218
+ # hardcoded the mid-block for now
219
+ new_prefix = "mid_block.resnets.0"
220
+ old_prefix = "middle_block.0"
221
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
222
+ new_prefix = "mid_block.attentions.0"
223
+ old_prefix = "middle_block.1"
224
+ new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim)
225
+ new_prefix = "mid_block.resnets.1"
226
+ old_prefix = "middle_block.2"
227
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
228
+
229
+ current_layer = 0
230
+ up_block_types = unet_config["up_block_types"]
231
+
232
+ for i, layer_type in enumerate(up_block_types):
233
+ if layer_type == "ResnetUpsampleBlock2D":
234
+ for j in range(layers_per_block + 1):
235
+ new_prefix = f"up_blocks.{i}.resnets.{j}"
236
+ old_prefix = f"output_blocks.{current_layer}.0"
237
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
238
+ current_layer += 1
239
+
240
+ if i != len(up_block_types) - 1:
241
+ new_prefix = f"up_blocks.{i}.upsamplers.0"
242
+ old_prefix = f"output_blocks.{current_layer - 1}.1"
243
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
244
+ elif layer_type == "AttnUpBlock2D":
245
+ for j in range(layers_per_block + 1):
246
+ new_prefix = f"up_blocks.{i}.resnets.{j}"
247
+ old_prefix = f"output_blocks.{current_layer}.0"
248
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
249
+ new_prefix = f"up_blocks.{i}.attentions.{j}"
250
+ old_prefix = f"output_blocks.{current_layer}.1"
251
+ new_checkpoint = convert_attention(
252
+ checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
253
+ )
254
+ current_layer += 1
255
+
256
+ if i != len(up_block_types) - 1:
257
+ new_prefix = f"up_blocks.{i}.upsamplers.0"
258
+ old_prefix = f"output_blocks.{current_layer - 1}.2"
259
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
260
+
261
+ new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
262
+ new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
263
+ new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
264
+ new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
265
+
266
+ return new_checkpoint
267
+
268
+
269
+ if __name__ == "__main__":
270
+ parser = argparse.ArgumentParser()
271
+
272
+ parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.")
273
+ parser.add_argument(
274
+ "--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model."
275
+ )
276
+ parser.add_argument("--class_cond", default=True, type=str, help="Whether the model is class-conditional.")
277
+
278
+ args = parser.parse_args()
279
+ args.class_cond = str2bool(args.class_cond)
280
+
281
+ ckpt_name = os.path.basename(args.unet_path)
282
+ print(f"Checkpoint: {ckpt_name}")
283
+
284
+ # Get U-Net config
285
+ if "imagenet64" in ckpt_name:
286
+ unet_config = IMAGENET_64_UNET_CONFIG
287
+ elif "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
288
+ unet_config = LSUN_256_UNET_CONFIG
289
+ elif "test" in ckpt_name:
290
+ unet_config = TEST_UNET_CONFIG
291
+ else:
292
+ raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
293
+
294
+ if not args.class_cond:
295
+ unet_config["num_class_embeds"] = None
296
+
297
+ converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config)
298
+
299
+ image_unet = UNet2DModel(**unet_config)
300
+ image_unet.load_state_dict(converted_unet_ckpt)
301
+
302
+ # Get scheduler config
303
+ if "cd" in ckpt_name or "test" in ckpt_name:
304
+ scheduler_config = CD_SCHEDULER_CONFIG
305
+ elif "ct" in ckpt_name and "imagenet64" in ckpt_name:
306
+ scheduler_config = CT_IMAGENET_64_SCHEDULER_CONFIG
307
+ elif "ct" in ckpt_name and "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
308
+ scheduler_config = CT_LSUN_256_SCHEDULER_CONFIG
309
+ else:
310
+ raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
311
+
312
+ cm_scheduler = CMStochasticIterativeScheduler(**scheduler_config)
313
+
314
+ consistency_model = ConsistencyModelPipeline(unet=image_unet, scheduler=cm_scheduler)
315
+ consistency_model.save_pretrained(args.dump_path)
diffusers/scripts/convert_cosmos_to_diffusers.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pathlib
3
+ from typing import Any, Dict
4
+
5
+ import torch
6
+ from accelerate import init_empty_weights
7
+ from huggingface_hub import snapshot_download
8
+ from transformers import T5EncoderModel, T5TokenizerFast
9
+
10
+ from diffusers import (
11
+ AutoencoderKLCosmos,
12
+ AutoencoderKLWan,
13
+ Cosmos2TextToImagePipeline,
14
+ Cosmos2VideoToWorldPipeline,
15
+ CosmosTextToWorldPipeline,
16
+ CosmosTransformer3DModel,
17
+ CosmosVideoToWorldPipeline,
18
+ EDMEulerScheduler,
19
+ FlowMatchEulerDiscreteScheduler,
20
+ )
21
+
22
+
23
+ def remove_keys_(key: str, state_dict: Dict[str, Any]):
24
+ state_dict.pop(key)
25
+
26
+
27
+ def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
28
+ state_dict[new_key] = state_dict.pop(old_key)
29
+
30
+
31
+ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
32
+ block_index = int(key.split(".")[1].removeprefix("block"))
33
+ new_key = key
34
+
35
+ old_prefix = f"blocks.block{block_index}"
36
+ new_prefix = f"transformer_blocks.{block_index}"
37
+ new_key = new_prefix + new_key.removeprefix(old_prefix)
38
+
39
+ state_dict[new_key] = state_dict.pop(key)
40
+
41
+
42
+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
43
+ "t_embedder.1": "time_embed.t_embedder",
44
+ "affline_norm": "time_embed.norm",
45
+ ".blocks.0.block.attn": ".attn1",
46
+ ".blocks.1.block.attn": ".attn2",
47
+ ".blocks.2.block": ".ff",
48
+ ".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
49
+ ".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
50
+ ".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
51
+ ".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
52
+ ".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
53
+ ".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
54
+ "to_q.0": "to_q",
55
+ "to_q.1": "norm_q",
56
+ "to_k.0": "to_k",
57
+ "to_k.1": "norm_k",
58
+ "to_v.0": "to_v",
59
+ "layer1": "net.0.proj",
60
+ "layer2": "net.2",
61
+ "proj.1": "proj",
62
+ "x_embedder": "patch_embed",
63
+ "extra_pos_embedder": "learnable_pos_embed",
64
+ "final_layer.adaLN_modulation.1": "norm_out.linear_1",
65
+ "final_layer.adaLN_modulation.2": "norm_out.linear_2",
66
+ "final_layer.linear": "proj_out",
67
+ }
68
+
69
+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
70
+ "blocks.block": rename_transformer_blocks_,
71
+ "logvar.0.freqs": remove_keys_,
72
+ "logvar.0.phases": remove_keys_,
73
+ "logvar.1.weight": remove_keys_,
74
+ "pos_embedder.seq": remove_keys_,
75
+ }
76
+
77
+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
78
+ "t_embedder.1": "time_embed.t_embedder",
79
+ "t_embedding_norm": "time_embed.norm",
80
+ "blocks": "transformer_blocks",
81
+ "adaln_modulation_self_attn.1": "norm1.linear_1",
82
+ "adaln_modulation_self_attn.2": "norm1.linear_2",
83
+ "adaln_modulation_cross_attn.1": "norm2.linear_1",
84
+ "adaln_modulation_cross_attn.2": "norm2.linear_2",
85
+ "adaln_modulation_mlp.1": "norm3.linear_1",
86
+ "adaln_modulation_mlp.2": "norm3.linear_2",
87
+ "self_attn": "attn1",
88
+ "cross_attn": "attn2",
89
+ "q_proj": "to_q",
90
+ "k_proj": "to_k",
91
+ "v_proj": "to_v",
92
+ "output_proj": "to_out.0",
93
+ "q_norm": "norm_q",
94
+ "k_norm": "norm_k",
95
+ "mlp.layer1": "ff.net.0.proj",
96
+ "mlp.layer2": "ff.net.2",
97
+ "x_embedder.proj.1": "patch_embed.proj",
98
+ "final_layer.adaln_modulation.1": "norm_out.linear_1",
99
+ "final_layer.adaln_modulation.2": "norm_out.linear_2",
100
+ "final_layer.linear": "proj_out",
101
+ }
102
+
103
+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
104
+ "accum_video_sample_counter": remove_keys_,
105
+ "accum_image_sample_counter": remove_keys_,
106
+ "accum_iteration": remove_keys_,
107
+ "accum_train_in_hours": remove_keys_,
108
+ "pos_embedder.seq": remove_keys_,
109
+ "pos_embedder.dim_spatial_range": remove_keys_,
110
+ "pos_embedder.dim_temporal_range": remove_keys_,
111
+ "_extra_state": remove_keys_,
112
+ }
113
+
114
+
115
+ TRANSFORMER_CONFIGS = {
116
+ "Cosmos-1.0-Diffusion-7B-Text2World": {
117
+ "in_channels": 16,
118
+ "out_channels": 16,
119
+ "num_attention_heads": 32,
120
+ "attention_head_dim": 128,
121
+ "num_layers": 28,
122
+ "mlp_ratio": 4.0,
123
+ "text_embed_dim": 1024,
124
+ "adaln_lora_dim": 256,
125
+ "max_size": (128, 240, 240),
126
+ "patch_size": (1, 2, 2),
127
+ "rope_scale": (2.0, 1.0, 1.0),
128
+ "concat_padding_mask": True,
129
+ "extra_pos_embed_type": "learnable",
130
+ },
131
+ "Cosmos-1.0-Diffusion-7B-Video2World": {
132
+ "in_channels": 16 + 1,
133
+ "out_channels": 16,
134
+ "num_attention_heads": 32,
135
+ "attention_head_dim": 128,
136
+ "num_layers": 28,
137
+ "mlp_ratio": 4.0,
138
+ "text_embed_dim": 1024,
139
+ "adaln_lora_dim": 256,
140
+ "max_size": (128, 240, 240),
141
+ "patch_size": (1, 2, 2),
142
+ "rope_scale": (2.0, 1.0, 1.0),
143
+ "concat_padding_mask": True,
144
+ "extra_pos_embed_type": "learnable",
145
+ },
146
+ "Cosmos-1.0-Diffusion-14B-Text2World": {
147
+ "in_channels": 16,
148
+ "out_channels": 16,
149
+ "num_attention_heads": 40,
150
+ "attention_head_dim": 128,
151
+ "num_layers": 36,
152
+ "mlp_ratio": 4.0,
153
+ "text_embed_dim": 1024,
154
+ "adaln_lora_dim": 256,
155
+ "max_size": (128, 240, 240),
156
+ "patch_size": (1, 2, 2),
157
+ "rope_scale": (2.0, 2.0, 2.0),
158
+ "concat_padding_mask": True,
159
+ "extra_pos_embed_type": "learnable",
160
+ },
161
+ "Cosmos-1.0-Diffusion-14B-Video2World": {
162
+ "in_channels": 16 + 1,
163
+ "out_channels": 16,
164
+ "num_attention_heads": 40,
165
+ "attention_head_dim": 128,
166
+ "num_layers": 36,
167
+ "mlp_ratio": 4.0,
168
+ "text_embed_dim": 1024,
169
+ "adaln_lora_dim": 256,
170
+ "max_size": (128, 240, 240),
171
+ "patch_size": (1, 2, 2),
172
+ "rope_scale": (2.0, 2.0, 2.0),
173
+ "concat_padding_mask": True,
174
+ "extra_pos_embed_type": "learnable",
175
+ },
176
+ "Cosmos-2.0-Diffusion-2B-Text2Image": {
177
+ "in_channels": 16,
178
+ "out_channels": 16,
179
+ "num_attention_heads": 16,
180
+ "attention_head_dim": 128,
181
+ "num_layers": 28,
182
+ "mlp_ratio": 4.0,
183
+ "text_embed_dim": 1024,
184
+ "adaln_lora_dim": 256,
185
+ "max_size": (128, 240, 240),
186
+ "patch_size": (1, 2, 2),
187
+ "rope_scale": (1.0, 4.0, 4.0),
188
+ "concat_padding_mask": True,
189
+ "extra_pos_embed_type": None,
190
+ },
191
+ "Cosmos-2.0-Diffusion-14B-Text2Image": {
192
+ "in_channels": 16,
193
+ "out_channels": 16,
194
+ "num_attention_heads": 40,
195
+ "attention_head_dim": 128,
196
+ "num_layers": 36,
197
+ "mlp_ratio": 4.0,
198
+ "text_embed_dim": 1024,
199
+ "adaln_lora_dim": 256,
200
+ "max_size": (128, 240, 240),
201
+ "patch_size": (1, 2, 2),
202
+ "rope_scale": (1.0, 4.0, 4.0),
203
+ "concat_padding_mask": True,
204
+ "extra_pos_embed_type": None,
205
+ },
206
+ "Cosmos-2.0-Diffusion-2B-Video2World": {
207
+ "in_channels": 16 + 1,
208
+ "out_channels": 16,
209
+ "num_attention_heads": 16,
210
+ "attention_head_dim": 128,
211
+ "num_layers": 28,
212
+ "mlp_ratio": 4.0,
213
+ "text_embed_dim": 1024,
214
+ "adaln_lora_dim": 256,
215
+ "max_size": (128, 240, 240),
216
+ "patch_size": (1, 2, 2),
217
+ "rope_scale": (1.0, 3.0, 3.0),
218
+ "concat_padding_mask": True,
219
+ "extra_pos_embed_type": None,
220
+ },
221
+ "Cosmos-2.0-Diffusion-14B-Video2World": {
222
+ "in_channels": 16 + 1,
223
+ "out_channels": 16,
224
+ "num_attention_heads": 40,
225
+ "attention_head_dim": 128,
226
+ "num_layers": 36,
227
+ "mlp_ratio": 4.0,
228
+ "text_embed_dim": 1024,
229
+ "adaln_lora_dim": 256,
230
+ "max_size": (128, 240, 240),
231
+ "patch_size": (1, 2, 2),
232
+ "rope_scale": (20 / 24, 2.0, 2.0),
233
+ "concat_padding_mask": True,
234
+ "extra_pos_embed_type": None,
235
+ },
236
+ }
237
+
238
+ VAE_KEYS_RENAME_DICT = {
239
+ "down.0": "down_blocks.0",
240
+ "down.1": "down_blocks.1",
241
+ "down.2": "down_blocks.2",
242
+ "up.0": "up_blocks.2",
243
+ "up.1": "up_blocks.1",
244
+ "up.2": "up_blocks.0",
245
+ ".block.": ".resnets.",
246
+ "downsample": "downsamplers.0",
247
+ "upsample": "upsamplers.0",
248
+ "mid.block_1": "mid_block.resnets.0",
249
+ "mid.attn_1.0": "mid_block.attentions.0",
250
+ "mid.attn_1.1": "mid_block.temp_attentions.0",
251
+ "mid.block_2": "mid_block.resnets.1",
252
+ ".q.conv3d": ".to_q",
253
+ ".k.conv3d": ".to_k",
254
+ ".v.conv3d": ".to_v",
255
+ ".proj_out.conv3d": ".to_out.0",
256
+ ".0.conv3d": ".conv_s",
257
+ ".1.conv3d": ".conv_t",
258
+ "conv1.conv3d": "conv1",
259
+ "conv2.conv3d": "conv2",
260
+ "conv3.conv3d": "conv3",
261
+ "nin_shortcut.conv3d": "conv_shortcut",
262
+ "quant_conv.conv3d": "quant_conv",
263
+ "post_quant_conv.conv3d": "post_quant_conv",
264
+ }
265
+
266
+ VAE_SPECIAL_KEYS_REMAP = {
267
+ "wavelets": remove_keys_,
268
+ "_arange": remove_keys_,
269
+ "patch_size_buffer": remove_keys_,
270
+ }
271
+
272
+ VAE_CONFIGS = {
273
+ "CV8x8x8-0.1": {
274
+ "name": "nvidia/Cosmos-0.1-Tokenizer-CV8x8x8",
275
+ "diffusers_config": {
276
+ "in_channels": 3,
277
+ "out_channels": 3,
278
+ "latent_channels": 16,
279
+ "encoder_block_out_channels": (128, 256, 512, 512),
280
+ "decode_block_out_channels": (256, 512, 512, 512),
281
+ "attention_resolutions": (32,),
282
+ "resolution": 1024,
283
+ "num_layers": 2,
284
+ "patch_size": 4,
285
+ "patch_type": "haar",
286
+ "scaling_factor": 1.0,
287
+ "spatial_compression_ratio": 8,
288
+ "temporal_compression_ratio": 8,
289
+ "latents_mean": None,
290
+ "latents_std": None,
291
+ },
292
+ },
293
+ "CV8x8x8-1.0": {
294
+ "name": "nvidia/Cosmos-1.0-Tokenizer-CV8x8x8",
295
+ "diffusers_config": {
296
+ "in_channels": 3,
297
+ "out_channels": 3,
298
+ "latent_channels": 16,
299
+ "encoder_block_out_channels": (128, 256, 512, 512),
300
+ "decode_block_out_channels": (256, 512, 512, 512),
301
+ "attention_resolutions": (32,),
302
+ "resolution": 1024,
303
+ "num_layers": 2,
304
+ "patch_size": 4,
305
+ "patch_type": "haar",
306
+ "scaling_factor": 1.0,
307
+ "spatial_compression_ratio": 8,
308
+ "temporal_compression_ratio": 8,
309
+ "latents_mean": None,
310
+ "latents_std": None,
311
+ },
312
+ },
313
+ }
314
+
315
+
316
+ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
317
+ state_dict = saved_dict
318
+ if "model" in saved_dict.keys():
319
+ state_dict = state_dict["model"]
320
+ if "module" in saved_dict.keys():
321
+ state_dict = state_dict["module"]
322
+ if "state_dict" in saved_dict.keys():
323
+ state_dict = state_dict["state_dict"]
324
+ return state_dict
325
+
326
+
327
+ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: bool = True):
328
+ PREFIX_KEY = "net."
329
+ original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=weights_only))
330
+
331
+ if "Cosmos-1.0" in transformer_type:
332
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
333
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
334
+ elif "Cosmos-2.0" in transformer_type:
335
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
336
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
337
+ else:
338
+ assert False
339
+
340
+ with init_empty_weights():
341
+ config = TRANSFORMER_CONFIGS[transformer_type]
342
+ transformer = CosmosTransformer3DModel(**config)
343
+
344
+ for key in list(original_state_dict.keys()):
345
+ new_key = key[:]
346
+ if new_key.startswith(PREFIX_KEY):
347
+ new_key = new_key.removeprefix(PREFIX_KEY)
348
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
349
+ new_key = new_key.replace(replace_key, rename_key)
350
+ update_state_dict_(original_state_dict, key, new_key)
351
+
352
+ for key in list(original_state_dict.keys()):
353
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
354
+ if special_key not in key:
355
+ continue
356
+ handler_fn_inplace(key, original_state_dict)
357
+
358
+ transformer.load_state_dict(original_state_dict, strict=True, assign=True)
359
+ return transformer
360
+
361
+
362
+ def convert_vae(vae_type: str):
363
+ model_name = VAE_CONFIGS[vae_type]["name"]
364
+ snapshot_directory = snapshot_download(model_name, repo_type="model")
365
+ directory = pathlib.Path(snapshot_directory)
366
+
367
+ autoencoder_file = directory / "autoencoder.jit"
368
+ mean_std_file = directory / "mean_std.pt"
369
+
370
+ original_state_dict = torch.jit.load(autoencoder_file.as_posix()).state_dict()
371
+ if mean_std_file.exists():
372
+ mean_std = torch.load(mean_std_file, map_location="cpu", weights_only=True)
373
+ else:
374
+ mean_std = (None, None)
375
+
376
+ config = VAE_CONFIGS[vae_type]["diffusers_config"]
377
+ config.update(
378
+ {
379
+ "latents_mean": mean_std[0].detach().cpu().numpy().tolist(),
380
+ "latents_std": mean_std[1].detach().cpu().numpy().tolist(),
381
+ }
382
+ )
383
+ vae = AutoencoderKLCosmos(**config)
384
+
385
+ for key in list(original_state_dict.keys()):
386
+ new_key = key[:]
387
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
388
+ new_key = new_key.replace(replace_key, rename_key)
389
+ update_state_dict_(original_state_dict, key, new_key)
390
+
391
+ for key in list(original_state_dict.keys()):
392
+ for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
393
+ if special_key not in key:
394
+ continue
395
+ handler_fn_inplace(key, original_state_dict)
396
+
397
+ vae.load_state_dict(original_state_dict, strict=True, assign=True)
398
+ return vae
399
+
400
+
401
+ def save_pipeline_cosmos_1_0(args, transformer, vae):
402
+ text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
403
+ tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
404
+ # The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
405
+ # So, the sigma_min values that is used is the default value of 0.002.
406
+ scheduler = EDMEulerScheduler(
407
+ sigma_min=0.002,
408
+ sigma_max=80,
409
+ sigma_data=0.5,
410
+ sigma_schedule="karras",
411
+ num_train_timesteps=1000,
412
+ prediction_type="epsilon",
413
+ rho=7.0,
414
+ final_sigmas_type="sigma_min",
415
+ )
416
+
417
+ pipe_cls = CosmosTextToWorldPipeline if "Text2World" in args.transformer_type else CosmosVideoToWorldPipeline
418
+ pipe = pipe_cls(
419
+ text_encoder=text_encoder,
420
+ tokenizer=tokenizer,
421
+ transformer=transformer,
422
+ vae=vae,
423
+ scheduler=scheduler,
424
+ safety_checker=lambda *args, **kwargs: None,
425
+ )
426
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
427
+
428
+
429
+ def save_pipeline_cosmos_2_0(args, transformer, vae):
430
+ text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
431
+ tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
432
+
433
+ scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
434
+
435
+ pipe_cls = Cosmos2TextToImagePipeline if "Text2Image" in args.transformer_type else Cosmos2VideoToWorldPipeline
436
+ pipe = pipe_cls(
437
+ text_encoder=text_encoder,
438
+ tokenizer=tokenizer,
439
+ transformer=transformer,
440
+ vae=vae,
441
+ scheduler=scheduler,
442
+ safety_checker=lambda *args, **kwargs: None,
443
+ )
444
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
445
+
446
+
447
+ def get_args():
448
+ parser = argparse.ArgumentParser()
449
+ parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
450
+ parser.add_argument(
451
+ "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
452
+ )
453
+ parser.add_argument(
454
+ "--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
455
+ )
456
+ parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
457
+ parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
458
+ parser.add_argument("--save_pipeline", action="store_true")
459
+ parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
460
+ parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
461
+ return parser.parse_args()
462
+
463
+
464
+ DTYPE_MAPPING = {
465
+ "fp32": torch.float32,
466
+ "fp16": torch.float16,
467
+ "bf16": torch.bfloat16,
468
+ }
469
+
470
+
471
+ if __name__ == "__main__":
472
+ args = get_args()
473
+
474
+ transformer = None
475
+ dtype = DTYPE_MAPPING[args.dtype]
476
+
477
+ if args.save_pipeline:
478
+ assert args.transformer_ckpt_path is not None
479
+ assert args.vae_type is not None
480
+ assert args.text_encoder_path is not None
481
+ assert args.tokenizer_path is not None
482
+
483
+ if args.transformer_ckpt_path is not None:
484
+ weights_only = "Cosmos-1.0" in args.transformer_type
485
+ transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path, weights_only)
486
+ transformer = transformer.to(dtype=dtype)
487
+ if not args.save_pipeline:
488
+ transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
489
+
490
+ if args.vae_type is not None:
491
+ if "Cosmos-1.0" in args.transformer_type:
492
+ vae = convert_vae(args.vae_type)
493
+ else:
494
+ vae = AutoencoderKLWan.from_pretrained(
495
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
496
+ )
497
+ if not args.save_pipeline:
498
+ vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
499
+
500
+ if args.save_pipeline:
501
+ if "Cosmos-1.0" in args.transformer_type:
502
+ save_pipeline_cosmos_1_0(args, transformer, vae)
503
+ elif "Cosmos-2.0" in args.transformer_type:
504
+ save_pipeline_cosmos_2_0(args, transformer, vae)
505
+ else:
506
+ assert False
diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ import torch
5
+
6
+ from diffusers import AutoencoderKL, DDPMPipeline, DDPMScheduler, UNet2DModel, VQModel
7
+
8
+
9
+ def shave_segments(path, n_shave_prefix_segments=1):
10
+ """
11
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
12
+ """
13
+ if n_shave_prefix_segments >= 0:
14
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
15
+ else:
16
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
17
+
18
+
19
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
20
+ mapping = []
21
+ for old_item in old_list:
22
+ new_item = old_item
23
+ new_item = new_item.replace("block.", "resnets.")
24
+ new_item = new_item.replace("conv_shorcut", "conv1")
25
+ new_item = new_item.replace("in_shortcut", "conv_shortcut")
26
+ new_item = new_item.replace("temb_proj", "time_emb_proj")
27
+
28
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
29
+
30
+ mapping.append({"old": old_item, "new": new_item})
31
+
32
+ return mapping
33
+
34
+
35
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):
36
+ mapping = []
37
+ for old_item in old_list:
38
+ new_item = old_item
39
+
40
+ # In `model.mid`, the layer is called `attn`.
41
+ if not in_mid:
42
+ new_item = new_item.replace("attn", "attentions")
43
+ new_item = new_item.replace(".k.", ".key.")
44
+ new_item = new_item.replace(".v.", ".value.")
45
+ new_item = new_item.replace(".q.", ".query.")
46
+
47
+ new_item = new_item.replace("proj_out", "proj_attn")
48
+ new_item = new_item.replace("norm", "group_norm")
49
+
50
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
51
+ mapping.append({"old": old_item, "new": new_item})
52
+
53
+ return mapping
54
+
55
+
56
+ def assign_to_checkpoint(
57
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
58
+ ):
59
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
60
+
61
+ if attention_paths_to_split is not None:
62
+ if config is None:
63
+ raise ValueError("Please specify the config if setting 'attention_paths_to_split' to 'True'.")
64
+
65
+ for path, path_map in attention_paths_to_split.items():
66
+ old_tensor = old_checkpoint[path]
67
+ channels = old_tensor.shape[0] // 3
68
+
69
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
70
+
71
+ num_heads = old_tensor.shape[0] // config.get("num_head_channels", 1) // 3
72
+
73
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
74
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
75
+
76
+ checkpoint[path_map["query"]] = query.reshape(target_shape).squeeze()
77
+ checkpoint[path_map["key"]] = key.reshape(target_shape).squeeze()
78
+ checkpoint[path_map["value"]] = value.reshape(target_shape).squeeze()
79
+
80
+ for path in paths:
81
+ new_path = path["new"]
82
+
83
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
84
+ continue
85
+
86
+ new_path = new_path.replace("down.", "down_blocks.")
87
+ new_path = new_path.replace("up.", "up_blocks.")
88
+
89
+ if additional_replacements is not None:
90
+ for replacement in additional_replacements:
91
+ new_path = new_path.replace(replacement["old"], replacement["new"])
92
+
93
+ if "attentions" in new_path:
94
+ checkpoint[new_path] = old_checkpoint[path["old"]].squeeze()
95
+ else:
96
+ checkpoint[new_path] = old_checkpoint[path["old"]]
97
+
98
+
99
+ def convert_ddpm_checkpoint(checkpoint, config):
100
+ """
101
+ Takes a state dict and a config, and returns a converted checkpoint.
102
+ """
103
+ new_checkpoint = {}
104
+
105
+ new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["temb.dense.0.weight"]
106
+ new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["temb.dense.0.bias"]
107
+ new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["temb.dense.1.weight"]
108
+ new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["temb.dense.1.bias"]
109
+
110
+ new_checkpoint["conv_norm_out.weight"] = checkpoint["norm_out.weight"]
111
+ new_checkpoint["conv_norm_out.bias"] = checkpoint["norm_out.bias"]
112
+
113
+ new_checkpoint["conv_in.weight"] = checkpoint["conv_in.weight"]
114
+ new_checkpoint["conv_in.bias"] = checkpoint["conv_in.bias"]
115
+ new_checkpoint["conv_out.weight"] = checkpoint["conv_out.weight"]
116
+ new_checkpoint["conv_out.bias"] = checkpoint["conv_out.bias"]
117
+
118
+ num_down_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "down" in layer})
119
+ down_blocks = {
120
+ layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
121
+ }
122
+
123
+ num_up_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "up" in layer})
124
+ up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
125
+
126
+ for i in range(num_down_blocks):
127
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
128
+
129
+ if any("downsample" in layer for layer in down_blocks[i]):
130
+ new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
131
+ f"down.{i}.downsample.op.weight"
132
+ ]
133
+ new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[f"down.{i}.downsample.op.bias"]
134
+ # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
135
+ # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
136
+
137
+ if any("block" in layer for layer in down_blocks[i]):
138
+ num_blocks = len(
139
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "block" in layer}
140
+ )
141
+ blocks = {
142
+ layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
143
+ for layer_id in range(num_blocks)
144
+ }
145
+
146
+ if num_blocks > 0:
147
+ for j in range(config["layers_per_block"]):
148
+ paths = renew_resnet_paths(blocks[j])
149
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint)
150
+
151
+ if any("attn" in layer for layer in down_blocks[i]):
152
+ num_attn = len(
153
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "attn" in layer}
154
+ )
155
+ attns = {
156
+ layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
157
+ for layer_id in range(num_blocks)
158
+ }
159
+
160
+ if num_attn > 0:
161
+ for j in range(config["layers_per_block"]):
162
+ paths = renew_attention_paths(attns[j])
163
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
164
+
165
+ mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
166
+ mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
167
+ mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
168
+
169
+ # Mid new 2
170
+ paths = renew_resnet_paths(mid_block_1_layers)
171
+ assign_to_checkpoint(
172
+ paths,
173
+ new_checkpoint,
174
+ checkpoint,
175
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
176
+ )
177
+
178
+ paths = renew_resnet_paths(mid_block_2_layers)
179
+ assign_to_checkpoint(
180
+ paths,
181
+ new_checkpoint,
182
+ checkpoint,
183
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
184
+ )
185
+
186
+ paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
187
+ assign_to_checkpoint(
188
+ paths,
189
+ new_checkpoint,
190
+ checkpoint,
191
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
192
+ )
193
+
194
+ for i in range(num_up_blocks):
195
+ block_id = num_up_blocks - 1 - i
196
+
197
+ if any("upsample" in layer for layer in up_blocks[i]):
198
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
199
+ f"up.{i}.upsample.conv.weight"
200
+ ]
201
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[f"up.{i}.upsample.conv.bias"]
202
+
203
+ if any("block" in layer for layer in up_blocks[i]):
204
+ num_blocks = len(
205
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "block" in layer}
206
+ )
207
+ blocks = {
208
+ layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
209
+ }
210
+
211
+ if num_blocks > 0:
212
+ for j in range(config["layers_per_block"] + 1):
213
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
214
+ paths = renew_resnet_paths(blocks[j])
215
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
216
+
217
+ if any("attn" in layer for layer in up_blocks[i]):
218
+ num_attn = len(
219
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "attn" in layer}
220
+ )
221
+ attns = {
222
+ layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
223
+ }
224
+
225
+ if num_attn > 0:
226
+ for j in range(config["layers_per_block"] + 1):
227
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
228
+ paths = renew_attention_paths(attns[j])
229
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
230
+
231
+ new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
232
+ return new_checkpoint
233
+
234
+
235
+ def convert_vq_autoenc_checkpoint(checkpoint, config):
236
+ """
237
+ Takes a state dict and a config, and returns a converted checkpoint.
238
+ """
239
+ new_checkpoint = {}
240
+
241
+ new_checkpoint["encoder.conv_norm_out.weight"] = checkpoint["encoder.norm_out.weight"]
242
+ new_checkpoint["encoder.conv_norm_out.bias"] = checkpoint["encoder.norm_out.bias"]
243
+
244
+ new_checkpoint["encoder.conv_in.weight"] = checkpoint["encoder.conv_in.weight"]
245
+ new_checkpoint["encoder.conv_in.bias"] = checkpoint["encoder.conv_in.bias"]
246
+ new_checkpoint["encoder.conv_out.weight"] = checkpoint["encoder.conv_out.weight"]
247
+ new_checkpoint["encoder.conv_out.bias"] = checkpoint["encoder.conv_out.bias"]
248
+
249
+ new_checkpoint["decoder.conv_norm_out.weight"] = checkpoint["decoder.norm_out.weight"]
250
+ new_checkpoint["decoder.conv_norm_out.bias"] = checkpoint["decoder.norm_out.bias"]
251
+
252
+ new_checkpoint["decoder.conv_in.weight"] = checkpoint["decoder.conv_in.weight"]
253
+ new_checkpoint["decoder.conv_in.bias"] = checkpoint["decoder.conv_in.bias"]
254
+ new_checkpoint["decoder.conv_out.weight"] = checkpoint["decoder.conv_out.weight"]
255
+ new_checkpoint["decoder.conv_out.bias"] = checkpoint["decoder.conv_out.bias"]
256
+
257
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "down" in layer})
258
+ down_blocks = {
259
+ layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
260
+ }
261
+
262
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "up" in layer})
263
+ up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
264
+
265
+ for i in range(num_down_blocks):
266
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
267
+
268
+ if any("downsample" in layer for layer in down_blocks[i]):
269
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
270
+ f"encoder.down.{i}.downsample.conv.weight"
271
+ ]
272
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[
273
+ f"encoder.down.{i}.downsample.conv.bias"
274
+ ]
275
+
276
+ if any("block" in layer for layer in down_blocks[i]):
277
+ num_blocks = len(
278
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "block" in layer}
279
+ )
280
+ blocks = {
281
+ layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
282
+ for layer_id in range(num_blocks)
283
+ }
284
+
285
+ if num_blocks > 0:
286
+ for j in range(config["layers_per_block"]):
287
+ paths = renew_resnet_paths(blocks[j])
288
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint)
289
+
290
+ if any("attn" in layer for layer in down_blocks[i]):
291
+ num_attn = len(
292
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "attn" in layer}
293
+ )
294
+ attns = {
295
+ layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
296
+ for layer_id in range(num_blocks)
297
+ }
298
+
299
+ if num_attn > 0:
300
+ for j in range(config["layers_per_block"]):
301
+ paths = renew_attention_paths(attns[j])
302
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
303
+
304
+ mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
305
+ mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
306
+ mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
307
+
308
+ # Mid new 2
309
+ paths = renew_resnet_paths(mid_block_1_layers)
310
+ assign_to_checkpoint(
311
+ paths,
312
+ new_checkpoint,
313
+ checkpoint,
314
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
315
+ )
316
+
317
+ paths = renew_resnet_paths(mid_block_2_layers)
318
+ assign_to_checkpoint(
319
+ paths,
320
+ new_checkpoint,
321
+ checkpoint,
322
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
323
+ )
324
+
325
+ paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
326
+ assign_to_checkpoint(
327
+ paths,
328
+ new_checkpoint,
329
+ checkpoint,
330
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
331
+ )
332
+
333
+ for i in range(num_up_blocks):
334
+ block_id = num_up_blocks - 1 - i
335
+
336
+ if any("upsample" in layer for layer in up_blocks[i]):
337
+ new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
338
+ f"decoder.up.{i}.upsample.conv.weight"
339
+ ]
340
+ new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
341
+ f"decoder.up.{i}.upsample.conv.bias"
342
+ ]
343
+
344
+ if any("block" in layer for layer in up_blocks[i]):
345
+ num_blocks = len(
346
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "block" in layer}
347
+ )
348
+ blocks = {
349
+ layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
350
+ }
351
+
352
+ if num_blocks > 0:
353
+ for j in range(config["layers_per_block"] + 1):
354
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
355
+ paths = renew_resnet_paths(blocks[j])
356
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
357
+
358
+ if any("attn" in layer for layer in up_blocks[i]):
359
+ num_attn = len(
360
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "attn" in layer}
361
+ )
362
+ attns = {
363
+ layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
364
+ }
365
+
366
+ if num_attn > 0:
367
+ for j in range(config["layers_per_block"] + 1):
368
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
369
+ paths = renew_attention_paths(attns[j])
370
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
371
+
372
+ new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
373
+ new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
374
+ new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
375
+ if "quantize.embedding.weight" in checkpoint:
376
+ new_checkpoint["quantize.embedding.weight"] = checkpoint["quantize.embedding.weight"]
377
+ new_checkpoint["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"]
378
+ new_checkpoint["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"]
379
+
380
+ return new_checkpoint
381
+
382
+
383
+ if __name__ == "__main__":
384
+ parser = argparse.ArgumentParser()
385
+
386
+ parser.add_argument(
387
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
388
+ )
389
+
390
+ parser.add_argument(
391
+ "--config_file",
392
+ default=None,
393
+ type=str,
394
+ required=True,
395
+ help="The config json file corresponding to the architecture.",
396
+ )
397
+
398
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
399
+
400
+ args = parser.parse_args()
401
+ checkpoint = torch.load(args.checkpoint_path)
402
+
403
+ with open(args.config_file) as f:
404
+ config = json.loads(f.read())
405
+
406
+ # unet case
407
+ key_prefix_set = {key.split(".")[0] for key in checkpoint.keys()}
408
+ if "encoder" in key_prefix_set and "decoder" in key_prefix_set:
409
+ converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config)
410
+ else:
411
+ converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
412
+
413
+ if "ddpm" in config:
414
+ del config["ddpm"]
415
+
416
+ if config["_class_name"] == "VQModel":
417
+ model = VQModel(**config)
418
+ model.load_state_dict(converted_checkpoint)
419
+ model.save_pretrained(args.dump_path)
420
+ elif config["_class_name"] == "AutoencoderKL":
421
+ model = AutoencoderKL(**config)
422
+ model.load_state_dict(converted_checkpoint)
423
+ model.save_pretrained(args.dump_path)
424
+ else:
425
+ model = UNet2DModel(**config)
426
+ model.load_state_dict(converted_checkpoint)
427
+
428
+ scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
429
+
430
+ pipe = DDPMPipeline(unet=model, scheduler=scheduler)
431
+ pipe.save_pretrained(args.dump_path)
diffusers/scripts/convert_diffusers_to_original_sdxl.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
+ # *Only* converts the UNet, VAE, and Text Encoder.
3
+ # Does not convert optimizer state or any other thing.
4
+
5
+ import argparse
6
+ import os.path as osp
7
+ import re
8
+
9
+ import torch
10
+ from safetensors.torch import load_file, save_file
11
+
12
+
13
+ # =================#
14
+ # UNet Conversion #
15
+ # =================#
16
+
17
+ unet_conversion_map = [
18
+ # (stable-diffusion, HF Diffusers)
19
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
20
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
21
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
22
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
23
+ ("input_blocks.0.0.weight", "conv_in.weight"),
24
+ ("input_blocks.0.0.bias", "conv_in.bias"),
25
+ ("out.0.weight", "conv_norm_out.weight"),
26
+ ("out.0.bias", "conv_norm_out.bias"),
27
+ ("out.2.weight", "conv_out.weight"),
28
+ ("out.2.bias", "conv_out.bias"),
29
+ # the following are for sdxl
30
+ ("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
31
+ ("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
32
+ ("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
33
+ ("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
34
+ ]
35
+
36
+ unet_conversion_map_resnet = [
37
+ # (stable-diffusion, HF Diffusers)
38
+ ("in_layers.0", "norm1"),
39
+ ("in_layers.2", "conv1"),
40
+ ("out_layers.0", "norm2"),
41
+ ("out_layers.3", "conv2"),
42
+ ("emb_layers.1", "time_emb_proj"),
43
+ ("skip_connection", "conv_shortcut"),
44
+ ]
45
+
46
+ unet_conversion_map_layer = []
47
+ # hardcoded number of downblocks and resnets/attentions...
48
+ # would need smarter logic for other networks.
49
+ for i in range(3):
50
+ # loop over downblocks/upblocks
51
+
52
+ for j in range(2):
53
+ # loop over resnets/attentions for downblocks
54
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
55
+ sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
56
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
57
+
58
+ if i > 0:
59
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
60
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
61
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
62
+
63
+ for j in range(4):
64
+ # loop over resnets/attentions for upblocks
65
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
66
+ sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
67
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
68
+
69
+ if i < 2:
70
+ # no attention layers in up_blocks.0
71
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
72
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
73
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
74
+
75
+ if i < 3:
76
+ # no downsample in down_blocks.3
77
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
78
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
79
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
80
+
81
+ # no upsample in up_blocks.3
82
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
83
+ sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
84
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
85
+ unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
86
+
87
+ hf_mid_atn_prefix = "mid_block.attentions.0."
88
+ sd_mid_atn_prefix = "middle_block.1."
89
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
90
+ for j in range(2):
91
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
92
+ sd_mid_res_prefix = f"middle_block.{2 * j}."
93
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
94
+
95
+
96
+ def convert_unet_state_dict(unet_state_dict):
97
+ # buyer beware: this is a *brittle* function,
98
+ # and correct output requires that all of these pieces interact in
99
+ # the exact order in which I have arranged them.
100
+ mapping = {k: k for k in unet_state_dict.keys()}
101
+ for sd_name, hf_name in unet_conversion_map:
102
+ mapping[hf_name] = sd_name
103
+ for k, v in mapping.items():
104
+ if "resnets" in k:
105
+ for sd_part, hf_part in unet_conversion_map_resnet:
106
+ v = v.replace(hf_part, sd_part)
107
+ mapping[k] = v
108
+ for k, v in mapping.items():
109
+ for sd_part, hf_part in unet_conversion_map_layer:
110
+ v = v.replace(hf_part, sd_part)
111
+ mapping[k] = v
112
+ new_state_dict = {sd_name: unet_state_dict[hf_name] for hf_name, sd_name in mapping.items()}
113
+ return new_state_dict
114
+
115
+
116
+ # ================#
117
+ # VAE Conversion #
118
+ # ================#
119
+
120
+ vae_conversion_map = [
121
+ # (stable-diffusion, HF Diffusers)
122
+ ("nin_shortcut", "conv_shortcut"),
123
+ ("norm_out", "conv_norm_out"),
124
+ ("mid.attn_1.", "mid_block.attentions.0."),
125
+ ]
126
+
127
+ for i in range(4):
128
+ # down_blocks have two resnets
129
+ for j in range(2):
130
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
131
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
132
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
133
+
134
+ if i < 3:
135
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
136
+ sd_downsample_prefix = f"down.{i}.downsample."
137
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
138
+
139
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
140
+ sd_upsample_prefix = f"up.{3 - i}.upsample."
141
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
142
+
143
+ # up_blocks have three resnets
144
+ # also, up blocks in hf are numbered in reverse from sd
145
+ for j in range(3):
146
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
147
+ sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
148
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
149
+
150
+ # this part accounts for mid blocks in both the encoder and the decoder
151
+ for i in range(2):
152
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
153
+ sd_mid_res_prefix = f"mid.block_{i + 1}."
154
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
155
+
156
+
157
+ vae_conversion_map_attn = [
158
+ # (stable-diffusion, HF Diffusers)
159
+ ("norm.", "group_norm."),
160
+ # the following are for SDXL
161
+ ("q.", "to_q."),
162
+ ("k.", "to_k."),
163
+ ("v.", "to_v."),
164
+ ("proj_out.", "to_out.0."),
165
+ ]
166
+
167
+
168
+ def reshape_weight_for_sd(w):
169
+ # convert HF linear weights to SD conv2d weights
170
+ if not w.ndim == 1:
171
+ return w.reshape(*w.shape, 1, 1)
172
+ else:
173
+ return w
174
+
175
+
176
+ def convert_vae_state_dict(vae_state_dict):
177
+ mapping = {k: k for k in vae_state_dict.keys()}
178
+ for k, v in mapping.items():
179
+ for sd_part, hf_part in vae_conversion_map:
180
+ v = v.replace(hf_part, sd_part)
181
+ mapping[k] = v
182
+ for k, v in mapping.items():
183
+ if "attentions" in k:
184
+ for sd_part, hf_part in vae_conversion_map_attn:
185
+ v = v.replace(hf_part, sd_part)
186
+ mapping[k] = v
187
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
188
+ weights_to_convert = ["q", "k", "v", "proj_out"]
189
+ for k, v in new_state_dict.items():
190
+ for weight_name in weights_to_convert:
191
+ if f"mid.attn_1.{weight_name}.weight" in k:
192
+ print(f"Reshaping {k} for SD format")
193
+ new_state_dict[k] = reshape_weight_for_sd(v)
194
+ return new_state_dict
195
+
196
+
197
+ # =========================#
198
+ # Text Encoder Conversion #
199
+ # =========================#
200
+
201
+
202
+ textenc_conversion_lst = [
203
+ # (stable-diffusion, HF Diffusers)
204
+ ("transformer.resblocks.", "text_model.encoder.layers."),
205
+ ("ln_1", "layer_norm1"),
206
+ ("ln_2", "layer_norm2"),
207
+ (".c_fc.", ".fc1."),
208
+ (".c_proj.", ".fc2."),
209
+ (".attn", ".self_attn"),
210
+ ("ln_final.", "text_model.final_layer_norm."),
211
+ ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
212
+ ("positional_embedding", "text_model.embeddings.position_embedding.weight"),
213
+ ]
214
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
215
+ textenc_pattern = re.compile("|".join(protected.keys()))
216
+
217
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
218
+ code2idx = {"q": 0, "k": 1, "v": 2}
219
+
220
+
221
+ def convert_openclip_text_enc_state_dict(text_enc_dict):
222
+ new_state_dict = {}
223
+ capture_qkv_weight = {}
224
+ capture_qkv_bias = {}
225
+ for k, v in text_enc_dict.items():
226
+ if (
227
+ k.endswith(".self_attn.q_proj.weight")
228
+ or k.endswith(".self_attn.k_proj.weight")
229
+ or k.endswith(".self_attn.v_proj.weight")
230
+ ):
231
+ k_pre = k[: -len(".q_proj.weight")]
232
+ k_code = k[-len("q_proj.weight")]
233
+ if k_pre not in capture_qkv_weight:
234
+ capture_qkv_weight[k_pre] = [None, None, None]
235
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
236
+ continue
237
+
238
+ if (
239
+ k.endswith(".self_attn.q_proj.bias")
240
+ or k.endswith(".self_attn.k_proj.bias")
241
+ or k.endswith(".self_attn.v_proj.bias")
242
+ ):
243
+ k_pre = k[: -len(".q_proj.bias")]
244
+ k_code = k[-len("q_proj.bias")]
245
+ if k_pre not in capture_qkv_bias:
246
+ capture_qkv_bias[k_pre] = [None, None, None]
247
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
248
+ continue
249
+
250
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
251
+ new_state_dict[relabelled_key] = v
252
+
253
+ for k_pre, tensors in capture_qkv_weight.items():
254
+ if None in tensors:
255
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
256
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
257
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
258
+
259
+ for k_pre, tensors in capture_qkv_bias.items():
260
+ if None in tensors:
261
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
262
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
263
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
264
+
265
+ return new_state_dict
266
+
267
+
268
+ def convert_openai_text_enc_state_dict(text_enc_dict):
269
+ return text_enc_dict
270
+
271
+
272
+ if __name__ == "__main__":
273
+ parser = argparse.ArgumentParser()
274
+
275
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
276
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
277
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
278
+ parser.add_argument(
279
+ "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
280
+ )
281
+
282
+ args = parser.parse_args()
283
+
284
+ assert args.model_path is not None, "Must provide a model path!"
285
+
286
+ assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
287
+
288
+ # Path for safetensors
289
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
290
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
291
+ text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
292
+ text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "model.safetensors")
293
+
294
+ # Load models from safetensors if it exists, if it doesn't pytorch
295
+ if osp.exists(unet_path):
296
+ unet_state_dict = load_file(unet_path, device="cpu")
297
+ else:
298
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
299
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
300
+
301
+ if osp.exists(vae_path):
302
+ vae_state_dict = load_file(vae_path, device="cpu")
303
+ else:
304
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
305
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
306
+
307
+ if osp.exists(text_enc_path):
308
+ text_enc_dict = load_file(text_enc_path, device="cpu")
309
+ else:
310
+ text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
311
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
312
+
313
+ if osp.exists(text_enc_2_path):
314
+ text_enc_2_dict = load_file(text_enc_2_path, device="cpu")
315
+ else:
316
+ text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "pytorch_model.bin")
317
+ text_enc_2_dict = torch.load(text_enc_2_path, map_location="cpu")
318
+
319
+ # Convert the UNet model
320
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
321
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
322
+
323
+ # Convert the VAE model
324
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
325
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
326
+
327
+ # Convert text encoder 1
328
+ text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
329
+ text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
330
+
331
+ # Convert text encoder 2
332
+ text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
333
+ text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
334
+ # We call the `.T.contiguous()` to match what's done in
335
+ # https://github.com/huggingface/diffusers/blob/84905ca7287876b925b6bf8e9bb92fec21c78764/src/diffusers/loaders/single_file_utils.py#L1085
336
+ text_enc_2_dict["conditioner.embedders.1.model.text_projection"] = text_enc_2_dict.pop(
337
+ "conditioner.embedders.1.model.text_projection.weight"
338
+ ).T.contiguous()
339
+
340
+ # Put together new checkpoint
341
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
342
+
343
+ if args.half:
344
+ state_dict = {k: v.half() for k, v in state_dict.items()}
345
+
346
+ if args.use_safetensors:
347
+ save_file(state_dict, args.checkpoint_path)
348
+ else:
349
+ state_dict = {"state_dict": state_dict}
350
+ torch.save(state_dict, args.checkpoint_path)
diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2
+ # *Only* converts the UNet, VAE, and Text Encoder.
3
+ # Does not convert optimizer state or any other thing.
4
+
5
+ import argparse
6
+ import os.path as osp
7
+ import re
8
+
9
+ import torch
10
+ from safetensors.torch import load_file, save_file
11
+
12
+
13
+ # =================#
14
+ # UNet Conversion #
15
+ # =================#
16
+
17
+ unet_conversion_map = [
18
+ # (stable-diffusion, HF Diffusers)
19
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
20
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
21
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
22
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
23
+ ("input_blocks.0.0.weight", "conv_in.weight"),
24
+ ("input_blocks.0.0.bias", "conv_in.bias"),
25
+ ("out.0.weight", "conv_norm_out.weight"),
26
+ ("out.0.bias", "conv_norm_out.bias"),
27
+ ("out.2.weight", "conv_out.weight"),
28
+ ("out.2.bias", "conv_out.bias"),
29
+ ]
30
+
31
+ unet_conversion_map_resnet = [
32
+ # (stable-diffusion, HF Diffusers)
33
+ ("in_layers.0", "norm1"),
34
+ ("in_layers.2", "conv1"),
35
+ ("out_layers.0", "norm2"),
36
+ ("out_layers.3", "conv2"),
37
+ ("emb_layers.1", "time_emb_proj"),
38
+ ("skip_connection", "conv_shortcut"),
39
+ ]
40
+
41
+ unet_conversion_map_layer = []
42
+ # hardcoded number of downblocks and resnets/attentions...
43
+ # would need smarter logic for other networks.
44
+ for i in range(4):
45
+ # loop over downblocks/upblocks
46
+
47
+ for j in range(2):
48
+ # loop over resnets/attentions for downblocks
49
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
50
+ sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
51
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
52
+
53
+ if i < 3:
54
+ # no attention layers in down_blocks.3
55
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
56
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
57
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
58
+
59
+ for j in range(3):
60
+ # loop over resnets/attentions for upblocks
61
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
62
+ sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
63
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
64
+
65
+ if i > 0:
66
+ # no attention layers in up_blocks.0
67
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
68
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
69
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
70
+
71
+ if i < 3:
72
+ # no downsample in down_blocks.3
73
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
74
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
75
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
76
+
77
+ # no upsample in up_blocks.3
78
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
79
+ sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
80
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
81
+
82
+ hf_mid_atn_prefix = "mid_block.attentions.0."
83
+ sd_mid_atn_prefix = "middle_block.1."
84
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
85
+
86
+ for j in range(2):
87
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
88
+ sd_mid_res_prefix = f"middle_block.{2 * j}."
89
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
90
+
91
+
92
+ def convert_unet_state_dict(unet_state_dict):
93
+ # buyer beware: this is a *brittle* function,
94
+ # and correct output requires that all of these pieces interact in
95
+ # the exact order in which I have arranged them.
96
+ mapping = {k: k for k in unet_state_dict.keys()}
97
+ for sd_name, hf_name in unet_conversion_map:
98
+ mapping[hf_name] = sd_name
99
+ for k, v in mapping.items():
100
+ if "resnets" in k:
101
+ for sd_part, hf_part in unet_conversion_map_resnet:
102
+ v = v.replace(hf_part, sd_part)
103
+ mapping[k] = v
104
+ for k, v in mapping.items():
105
+ for sd_part, hf_part in unet_conversion_map_layer:
106
+ v = v.replace(hf_part, sd_part)
107
+ mapping[k] = v
108
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
109
+ return new_state_dict
110
+
111
+
112
+ # ================#
113
+ # VAE Conversion #
114
+ # ================#
115
+
116
+ vae_conversion_map = [
117
+ # (stable-diffusion, HF Diffusers)
118
+ ("nin_shortcut", "conv_shortcut"),
119
+ ("norm_out", "conv_norm_out"),
120
+ ("mid.attn_1.", "mid_block.attentions.0."),
121
+ ]
122
+
123
+ for i in range(4):
124
+ # down_blocks have two resnets
125
+ for j in range(2):
126
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
127
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
128
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
129
+
130
+ if i < 3:
131
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
132
+ sd_downsample_prefix = f"down.{i}.downsample."
133
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
134
+
135
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
136
+ sd_upsample_prefix = f"up.{3 - i}.upsample."
137
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
138
+
139
+ # up_blocks have three resnets
140
+ # also, up blocks in hf are numbered in reverse from sd
141
+ for j in range(3):
142
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
143
+ sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
144
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
145
+
146
+ # this part accounts for mid blocks in both the encoder and the decoder
147
+ for i in range(2):
148
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
149
+ sd_mid_res_prefix = f"mid.block_{i + 1}."
150
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
151
+
152
+
153
+ vae_conversion_map_attn = [
154
+ # (stable-diffusion, HF Diffusers)
155
+ ("norm.", "group_norm."),
156
+ ("q.", "query."),
157
+ ("k.", "key."),
158
+ ("v.", "value."),
159
+ ("proj_out.", "proj_attn."),
160
+ ]
161
+
162
+ # This is probably not the most ideal solution, but it does work.
163
+ vae_extra_conversion_map = [
164
+ ("to_q", "q"),
165
+ ("to_k", "k"),
166
+ ("to_v", "v"),
167
+ ("to_out.0", "proj_out"),
168
+ ]
169
+
170
+
171
+ def reshape_weight_for_sd(w):
172
+ # convert HF linear weights to SD conv2d weights
173
+ if not w.ndim == 1:
174
+ return w.reshape(*w.shape, 1, 1)
175
+ else:
176
+ return w
177
+
178
+
179
+ def convert_vae_state_dict(vae_state_dict):
180
+ mapping = {k: k for k in vae_state_dict.keys()}
181
+ for k, v in mapping.items():
182
+ for sd_part, hf_part in vae_conversion_map:
183
+ v = v.replace(hf_part, sd_part)
184
+ mapping[k] = v
185
+ for k, v in mapping.items():
186
+ if "attentions" in k:
187
+ for sd_part, hf_part in vae_conversion_map_attn:
188
+ v = v.replace(hf_part, sd_part)
189
+ mapping[k] = v
190
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
191
+ weights_to_convert = ["q", "k", "v", "proj_out"]
192
+ keys_to_rename = {}
193
+ for k, v in new_state_dict.items():
194
+ for weight_name in weights_to_convert:
195
+ if f"mid.attn_1.{weight_name}.weight" in k:
196
+ print(f"Reshaping {k} for SD format")
197
+ new_state_dict[k] = reshape_weight_for_sd(v)
198
+ for weight_name, real_weight_name in vae_extra_conversion_map:
199
+ if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
200
+ keys_to_rename[k] = k.replace(weight_name, real_weight_name)
201
+ for k, v in keys_to_rename.items():
202
+ if k in new_state_dict:
203
+ print(f"Renaming {k} to {v}")
204
+ new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
205
+ del new_state_dict[k]
206
+ return new_state_dict
207
+
208
+
209
+ # =========================#
210
+ # Text Encoder Conversion #
211
+ # =========================#
212
+
213
+
214
+ textenc_conversion_lst = [
215
+ # (stable-diffusion, HF Diffusers)
216
+ ("resblocks.", "text_model.encoder.layers."),
217
+ ("ln_1", "layer_norm1"),
218
+ ("ln_2", "layer_norm2"),
219
+ (".c_fc.", ".fc1."),
220
+ (".c_proj.", ".fc2."),
221
+ (".attn", ".self_attn"),
222
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
223
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
224
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
225
+ ]
226
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
227
+ textenc_pattern = re.compile("|".join(protected.keys()))
228
+
229
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
230
+ code2idx = {"q": 0, "k": 1, "v": 2}
231
+
232
+
233
+ def convert_text_enc_state_dict_v20(text_enc_dict):
234
+ new_state_dict = {}
235
+ capture_qkv_weight = {}
236
+ capture_qkv_bias = {}
237
+ for k, v in text_enc_dict.items():
238
+ if (
239
+ k.endswith(".self_attn.q_proj.weight")
240
+ or k.endswith(".self_attn.k_proj.weight")
241
+ or k.endswith(".self_attn.v_proj.weight")
242
+ ):
243
+ k_pre = k[: -len(".q_proj.weight")]
244
+ k_code = k[-len("q_proj.weight")]
245
+ if k_pre not in capture_qkv_weight:
246
+ capture_qkv_weight[k_pre] = [None, None, None]
247
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
248
+ continue
249
+
250
+ if (
251
+ k.endswith(".self_attn.q_proj.bias")
252
+ or k.endswith(".self_attn.k_proj.bias")
253
+ or k.endswith(".self_attn.v_proj.bias")
254
+ ):
255
+ k_pre = k[: -len(".q_proj.bias")]
256
+ k_code = k[-len("q_proj.bias")]
257
+ if k_pre not in capture_qkv_bias:
258
+ capture_qkv_bias[k_pre] = [None, None, None]
259
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
260
+ continue
261
+
262
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
263
+ new_state_dict[relabelled_key] = v
264
+
265
+ for k_pre, tensors in capture_qkv_weight.items():
266
+ if None in tensors:
267
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
268
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
269
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
270
+
271
+ for k_pre, tensors in capture_qkv_bias.items():
272
+ if None in tensors:
273
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
274
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
275
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
276
+
277
+ return new_state_dict
278
+
279
+
280
+ def convert_text_enc_state_dict(text_enc_dict):
281
+ return text_enc_dict
282
+
283
+
284
+ if __name__ == "__main__":
285
+ parser = argparse.ArgumentParser()
286
+
287
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
288
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
289
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
290
+ parser.add_argument(
291
+ "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
292
+ )
293
+
294
+ args = parser.parse_args()
295
+
296
+ assert args.model_path is not None, "Must provide a model path!"
297
+
298
+ assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
299
+
300
+ # Path for safetensors
301
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
302
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
303
+ text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
304
+
305
+ # Load models from safetensors if it exists, if it doesn't pytorch
306
+ if osp.exists(unet_path):
307
+ unet_state_dict = load_file(unet_path, device="cpu")
308
+ else:
309
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
310
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
311
+
312
+ if osp.exists(vae_path):
313
+ vae_state_dict = load_file(vae_path, device="cpu")
314
+ else:
315
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
316
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
317
+
318
+ if osp.exists(text_enc_path):
319
+ text_enc_dict = load_file(text_enc_path, device="cpu")
320
+ else:
321
+ text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
322
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
323
+
324
+ # Convert the UNet model
325
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
326
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
327
+
328
+ # Convert the VAE model
329
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
330
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
331
+
332
+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
333
+ is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
334
+
335
+ if is_v20_model:
336
+ # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
337
+ text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
338
+ text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
339
+ text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
340
+ else:
341
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
342
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
343
+
344
+ # Put together new checkpoint
345
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
346
+ if args.half:
347
+ state_dict = {k: v.half() for k, v in state_dict.items()}
348
+
349
+ if args.use_safetensors:
350
+ save_file(state_dict, args.checkpoint_path)
351
+ else:
352
+ state_dict = {"state_dict": state_dict}
353
+ torch.save(state_dict, args.checkpoint_path)
diffusers/scripts/convert_dit_to_diffusers.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from torchvision.datasets.utils import download_url
6
+
7
+ from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel
8
+
9
+
10
+ pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"}
11
+
12
+
13
+ def download_model(model_name):
14
+ """
15
+ Downloads a pre-trained DiT model from the web.
16
+ """
17
+ local_path = f"pretrained_models/{model_name}"
18
+ if not os.path.isfile(local_path):
19
+ os.makedirs("pretrained_models", exist_ok=True)
20
+ web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}"
21
+ download_url(web_path, "pretrained_models")
22
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
23
+ return model
24
+
25
+
26
+ def main(args):
27
+ state_dict = download_model(pretrained_models[args.image_size])
28
+
29
+ state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
30
+ state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
31
+ state_dict.pop("x_embedder.proj.weight")
32
+ state_dict.pop("x_embedder.proj.bias")
33
+
34
+ for depth in range(28):
35
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.weight"] = state_dict[
36
+ "t_embedder.mlp.0.weight"
37
+ ]
38
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.bias"] = state_dict[
39
+ "t_embedder.mlp.0.bias"
40
+ ]
41
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.weight"] = state_dict[
42
+ "t_embedder.mlp.2.weight"
43
+ ]
44
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.bias"] = state_dict[
45
+ "t_embedder.mlp.2.bias"
46
+ ]
47
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.class_embedder.embedding_table.weight"] = state_dict[
48
+ "y_embedder.embedding_table.weight"
49
+ ]
50
+
51
+ state_dict[f"transformer_blocks.{depth}.norm1.linear.weight"] = state_dict[
52
+ f"blocks.{depth}.adaLN_modulation.1.weight"
53
+ ]
54
+ state_dict[f"transformer_blocks.{depth}.norm1.linear.bias"] = state_dict[
55
+ f"blocks.{depth}.adaLN_modulation.1.bias"
56
+ ]
57
+
58
+ q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0)
59
+ q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0)
60
+
61
+ state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
62
+ state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
63
+ state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
64
+ state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
65
+ state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
66
+ state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
67
+
68
+ state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict[
69
+ f"blocks.{depth}.attn.proj.weight"
70
+ ]
71
+ state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"]
72
+
73
+ state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict[f"blocks.{depth}.mlp.fc1.weight"]
74
+ state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict[f"blocks.{depth}.mlp.fc1.bias"]
75
+ state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict[f"blocks.{depth}.mlp.fc2.weight"]
76
+ state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict[f"blocks.{depth}.mlp.fc2.bias"]
77
+
78
+ state_dict.pop(f"blocks.{depth}.attn.qkv.weight")
79
+ state_dict.pop(f"blocks.{depth}.attn.qkv.bias")
80
+ state_dict.pop(f"blocks.{depth}.attn.proj.weight")
81
+ state_dict.pop(f"blocks.{depth}.attn.proj.bias")
82
+ state_dict.pop(f"blocks.{depth}.mlp.fc1.weight")
83
+ state_dict.pop(f"blocks.{depth}.mlp.fc1.bias")
84
+ state_dict.pop(f"blocks.{depth}.mlp.fc2.weight")
85
+ state_dict.pop(f"blocks.{depth}.mlp.fc2.bias")
86
+ state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.weight")
87
+ state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.bias")
88
+
89
+ state_dict.pop("t_embedder.mlp.0.weight")
90
+ state_dict.pop("t_embedder.mlp.0.bias")
91
+ state_dict.pop("t_embedder.mlp.2.weight")
92
+ state_dict.pop("t_embedder.mlp.2.bias")
93
+ state_dict.pop("y_embedder.embedding_table.weight")
94
+
95
+ state_dict["proj_out_1.weight"] = state_dict["final_layer.adaLN_modulation.1.weight"]
96
+ state_dict["proj_out_1.bias"] = state_dict["final_layer.adaLN_modulation.1.bias"]
97
+ state_dict["proj_out_2.weight"] = state_dict["final_layer.linear.weight"]
98
+ state_dict["proj_out_2.bias"] = state_dict["final_layer.linear.bias"]
99
+
100
+ state_dict.pop("final_layer.linear.weight")
101
+ state_dict.pop("final_layer.linear.bias")
102
+ state_dict.pop("final_layer.adaLN_modulation.1.weight")
103
+ state_dict.pop("final_layer.adaLN_modulation.1.bias")
104
+
105
+ # DiT XL/2
106
+ transformer = Transformer2DModel(
107
+ sample_size=args.image_size // 8,
108
+ num_layers=28,
109
+ attention_head_dim=72,
110
+ in_channels=4,
111
+ out_channels=8,
112
+ patch_size=2,
113
+ attention_bias=True,
114
+ num_attention_heads=16,
115
+ activation_fn="gelu-approximate",
116
+ num_embeds_ada_norm=1000,
117
+ norm_type="ada_norm_zero",
118
+ norm_elementwise_affine=False,
119
+ )
120
+ transformer.load_state_dict(state_dict, strict=True)
121
+
122
+ scheduler = DDIMScheduler(
123
+ num_train_timesteps=1000,
124
+ beta_schedule="linear",
125
+ prediction_type="epsilon",
126
+ clip_sample=False,
127
+ )
128
+
129
+ vae = AutoencoderKL.from_pretrained(args.vae_model)
130
+
131
+ pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler)
132
+
133
+ if args.save:
134
+ pipeline.save_pretrained(args.checkpoint_path)
135
+
136
+
137
+ if __name__ == "__main__":
138
+ parser = argparse.ArgumentParser()
139
+
140
+ parser.add_argument(
141
+ "--image_size",
142
+ default=256,
143
+ type=int,
144
+ required=False,
145
+ help="Image size of pretrained model, either 256 or 512.",
146
+ )
147
+ parser.add_argument(
148
+ "--vae_model",
149
+ default="stabilityai/sd-vae-ft-ema",
150
+ type=str,
151
+ required=False,
152
+ help="Path to pretrained VAE model, either stabilityai/sd-vae-ft-mse or stabilityai/sd-vae-ft-ema.",
153
+ )
154
+ parser.add_argument(
155
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
156
+ )
157
+ parser.add_argument(
158
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the output pipeline."
159
+ )
160
+
161
+ args = parser.parse_args()
162
+ main(args)
diffusers/scripts/convert_flux_to_diffusers.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from contextlib import nullcontext
3
+
4
+ import safetensors.torch
5
+ import torch
6
+ from accelerate import init_empty_weights
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ from diffusers import AutoencoderKL, FluxTransformer2DModel
10
+ from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
11
+ from diffusers.utils.import_utils import is_accelerate_available
12
+
13
+
14
+ """
15
+ # Transformer
16
+
17
+ python scripts/convert_flux_to_diffusers.py \
18
+ --original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \
19
+ --filename "flux1-schnell.sft"
20
+ --output_path "flux-schnell" \
21
+ --transformer
22
+ """
23
+
24
+ """
25
+ # VAE
26
+
27
+ python scripts/convert_flux_to_diffusers.py \
28
+ --original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \
29
+ --filename "ae.sft"
30
+ --output_path "flux-schnell" \
31
+ --vae
32
+ """
33
+
34
+ CTX = init_empty_weights if is_accelerate_available() else nullcontext
35
+
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
38
+ parser.add_argument("--filename", default="flux.safetensors", type=str)
39
+ parser.add_argument("--checkpoint_path", default=None, type=str)
40
+ parser.add_argument("--in_channels", type=int, default=64)
41
+ parser.add_argument("--out_channels", type=int, default=None)
42
+ parser.add_argument("--vae", action="store_true")
43
+ parser.add_argument("--transformer", action="store_true")
44
+ parser.add_argument("--output_path", type=str)
45
+ parser.add_argument("--dtype", type=str, default="bf16")
46
+
47
+ args = parser.parse_args()
48
+ dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
49
+
50
+
51
+ def load_original_checkpoint(args):
52
+ if args.original_state_dict_repo_id is not None:
53
+ ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
54
+ elif args.checkpoint_path is not None:
55
+ ckpt_path = args.checkpoint_path
56
+ else:
57
+ raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
58
+
59
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
60
+ return original_state_dict
61
+
62
+
63
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
64
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
65
+ def swap_scale_shift(weight):
66
+ shift, scale = weight.chunk(2, dim=0)
67
+ new_weight = torch.cat([scale, shift], dim=0)
68
+ return new_weight
69
+
70
+
71
+ def convert_flux_transformer_checkpoint_to_diffusers(
72
+ original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0
73
+ ):
74
+ converted_state_dict = {}
75
+
76
+ ## time_text_embed.timestep_embedder <- time_in
77
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
78
+ "time_in.in_layer.weight"
79
+ )
80
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
81
+ "time_in.in_layer.bias"
82
+ )
83
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
84
+ "time_in.out_layer.weight"
85
+ )
86
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
87
+ "time_in.out_layer.bias"
88
+ )
89
+
90
+ ## time_text_embed.text_embedder <- vector_in
91
+ converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
92
+ "vector_in.in_layer.weight"
93
+ )
94
+ converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
95
+ "vector_in.in_layer.bias"
96
+ )
97
+ converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
98
+ "vector_in.out_layer.weight"
99
+ )
100
+ converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
101
+ "vector_in.out_layer.bias"
102
+ )
103
+
104
+ # guidance
105
+ has_guidance = any("guidance" in k for k in original_state_dict)
106
+ if has_guidance:
107
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop(
108
+ "guidance_in.in_layer.weight"
109
+ )
110
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop(
111
+ "guidance_in.in_layer.bias"
112
+ )
113
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop(
114
+ "guidance_in.out_layer.weight"
115
+ )
116
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop(
117
+ "guidance_in.out_layer.bias"
118
+ )
119
+
120
+ # context_embedder
121
+ converted_state_dict["context_embedder.weight"] = original_state_dict.pop("txt_in.weight")
122
+ converted_state_dict["context_embedder.bias"] = original_state_dict.pop("txt_in.bias")
123
+
124
+ # x_embedder
125
+ converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight")
126
+ converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias")
127
+
128
+ # double transformer blocks
129
+ for i in range(num_layers):
130
+ block_prefix = f"transformer_blocks.{i}."
131
+ # norms.
132
+ ## norm1
133
+ converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
134
+ f"double_blocks.{i}.img_mod.lin.weight"
135
+ )
136
+ converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
137
+ f"double_blocks.{i}.img_mod.lin.bias"
138
+ )
139
+ ## norm1_context
140
+ converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
141
+ f"double_blocks.{i}.txt_mod.lin.weight"
142
+ )
143
+ converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
144
+ f"double_blocks.{i}.txt_mod.lin.bias"
145
+ )
146
+ # Q, K, V
147
+ sample_q, sample_k, sample_v = torch.chunk(
148
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
149
+ )
150
+ context_q, context_k, context_v = torch.chunk(
151
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
152
+ )
153
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
154
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
155
+ )
156
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
157
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
158
+ )
159
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
160
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
161
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
162
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
163
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
164
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
165
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
166
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
167
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
168
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
169
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
170
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
171
+ # qk_norm
172
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
173
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
174
+ )
175
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
176
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
177
+ )
178
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
179
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
180
+ )
181
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
182
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
183
+ )
184
+ # ff img_mlp
185
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
186
+ f"double_blocks.{i}.img_mlp.0.weight"
187
+ )
188
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
189
+ f"double_blocks.{i}.img_mlp.0.bias"
190
+ )
191
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
192
+ f"double_blocks.{i}.img_mlp.2.weight"
193
+ )
194
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
195
+ f"double_blocks.{i}.img_mlp.2.bias"
196
+ )
197
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
198
+ f"double_blocks.{i}.txt_mlp.0.weight"
199
+ )
200
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
201
+ f"double_blocks.{i}.txt_mlp.0.bias"
202
+ )
203
+ converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
204
+ f"double_blocks.{i}.txt_mlp.2.weight"
205
+ )
206
+ converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
207
+ f"double_blocks.{i}.txt_mlp.2.bias"
208
+ )
209
+ # output projections.
210
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
211
+ f"double_blocks.{i}.img_attn.proj.weight"
212
+ )
213
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
214
+ f"double_blocks.{i}.img_attn.proj.bias"
215
+ )
216
+ converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
217
+ f"double_blocks.{i}.txt_attn.proj.weight"
218
+ )
219
+ converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
220
+ f"double_blocks.{i}.txt_attn.proj.bias"
221
+ )
222
+
223
+ # single transformer blocks
224
+ for i in range(num_single_layers):
225
+ block_prefix = f"single_transformer_blocks.{i}."
226
+ # norm.linear <- single_blocks.0.modulation.lin
227
+ converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop(
228
+ f"single_blocks.{i}.modulation.lin.weight"
229
+ )
230
+ converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop(
231
+ f"single_blocks.{i}.modulation.lin.bias"
232
+ )
233
+ # Q, K, V, mlp
234
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
235
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
236
+ q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
237
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
238
+ original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
239
+ )
240
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
241
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
242
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
243
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
244
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
245
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
246
+ converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
247
+ converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
248
+ # qk norm
249
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
250
+ f"single_blocks.{i}.norm.query_norm.scale"
251
+ )
252
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
253
+ f"single_blocks.{i}.norm.key_norm.scale"
254
+ )
255
+ # output projections.
256
+ converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop(
257
+ f"single_blocks.{i}.linear2.weight"
258
+ )
259
+ converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop(
260
+ f"single_blocks.{i}.linear2.bias"
261
+ )
262
+
263
+ converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
264
+ converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
265
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
266
+ original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
267
+ )
268
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
269
+ original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
270
+ )
271
+
272
+ return converted_state_dict
273
+
274
+
275
+ def main(args):
276
+ original_ckpt = load_original_checkpoint(args)
277
+ has_guidance = any("guidance" in k for k in original_ckpt)
278
+
279
+ if args.transformer:
280
+ num_layers = 19
281
+ num_single_layers = 38
282
+ inner_dim = 3072
283
+ mlp_ratio = 4.0
284
+
285
+ converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
286
+ original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
287
+ )
288
+ transformer = FluxTransformer2DModel(
289
+ in_channels=args.in_channels, out_channels=args.out_channels, guidance_embeds=has_guidance
290
+ )
291
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
292
+
293
+ print(
294
+ f"Saving Flux Transformer in Diffusers format. Variant: {'guidance-distilled' if has_guidance else 'timestep-distilled'}"
295
+ )
296
+ transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
297
+
298
+ if args.vae:
299
+ config = AutoencoderKL.load_config("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae")
300
+ vae = AutoencoderKL.from_config(config, scaling_factor=0.3611, shift_factor=0.1159).to(torch.bfloat16)
301
+
302
+ converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config)
303
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
304
+ vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
305
+
306
+
307
+ if __name__ == "__main__":
308
+ main(args)
diffusers/scripts/convert_gligen_to_diffusers.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import re
3
+
4
+ import torch
5
+ import yaml
6
+ from transformers import (
7
+ CLIPProcessor,
8
+ CLIPTextModel,
9
+ CLIPTokenizer,
10
+ CLIPVisionModelWithProjection,
11
+ )
12
+
13
+ from diffusers import (
14
+ AutoencoderKL,
15
+ DDIMScheduler,
16
+ StableDiffusionGLIGENPipeline,
17
+ StableDiffusionGLIGENTextImagePipeline,
18
+ UNet2DConditionModel,
19
+ )
20
+ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
21
+ assign_to_checkpoint,
22
+ conv_attn_to_linear,
23
+ protected,
24
+ renew_attention_paths,
25
+ renew_resnet_paths,
26
+ renew_vae_attention_paths,
27
+ renew_vae_resnet_paths,
28
+ shave_segments,
29
+ textenc_conversion_map,
30
+ textenc_pattern,
31
+ )
32
+
33
+
34
+ def convert_open_clip_checkpoint(checkpoint):
35
+ checkpoint = checkpoint["text_encoder"]
36
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
37
+
38
+ keys = list(checkpoint.keys())
39
+
40
+ text_model_dict = {}
41
+
42
+ if "cond_stage_model.model.text_projection" in checkpoint:
43
+ d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
44
+ else:
45
+ d_model = 1024
46
+
47
+ for key in keys:
48
+ if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
49
+ continue
50
+ if key in textenc_conversion_map:
51
+ text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
52
+ # if key.startswith("cond_stage_model.model.transformer."):
53
+ new_key = key[len("transformer.") :]
54
+ if new_key.endswith(".in_proj_weight"):
55
+ new_key = new_key[: -len(".in_proj_weight")]
56
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
57
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
58
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
59
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
60
+ elif new_key.endswith(".in_proj_bias"):
61
+ new_key = new_key[: -len(".in_proj_bias")]
62
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
63
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
64
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
65
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
66
+ else:
67
+ if key != "transformer.text_model.embeddings.position_ids":
68
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
69
+
70
+ text_model_dict[new_key] = checkpoint[key]
71
+
72
+ if key == "transformer.text_model.embeddings.token_embedding.weight":
73
+ text_model_dict["text_model.embeddings.token_embedding.weight"] = checkpoint[key]
74
+
75
+ text_model_dict.pop("text_model.embeddings.transformer.text_model.embeddings.token_embedding.weight")
76
+
77
+ text_model.load_state_dict(text_model_dict)
78
+
79
+ return text_model
80
+
81
+
82
+ def convert_gligen_vae_checkpoint(checkpoint, config):
83
+ checkpoint = checkpoint["autoencoder"]
84
+ vae_state_dict = {}
85
+ vae_key = "first_stage_model."
86
+ keys = list(checkpoint.keys())
87
+ for key in keys:
88
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
89
+
90
+ new_checkpoint = {}
91
+
92
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
93
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
94
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
95
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
96
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
97
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
98
+
99
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
100
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
101
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
102
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
103
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
104
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
105
+
106
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
107
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
108
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
109
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
110
+
111
+ # Retrieves the keys for the encoder down blocks only
112
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
113
+ down_blocks = {
114
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
115
+ }
116
+
117
+ # Retrieves the keys for the decoder up blocks only
118
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
119
+ up_blocks = {
120
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
121
+ }
122
+
123
+ for i in range(num_down_blocks):
124
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
125
+
126
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
127
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
128
+ f"encoder.down.{i}.downsample.conv.weight"
129
+ )
130
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
131
+ f"encoder.down.{i}.downsample.conv.bias"
132
+ )
133
+
134
+ paths = renew_vae_resnet_paths(resnets)
135
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
136
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
137
+
138
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
139
+ num_mid_res_blocks = 2
140
+ for i in range(1, num_mid_res_blocks + 1):
141
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
142
+
143
+ paths = renew_vae_resnet_paths(resnets)
144
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
145
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
146
+
147
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
148
+ paths = renew_vae_attention_paths(mid_attentions)
149
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
150
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
151
+ conv_attn_to_linear(new_checkpoint)
152
+
153
+ for i in range(num_up_blocks):
154
+ block_id = num_up_blocks - 1 - i
155
+ resnets = [
156
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
157
+ ]
158
+
159
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
160
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
161
+ f"decoder.up.{block_id}.upsample.conv.weight"
162
+ ]
163
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
164
+ f"decoder.up.{block_id}.upsample.conv.bias"
165
+ ]
166
+
167
+ paths = renew_vae_resnet_paths(resnets)
168
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
169
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
170
+
171
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
172
+ num_mid_res_blocks = 2
173
+ for i in range(1, num_mid_res_blocks + 1):
174
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
175
+
176
+ paths = renew_vae_resnet_paths(resnets)
177
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
178
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
179
+
180
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
181
+ paths = renew_vae_attention_paths(mid_attentions)
182
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
183
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
184
+ conv_attn_to_linear(new_checkpoint)
185
+
186
+ for key in new_checkpoint.keys():
187
+ if "encoder.mid_block.attentions.0" in key or "decoder.mid_block.attentions.0" in key:
188
+ if "query" in key:
189
+ new_checkpoint[key.replace("query", "to_q")] = new_checkpoint.pop(key)
190
+ if "value" in key:
191
+ new_checkpoint[key.replace("value", "to_v")] = new_checkpoint.pop(key)
192
+ if "key" in key:
193
+ new_checkpoint[key.replace("key", "to_k")] = new_checkpoint.pop(key)
194
+ if "proj_attn" in key:
195
+ new_checkpoint[key.replace("proj_attn", "to_out.0")] = new_checkpoint.pop(key)
196
+
197
+ return new_checkpoint
198
+
199
+
200
+ def convert_gligen_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
201
+ unet_state_dict = {}
202
+ checkpoint = checkpoint["model"]
203
+ keys = list(checkpoint.keys())
204
+
205
+ unet_key = "model.diffusion_model."
206
+
207
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
208
+ print(f"Checkpoint {path} has bot EMA and non-EMA weights.")
209
+ print(
210
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
211
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
212
+ )
213
+ for key in keys:
214
+ if key.startswith("model.diffusion_model"):
215
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
216
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
217
+ else:
218
+ if sum(k.startswith("model_ema") for k in keys) > 100:
219
+ print(
220
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
221
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
222
+ )
223
+ for key in keys:
224
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
225
+
226
+ new_checkpoint = {}
227
+
228
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
229
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
230
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
231
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
232
+
233
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
234
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
235
+
236
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
237
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
238
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
239
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
240
+
241
+ # Retrieves the keys for the input blocks only
242
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
243
+ input_blocks = {
244
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
245
+ for layer_id in range(num_input_blocks)
246
+ }
247
+
248
+ # Retrieves the keys for the middle blocks only
249
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
250
+ middle_blocks = {
251
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
252
+ for layer_id in range(num_middle_blocks)
253
+ }
254
+
255
+ # Retrieves the keys for the output blocks only
256
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
257
+ output_blocks = {
258
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
259
+ for layer_id in range(num_output_blocks)
260
+ }
261
+
262
+ for i in range(1, num_input_blocks):
263
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
264
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
265
+
266
+ resnets = [
267
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
268
+ ]
269
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
270
+
271
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
272
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
273
+ f"input_blocks.{i}.0.op.weight"
274
+ )
275
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
276
+ f"input_blocks.{i}.0.op.bias"
277
+ )
278
+
279
+ paths = renew_resnet_paths(resnets)
280
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
281
+ assign_to_checkpoint(
282
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
283
+ )
284
+
285
+ if len(attentions):
286
+ paths = renew_attention_paths(attentions)
287
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
288
+ assign_to_checkpoint(
289
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
290
+ )
291
+
292
+ resnet_0 = middle_blocks[0]
293
+ attentions = middle_blocks[1]
294
+ resnet_1 = middle_blocks[2]
295
+
296
+ resnet_0_paths = renew_resnet_paths(resnet_0)
297
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
298
+
299
+ resnet_1_paths = renew_resnet_paths(resnet_1)
300
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
301
+
302
+ attentions_paths = renew_attention_paths(attentions)
303
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
304
+ assign_to_checkpoint(
305
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
306
+ )
307
+
308
+ for i in range(num_output_blocks):
309
+ block_id = i // (config["layers_per_block"] + 1)
310
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
311
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
312
+ output_block_list = {}
313
+
314
+ for layer in output_block_layers:
315
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
316
+ if layer_id in output_block_list:
317
+ output_block_list[layer_id].append(layer_name)
318
+ else:
319
+ output_block_list[layer_id] = [layer_name]
320
+
321
+ if len(output_block_list) > 1:
322
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
323
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
324
+
325
+ resnet_0_paths = renew_resnet_paths(resnets)
326
+ paths = renew_resnet_paths(resnets)
327
+
328
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
329
+ assign_to_checkpoint(
330
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
331
+ )
332
+
333
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
334
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
335
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
336
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
337
+ f"output_blocks.{i}.{index}.conv.weight"
338
+ ]
339
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
340
+ f"output_blocks.{i}.{index}.conv.bias"
341
+ ]
342
+
343
+ # Clear attentions as they have been attributed above.
344
+ if len(attentions) == 2:
345
+ attentions = []
346
+
347
+ if len(attentions):
348
+ paths = renew_attention_paths(attentions)
349
+ meta_path = {
350
+ "old": f"output_blocks.{i}.1",
351
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
352
+ }
353
+ assign_to_checkpoint(
354
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
355
+ )
356
+ else:
357
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
358
+ for path in resnet_0_paths:
359
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
360
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
361
+
362
+ new_checkpoint[new_path] = unet_state_dict[old_path]
363
+
364
+ for key in keys:
365
+ if "position_net" in key:
366
+ new_checkpoint[key] = unet_state_dict[key]
367
+
368
+ return new_checkpoint
369
+
370
+
371
+ def create_vae_config(original_config, image_size: int):
372
+ vae_params = original_config["autoencoder"]["params"]["ddconfig"]
373
+ _ = original_config["autoencoder"]["params"]["embed_dim"]
374
+
375
+ block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
376
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
377
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
378
+
379
+ config = {
380
+ "sample_size": image_size,
381
+ "in_channels": vae_params["in_channels"],
382
+ "out_channels": vae_params["out_ch"],
383
+ "down_block_types": tuple(down_block_types),
384
+ "up_block_types": tuple(up_block_types),
385
+ "block_out_channels": tuple(block_out_channels),
386
+ "latent_channels": vae_params["z_channels"],
387
+ "layers_per_block": vae_params["num_res_blocks"],
388
+ }
389
+
390
+ return config
391
+
392
+
393
+ def create_unet_config(original_config, image_size: int, attention_type):
394
+ unet_params = original_config["model"]["params"]
395
+ vae_params = original_config["autoencoder"]["params"]["ddconfig"]
396
+
397
+ block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
398
+
399
+ down_block_types = []
400
+ resolution = 1
401
+ for i in range(len(block_out_channels)):
402
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
403
+ down_block_types.append(block_type)
404
+ if i != len(block_out_channels) - 1:
405
+ resolution *= 2
406
+
407
+ up_block_types = []
408
+ for i in range(len(block_out_channels)):
409
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
410
+ up_block_types.append(block_type)
411
+ resolution //= 2
412
+
413
+ vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
414
+
415
+ head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
416
+ use_linear_projection = (
417
+ unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
418
+ )
419
+ if use_linear_projection:
420
+ if head_dim is None:
421
+ head_dim = [5, 10, 20, 20]
422
+
423
+ config = {
424
+ "sample_size": image_size // vae_scale_factor,
425
+ "in_channels": unet_params["in_channels"],
426
+ "down_block_types": tuple(down_block_types),
427
+ "block_out_channels": tuple(block_out_channels),
428
+ "layers_per_block": unet_params["num_res_blocks"],
429
+ "cross_attention_dim": unet_params["context_dim"],
430
+ "attention_head_dim": head_dim,
431
+ "use_linear_projection": use_linear_projection,
432
+ "attention_type": attention_type,
433
+ }
434
+
435
+ return config
436
+
437
+
438
+ def convert_gligen_to_diffusers(
439
+ checkpoint_path: str,
440
+ original_config_file: str,
441
+ attention_type: str,
442
+ image_size: int = 512,
443
+ extract_ema: bool = False,
444
+ num_in_channels: int = None,
445
+ device: str = None,
446
+ ):
447
+ if device is None:
448
+ device = "cuda" if torch.cuda.is_available() else "cpu"
449
+ checkpoint = torch.load(checkpoint_path, map_location=device)
450
+ else:
451
+ checkpoint = torch.load(checkpoint_path, map_location=device)
452
+
453
+ if "global_step" in checkpoint:
454
+ checkpoint["global_step"]
455
+ else:
456
+ print("global_step key not found in model")
457
+
458
+ original_config = yaml.safe_load(original_config_file)
459
+
460
+ if num_in_channels is not None:
461
+ original_config["model"]["params"]["in_channels"] = num_in_channels
462
+
463
+ num_train_timesteps = original_config["diffusion"]["params"]["timesteps"]
464
+ beta_start = original_config["diffusion"]["params"]["linear_start"]
465
+ beta_end = original_config["diffusion"]["params"]["linear_end"]
466
+
467
+ scheduler = DDIMScheduler(
468
+ beta_end=beta_end,
469
+ beta_schedule="scaled_linear",
470
+ beta_start=beta_start,
471
+ num_train_timesteps=num_train_timesteps,
472
+ steps_offset=1,
473
+ clip_sample=False,
474
+ set_alpha_to_one=False,
475
+ prediction_type="epsilon",
476
+ )
477
+
478
+ # Convert the UNet2DConditionalModel model
479
+ unet_config = create_unet_config(original_config, image_size, attention_type)
480
+ unet = UNet2DConditionModel(**unet_config)
481
+
482
+ converted_unet_checkpoint = convert_gligen_unet_checkpoint(
483
+ checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
484
+ )
485
+
486
+ unet.load_state_dict(converted_unet_checkpoint)
487
+
488
+ # Convert the VAE model
489
+ vae_config = create_vae_config(original_config, image_size)
490
+ converted_vae_checkpoint = convert_gligen_vae_checkpoint(checkpoint, vae_config)
491
+
492
+ vae = AutoencoderKL(**vae_config)
493
+ vae.load_state_dict(converted_vae_checkpoint)
494
+
495
+ # Convert the text model
496
+ text_encoder = convert_open_clip_checkpoint(checkpoint)
497
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
498
+
499
+ if attention_type == "gated-text-image":
500
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
501
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
502
+
503
+ pipe = StableDiffusionGLIGENTextImagePipeline(
504
+ vae=vae,
505
+ text_encoder=text_encoder,
506
+ tokenizer=tokenizer,
507
+ image_encoder=image_encoder,
508
+ processor=processor,
509
+ unet=unet,
510
+ scheduler=scheduler,
511
+ safety_checker=None,
512
+ feature_extractor=None,
513
+ )
514
+ elif attention_type == "gated":
515
+ pipe = StableDiffusionGLIGENPipeline(
516
+ vae=vae,
517
+ text_encoder=text_encoder,
518
+ tokenizer=tokenizer,
519
+ unet=unet,
520
+ scheduler=scheduler,
521
+ safety_checker=None,
522
+ feature_extractor=None,
523
+ )
524
+
525
+ return pipe
526
+
527
+
528
+ if __name__ == "__main__":
529
+ parser = argparse.ArgumentParser()
530
+
531
+ parser.add_argument(
532
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
533
+ )
534
+ parser.add_argument(
535
+ "--original_config_file",
536
+ default=None,
537
+ type=str,
538
+ required=True,
539
+ help="The YAML config file corresponding to the gligen architecture.",
540
+ )
541
+ parser.add_argument(
542
+ "--num_in_channels",
543
+ default=None,
544
+ type=int,
545
+ help="The number of input channels. If `None` number of input channels will be automatically inferred.",
546
+ )
547
+ parser.add_argument(
548
+ "--extract_ema",
549
+ action="store_true",
550
+ help=(
551
+ "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
552
+ " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
553
+ " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
554
+ ),
555
+ )
556
+ parser.add_argument(
557
+ "--attention_type",
558
+ default=None,
559
+ type=str,
560
+ required=True,
561
+ help="Type of attention ex: gated or gated-text-image",
562
+ )
563
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
564
+ parser.add_argument("--device", type=str, help="Device to use.")
565
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
566
+
567
+ args = parser.parse_args()
568
+
569
+ pipe = convert_gligen_to_diffusers(
570
+ checkpoint_path=args.checkpoint_path,
571
+ original_config_file=args.original_config_file,
572
+ attention_type=args.attention_type,
573
+ extract_ema=args.extract_ema,
574
+ num_in_channels=args.num_in_channels,
575
+ device=args.device,
576
+ )
577
+
578
+ if args.half:
579
+ pipe.to(dtype=torch.float16)
580
+
581
+ pipe.save_pretrained(args.dump_path)
diffusers/scripts/convert_hunyuan_video_to_diffusers.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import Any, Dict
3
+
4
+ import torch
5
+ from accelerate import init_empty_weights
6
+ from transformers import (
7
+ AutoModel,
8
+ AutoTokenizer,
9
+ CLIPImageProcessor,
10
+ CLIPTextModel,
11
+ CLIPTokenizer,
12
+ LlavaForConditionalGeneration,
13
+ )
14
+
15
+ from diffusers import (
16
+ AutoencoderKLHunyuanVideo,
17
+ FlowMatchEulerDiscreteScheduler,
18
+ HunyuanVideoImageToVideoPipeline,
19
+ HunyuanVideoPipeline,
20
+ HunyuanVideoTransformer3DModel,
21
+ )
22
+
23
+
24
+ def remap_norm_scale_shift_(key, state_dict):
25
+ weight = state_dict.pop(key)
26
+ shift, scale = weight.chunk(2, dim=0)
27
+ new_weight = torch.cat([scale, shift], dim=0)
28
+ state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
29
+
30
+
31
+ def remap_txt_in_(key, state_dict):
32
+ def rename_key(key):
33
+ new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
34
+ new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
35
+ new_key = new_key.replace("txt_in", "context_embedder")
36
+ new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
37
+ new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
38
+ new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
39
+ new_key = new_key.replace("mlp", "ff")
40
+ return new_key
41
+
42
+ if "self_attn_qkv" in key:
43
+ weight = state_dict.pop(key)
44
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
45
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
46
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
47
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
48
+ else:
49
+ state_dict[rename_key(key)] = state_dict.pop(key)
50
+
51
+
52
+ def remap_img_attn_qkv_(key, state_dict):
53
+ weight = state_dict.pop(key)
54
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
55
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
56
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
57
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
58
+
59
+
60
+ def remap_txt_attn_qkv_(key, state_dict):
61
+ weight = state_dict.pop(key)
62
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
63
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
64
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
65
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
66
+
67
+
68
+ def remap_single_transformer_blocks_(key, state_dict):
69
+ hidden_size = 3072
70
+
71
+ if "linear1.weight" in key:
72
+ linear1_weight = state_dict.pop(key)
73
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
74
+ q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
75
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
76
+ state_dict[f"{new_key}.attn.to_q.weight"] = q
77
+ state_dict[f"{new_key}.attn.to_k.weight"] = k
78
+ state_dict[f"{new_key}.attn.to_v.weight"] = v
79
+ state_dict[f"{new_key}.proj_mlp.weight"] = mlp
80
+
81
+ elif "linear1.bias" in key:
82
+ linear1_bias = state_dict.pop(key)
83
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
84
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
85
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
86
+ state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
87
+ state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
88
+ state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
89
+ state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
90
+
91
+ else:
92
+ new_key = key.replace("single_blocks", "single_transformer_blocks")
93
+ new_key = new_key.replace("linear2", "proj_out")
94
+ new_key = new_key.replace("q_norm", "attn.norm_q")
95
+ new_key = new_key.replace("k_norm", "attn.norm_k")
96
+ state_dict[new_key] = state_dict.pop(key)
97
+
98
+
99
+ TRANSFORMER_KEYS_RENAME_DICT = {
100
+ "img_in": "x_embedder",
101
+ "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
102
+ "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
103
+ "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
104
+ "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
105
+ "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
106
+ "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
107
+ "double_blocks": "transformer_blocks",
108
+ "img_attn_q_norm": "attn.norm_q",
109
+ "img_attn_k_norm": "attn.norm_k",
110
+ "img_attn_proj": "attn.to_out.0",
111
+ "txt_attn_q_norm": "attn.norm_added_q",
112
+ "txt_attn_k_norm": "attn.norm_added_k",
113
+ "txt_attn_proj": "attn.to_add_out",
114
+ "img_mod.linear": "norm1.linear",
115
+ "img_norm1": "norm1.norm",
116
+ "img_norm2": "norm2",
117
+ "img_mlp": "ff",
118
+ "txt_mod.linear": "norm1_context.linear",
119
+ "txt_norm1": "norm1.norm",
120
+ "txt_norm2": "norm2_context",
121
+ "txt_mlp": "ff_context",
122
+ "self_attn_proj": "attn.to_out.0",
123
+ "modulation.linear": "norm.linear",
124
+ "pre_norm": "norm.norm",
125
+ "final_layer.norm_final": "norm_out.norm",
126
+ "final_layer.linear": "proj_out",
127
+ "fc1": "net.0.proj",
128
+ "fc2": "net.2",
129
+ "input_embedder": "proj_in",
130
+ }
131
+
132
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
133
+ "txt_in": remap_txt_in_,
134
+ "img_attn_qkv": remap_img_attn_qkv_,
135
+ "txt_attn_qkv": remap_txt_attn_qkv_,
136
+ "single_blocks": remap_single_transformer_blocks_,
137
+ "final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
138
+ }
139
+
140
+ VAE_KEYS_RENAME_DICT = {}
141
+
142
+ VAE_SPECIAL_KEYS_REMAP = {}
143
+
144
+
145
+ TRANSFORMER_CONFIGS = {
146
+ "HYVideo-T/2-cfgdistill": {
147
+ "in_channels": 16,
148
+ "out_channels": 16,
149
+ "num_attention_heads": 24,
150
+ "attention_head_dim": 128,
151
+ "num_layers": 20,
152
+ "num_single_layers": 40,
153
+ "num_refiner_layers": 2,
154
+ "mlp_ratio": 4.0,
155
+ "patch_size": 2,
156
+ "patch_size_t": 1,
157
+ "qk_norm": "rms_norm",
158
+ "guidance_embeds": True,
159
+ "text_embed_dim": 4096,
160
+ "pooled_projection_dim": 768,
161
+ "rope_theta": 256.0,
162
+ "rope_axes_dim": (16, 56, 56),
163
+ "image_condition_type": None,
164
+ },
165
+ "HYVideo-T/2-I2V-33ch": {
166
+ "in_channels": 16 * 2 + 1,
167
+ "out_channels": 16,
168
+ "num_attention_heads": 24,
169
+ "attention_head_dim": 128,
170
+ "num_layers": 20,
171
+ "num_single_layers": 40,
172
+ "num_refiner_layers": 2,
173
+ "mlp_ratio": 4.0,
174
+ "patch_size": 2,
175
+ "patch_size_t": 1,
176
+ "qk_norm": "rms_norm",
177
+ "guidance_embeds": False,
178
+ "text_embed_dim": 4096,
179
+ "pooled_projection_dim": 768,
180
+ "rope_theta": 256.0,
181
+ "rope_axes_dim": (16, 56, 56),
182
+ "image_condition_type": "latent_concat",
183
+ },
184
+ "HYVideo-T/2-I2V-16ch": {
185
+ "in_channels": 16,
186
+ "out_channels": 16,
187
+ "num_attention_heads": 24,
188
+ "attention_head_dim": 128,
189
+ "num_layers": 20,
190
+ "num_single_layers": 40,
191
+ "num_refiner_layers": 2,
192
+ "mlp_ratio": 4.0,
193
+ "patch_size": 2,
194
+ "patch_size_t": 1,
195
+ "qk_norm": "rms_norm",
196
+ "guidance_embeds": True,
197
+ "text_embed_dim": 4096,
198
+ "pooled_projection_dim": 768,
199
+ "rope_theta": 256.0,
200
+ "rope_axes_dim": (16, 56, 56),
201
+ "image_condition_type": "token_replace",
202
+ },
203
+ }
204
+
205
+
206
+ def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
207
+ state_dict[new_key] = state_dict.pop(old_key)
208
+
209
+
210
+ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
211
+ state_dict = saved_dict
212
+ if "model" in saved_dict.keys():
213
+ state_dict = state_dict["model"]
214
+ if "module" in saved_dict.keys():
215
+ state_dict = state_dict["module"]
216
+ if "state_dict" in saved_dict.keys():
217
+ state_dict = state_dict["state_dict"]
218
+ return state_dict
219
+
220
+
221
+ def convert_transformer(ckpt_path: str, transformer_type: str):
222
+ original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
223
+ config = TRANSFORMER_CONFIGS[transformer_type]
224
+
225
+ with init_empty_weights():
226
+ transformer = HunyuanVideoTransformer3DModel(**config)
227
+
228
+ for key in list(original_state_dict.keys()):
229
+ new_key = key[:]
230
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
231
+ new_key = new_key.replace(replace_key, rename_key)
232
+ update_state_dict_(original_state_dict, key, new_key)
233
+
234
+ for key in list(original_state_dict.keys()):
235
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
236
+ if special_key not in key:
237
+ continue
238
+ handler_fn_inplace(key, original_state_dict)
239
+
240
+ transformer.load_state_dict(original_state_dict, strict=True, assign=True)
241
+ return transformer
242
+
243
+
244
+ def convert_vae(ckpt_path: str):
245
+ original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
246
+
247
+ with init_empty_weights():
248
+ vae = AutoencoderKLHunyuanVideo()
249
+
250
+ for key in list(original_state_dict.keys()):
251
+ new_key = key[:]
252
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
253
+ new_key = new_key.replace(replace_key, rename_key)
254
+ update_state_dict_(original_state_dict, key, new_key)
255
+
256
+ for key in list(original_state_dict.keys()):
257
+ for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
258
+ if special_key not in key:
259
+ continue
260
+ handler_fn_inplace(key, original_state_dict)
261
+
262
+ vae.load_state_dict(original_state_dict, strict=True, assign=True)
263
+ return vae
264
+
265
+
266
+ def get_args():
267
+ parser = argparse.ArgumentParser()
268
+ parser.add_argument(
269
+ "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
270
+ )
271
+ parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint")
272
+ parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint")
273
+ parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer")
274
+ parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint")
275
+ parser.add_argument("--save_pipeline", action="store_true")
276
+ parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
277
+ parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
278
+ parser.add_argument(
279
+ "--transformer_type", type=str, default="HYVideo-T/2-cfgdistill", choices=list(TRANSFORMER_CONFIGS.keys())
280
+ )
281
+ parser.add_argument("--flow_shift", type=float, default=7.0)
282
+ return parser.parse_args()
283
+
284
+
285
+ DTYPE_MAPPING = {
286
+ "fp32": torch.float32,
287
+ "fp16": torch.float16,
288
+ "bf16": torch.bfloat16,
289
+ }
290
+
291
+
292
+ if __name__ == "__main__":
293
+ args = get_args()
294
+
295
+ transformer = None
296
+ dtype = DTYPE_MAPPING[args.dtype]
297
+
298
+ if args.save_pipeline:
299
+ assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
300
+ assert args.text_encoder_path is not None
301
+ assert args.tokenizer_path is not None
302
+ assert args.text_encoder_2_path is not None
303
+
304
+ if args.transformer_ckpt_path is not None:
305
+ transformer = convert_transformer(args.transformer_ckpt_path, args.transformer_type)
306
+ transformer = transformer.to(dtype=dtype)
307
+ if not args.save_pipeline:
308
+ transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
309
+
310
+ if args.vae_ckpt_path is not None:
311
+ vae = convert_vae(args.vae_ckpt_path)
312
+ if not args.save_pipeline:
313
+ vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
314
+
315
+ if args.save_pipeline:
316
+ if args.transformer_type == "HYVideo-T/2-cfgdistill":
317
+ text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
318
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
319
+ text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
320
+ tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
321
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
322
+
323
+ pipe = HunyuanVideoPipeline(
324
+ transformer=transformer,
325
+ vae=vae,
326
+ text_encoder=text_encoder,
327
+ tokenizer=tokenizer,
328
+ text_encoder_2=text_encoder_2,
329
+ tokenizer_2=tokenizer_2,
330
+ scheduler=scheduler,
331
+ )
332
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
333
+ else:
334
+ text_encoder = LlavaForConditionalGeneration.from_pretrained(
335
+ args.text_encoder_path, torch_dtype=torch.float16
336
+ )
337
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
338
+ text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
339
+ tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
340
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
341
+ image_processor = CLIPImageProcessor.from_pretrained(args.text_encoder_path)
342
+
343
+ pipe = HunyuanVideoImageToVideoPipeline(
344
+ transformer=transformer,
345
+ vae=vae,
346
+ text_encoder=text_encoder,
347
+ tokenizer=tokenizer,
348
+ text_encoder_2=text_encoder_2,
349
+ tokenizer_2=tokenizer_2,
350
+ scheduler=scheduler,
351
+ image_processor=image_processor,
352
+ )
353
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
diffusers/scripts/convert_hunyuandit_to_diffusers.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+
5
+ from diffusers import HunyuanDiT2DModel
6
+
7
+
8
+ def main(args):
9
+ state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
10
+
11
+ if args.load_key != "none":
12
+ try:
13
+ state_dict = state_dict[args.load_key]
14
+ except KeyError:
15
+ raise KeyError(
16
+ f"{args.load_key} not found in the checkpoint.Please load from the following keys:{state_dict.keys()}"
17
+ )
18
+
19
+ device = "cuda"
20
+ model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
21
+ model_config["use_style_cond_and_image_meta_size"] = (
22
+ args.use_style_cond_and_image_meta_size
23
+ ) ### version <= v1.1: True; version >= v1.2: False
24
+
25
+ # input_size -> sample_size, text_dim -> cross_attention_dim
26
+ for key in state_dict:
27
+ print("local:", key)
28
+
29
+ model = HunyuanDiT2DModel.from_config(model_config).to(device)
30
+
31
+ for key in model.state_dict():
32
+ print("diffusers:", key)
33
+
34
+ num_layers = 40
35
+ for i in range(num_layers):
36
+ # attn1
37
+ # Wkqv -> to_q, to_k, to_v
38
+ q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
39
+ q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
40
+ state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
41
+ state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
42
+ state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
43
+ state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
44
+ state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
45
+ state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
46
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
47
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
48
+
49
+ # q_norm, k_norm -> norm_q, norm_k
50
+ state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
51
+ state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
52
+ state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
53
+ state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
54
+
55
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
56
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
57
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
58
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
59
+
60
+ # out_proj -> to_out
61
+ state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
62
+ state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
63
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
64
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
65
+
66
+ # attn2
67
+ # kq_proj -> to_k, to_v
68
+ k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
69
+ k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
70
+ state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
71
+ state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
72
+ state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
73
+ state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
74
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
75
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
76
+
77
+ # q_proj -> to_q
78
+ state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
79
+ state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
80
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
81
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
82
+
83
+ # q_norm, k_norm -> norm_q, norm_k
84
+ state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
85
+ state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
86
+ state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
87
+ state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
88
+
89
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
90
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
91
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
92
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
93
+
94
+ # out_proj -> to_out
95
+ state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
96
+ state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
97
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
98
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
99
+
100
+ # switch norm 2 and norm 3
101
+ norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
102
+ norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
103
+ state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
104
+ state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
105
+ state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
106
+ state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
107
+
108
+ # norm1 -> norm1.norm
109
+ # default_modulation.1 -> norm1.linear
110
+ state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]
111
+ state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]
112
+ state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]
113
+ state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]
114
+ state_dict.pop(f"blocks.{i}.norm1.weight")
115
+ state_dict.pop(f"blocks.{i}.norm1.bias")
116
+ state_dict.pop(f"blocks.{i}.default_modulation.1.weight")
117
+ state_dict.pop(f"blocks.{i}.default_modulation.1.bias")
118
+
119
+ # mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2
120
+ state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]
121
+ state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]
122
+ state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]
123
+ state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]
124
+ state_dict.pop(f"blocks.{i}.mlp.fc1.weight")
125
+ state_dict.pop(f"blocks.{i}.mlp.fc1.bias")
126
+ state_dict.pop(f"blocks.{i}.mlp.fc2.weight")
127
+ state_dict.pop(f"blocks.{i}.mlp.fc2.bias")
128
+
129
+ # pooler -> time_extra_emb
130
+ state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
131
+ state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
132
+ state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
133
+ state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
134
+ state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
135
+ state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
136
+ state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
137
+ state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
138
+ state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
139
+ state_dict.pop("pooler.k_proj.weight")
140
+ state_dict.pop("pooler.k_proj.bias")
141
+ state_dict.pop("pooler.q_proj.weight")
142
+ state_dict.pop("pooler.q_proj.bias")
143
+ state_dict.pop("pooler.v_proj.weight")
144
+ state_dict.pop("pooler.v_proj.bias")
145
+ state_dict.pop("pooler.c_proj.weight")
146
+ state_dict.pop("pooler.c_proj.bias")
147
+ state_dict.pop("pooler.positional_embedding")
148
+
149
+ # t_embedder -> time_embedding (`TimestepEmbedding`)
150
+ state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
151
+ state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
152
+ state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
153
+ state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
154
+
155
+ state_dict.pop("t_embedder.mlp.0.bias")
156
+ state_dict.pop("t_embedder.mlp.0.weight")
157
+ state_dict.pop("t_embedder.mlp.2.bias")
158
+ state_dict.pop("t_embedder.mlp.2.weight")
159
+
160
+ # x_embedder -> pos_embd (`PatchEmbed`)
161
+ state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
162
+ state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
163
+ state_dict.pop("x_embedder.proj.weight")
164
+ state_dict.pop("x_embedder.proj.bias")
165
+
166
+ # mlp_t5 -> text_embedder
167
+ state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]
168
+ state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]
169
+ state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]
170
+ state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]
171
+ state_dict.pop("mlp_t5.0.bias")
172
+ state_dict.pop("mlp_t5.0.weight")
173
+ state_dict.pop("mlp_t5.2.bias")
174
+ state_dict.pop("mlp_t5.2.weight")
175
+
176
+ # extra_embedder -> extra_embedder
177
+ state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]
178
+ state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]
179
+ state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]
180
+ state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]
181
+ state_dict.pop("extra_embedder.0.bias")
182
+ state_dict.pop("extra_embedder.0.weight")
183
+ state_dict.pop("extra_embedder.2.bias")
184
+ state_dict.pop("extra_embedder.2.weight")
185
+
186
+ # model.final_adaLN_modulation.1 -> norm_out.linear
187
+ def swap_scale_shift(weight):
188
+ shift, scale = weight.chunk(2, dim=0)
189
+ new_weight = torch.cat([scale, shift], dim=0)
190
+ return new_weight
191
+
192
+ state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.weight"])
193
+ state_dict["norm_out.linear.bias"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.bias"])
194
+ state_dict.pop("final_layer.adaLN_modulation.1.weight")
195
+ state_dict.pop("final_layer.adaLN_modulation.1.bias")
196
+
197
+ # final_linear -> proj_out
198
+ state_dict["proj_out.weight"] = state_dict["final_layer.linear.weight"]
199
+ state_dict["proj_out.bias"] = state_dict["final_layer.linear.bias"]
200
+ state_dict.pop("final_layer.linear.weight")
201
+ state_dict.pop("final_layer.linear.bias")
202
+
203
+ # style_embedder
204
+ if model_config["use_style_cond_and_image_meta_size"]:
205
+ print(state_dict["style_embedder.weight"])
206
+ print(state_dict["style_embedder.weight"].shape)
207
+ state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]
208
+ state_dict.pop("style_embedder.weight")
209
+
210
+ model.load_state_dict(state_dict)
211
+
212
+ from diffusers import HunyuanDiTPipeline
213
+
214
+ if args.use_style_cond_and_image_meta_size:
215
+ pipe = HunyuanDiTPipeline.from_pretrained(
216
+ "Tencent-Hunyuan/HunyuanDiT-Diffusers", transformer=model, torch_dtype=torch.float32
217
+ )
218
+ else:
219
+ pipe = HunyuanDiTPipeline.from_pretrained(
220
+ "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", transformer=model, torch_dtype=torch.float32
221
+ )
222
+ pipe.to("cuda")
223
+ pipe.to(dtype=torch.float32)
224
+
225
+ if args.save:
226
+ pipe.save_pretrained(args.output_checkpoint_path)
227
+
228
+ # ### NOTE: HunyuanDiT supports both Chinese and English inputs
229
+ prompt = "一个宇航员在骑马"
230
+ # prompt = "An astronaut riding a horse"
231
+ generator = torch.Generator(device="cuda").manual_seed(0)
232
+ image = pipe(
233
+ height=1024, width=1024, prompt=prompt, generator=generator, num_inference_steps=25, guidance_scale=5.0
234
+ ).images[0]
235
+
236
+ image.save("img.png")
237
+
238
+
239
+ if __name__ == "__main__":
240
+ parser = argparse.ArgumentParser()
241
+
242
+ parser.add_argument(
243
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
244
+ )
245
+ parser.add_argument(
246
+ "--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."
247
+ )
248
+ parser.add_argument(
249
+ "--output_checkpoint_path",
250
+ default=None,
251
+ type=str,
252
+ required=False,
253
+ help="Path to the output converted diffusers pipeline.",
254
+ )
255
+ parser.add_argument(
256
+ "--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"
257
+ )
258
+ parser.add_argument(
259
+ "--use_style_cond_and_image_meta_size",
260
+ type=bool,
261
+ default=False,
262
+ help="version <= v1.1: True; version >= v1.2: False",
263
+ )
264
+
265
+ args = parser.parse_args()
266
+ main(args)
diffusers/scripts/convert_k_upscaler_to_diffusers.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import huggingface_hub
4
+ import k_diffusion as K
5
+ import torch
6
+
7
+ from diffusers import UNet2DConditionModel
8
+
9
+
10
+ UPSCALER_REPO = "pcuenq/k-upscaler"
11
+
12
+
13
+ def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
14
+ rv = {
15
+ # norm1
16
+ f"{diffusers_resnet_prefix}.norm1.linear.weight": checkpoint[f"{resnet_prefix}.main.0.mapper.weight"],
17
+ f"{diffusers_resnet_prefix}.norm1.linear.bias": checkpoint[f"{resnet_prefix}.main.0.mapper.bias"],
18
+ # conv1
19
+ f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.main.2.weight"],
20
+ f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.main.2.bias"],
21
+ # norm2
22
+ f"{diffusers_resnet_prefix}.norm2.linear.weight": checkpoint[f"{resnet_prefix}.main.4.mapper.weight"],
23
+ f"{diffusers_resnet_prefix}.norm2.linear.bias": checkpoint[f"{resnet_prefix}.main.4.mapper.bias"],
24
+ # conv2
25
+ f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.main.6.weight"],
26
+ f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.main.6.bias"],
27
+ }
28
+
29
+ if resnet.conv_shortcut is not None:
30
+ rv.update(
31
+ {
32
+ f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.skip.weight"],
33
+ }
34
+ )
35
+
36
+ return rv
37
+
38
+
39
+ def self_attn_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
40
+ weight_q, weight_k, weight_v = checkpoint[f"{attention_prefix}.qkv_proj.weight"].chunk(3, dim=0)
41
+ bias_q, bias_k, bias_v = checkpoint[f"{attention_prefix}.qkv_proj.bias"].chunk(3, dim=0)
42
+ rv = {
43
+ # norm
44
+ f"{diffusers_attention_prefix}.norm1.linear.weight": checkpoint[f"{attention_prefix}.norm_in.mapper.weight"],
45
+ f"{diffusers_attention_prefix}.norm1.linear.bias": checkpoint[f"{attention_prefix}.norm_in.mapper.bias"],
46
+ # to_q
47
+ f"{diffusers_attention_prefix}.attn1.to_q.weight": weight_q.squeeze(-1).squeeze(-1),
48
+ f"{diffusers_attention_prefix}.attn1.to_q.bias": bias_q,
49
+ # to_k
50
+ f"{diffusers_attention_prefix}.attn1.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
51
+ f"{diffusers_attention_prefix}.attn1.to_k.bias": bias_k,
52
+ # to_v
53
+ f"{diffusers_attention_prefix}.attn1.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
54
+ f"{diffusers_attention_prefix}.attn1.to_v.bias": bias_v,
55
+ # to_out
56
+ f"{diffusers_attention_prefix}.attn1.to_out.0.weight": checkpoint[f"{attention_prefix}.out_proj.weight"]
57
+ .squeeze(-1)
58
+ .squeeze(-1),
59
+ f"{diffusers_attention_prefix}.attn1.to_out.0.bias": checkpoint[f"{attention_prefix}.out_proj.bias"],
60
+ }
61
+
62
+ return rv
63
+
64
+
65
+ def cross_attn_to_diffusers_checkpoint(
66
+ checkpoint, *, diffusers_attention_prefix, diffusers_attention_index, attention_prefix
67
+ ):
68
+ weight_k, weight_v = checkpoint[f"{attention_prefix}.kv_proj.weight"].chunk(2, dim=0)
69
+ bias_k, bias_v = checkpoint[f"{attention_prefix}.kv_proj.bias"].chunk(2, dim=0)
70
+
71
+ rv = {
72
+ # norm2 (ada groupnorm)
73
+ f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.weight": checkpoint[
74
+ f"{attention_prefix}.norm_dec.mapper.weight"
75
+ ],
76
+ f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.bias": checkpoint[
77
+ f"{attention_prefix}.norm_dec.mapper.bias"
78
+ ],
79
+ # layernorm on encoder_hidden_state
80
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.weight": checkpoint[
81
+ f"{attention_prefix}.norm_enc.weight"
82
+ ],
83
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.bias": checkpoint[
84
+ f"{attention_prefix}.norm_enc.bias"
85
+ ],
86
+ # to_q
87
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.weight": checkpoint[
88
+ f"{attention_prefix}.q_proj.weight"
89
+ ]
90
+ .squeeze(-1)
91
+ .squeeze(-1),
92
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.bias": checkpoint[
93
+ f"{attention_prefix}.q_proj.bias"
94
+ ],
95
+ # to_k
96
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
97
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.bias": bias_k,
98
+ # to_v
99
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
100
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.bias": bias_v,
101
+ # to_out
102
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.weight": checkpoint[
103
+ f"{attention_prefix}.out_proj.weight"
104
+ ]
105
+ .squeeze(-1)
106
+ .squeeze(-1),
107
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.bias": checkpoint[
108
+ f"{attention_prefix}.out_proj.bias"
109
+ ],
110
+ }
111
+
112
+ return rv
113
+
114
+
115
+ def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
116
+ block_prefix = "inner_model.u_net.u_blocks" if block_type == "up" else "inner_model.u_net.d_blocks"
117
+ block_prefix = f"{block_prefix}.{block_idx}"
118
+
119
+ diffusers_checkpoint = {}
120
+
121
+ if not hasattr(block, "attentions"):
122
+ n = 1 # resnet only
123
+ elif not block.attentions[0].add_self_attention:
124
+ n = 2 # resnet -> cross-attention
125
+ else:
126
+ n = 3 # resnet -> self-attention -> cross-attention)
127
+
128
+ for resnet_idx, resnet in enumerate(block.resnets):
129
+ # diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
130
+ diffusers_resnet_prefix = f"{block_type}_blocks.{block_idx}.resnets.{resnet_idx}"
131
+ idx = n * resnet_idx if block_type == "up" else n * resnet_idx + 1
132
+ resnet_prefix = f"{block_prefix}.{idx}" if block_type == "up" else f"{block_prefix}.{idx}"
133
+
134
+ diffusers_checkpoint.update(
135
+ resnet_to_diffusers_checkpoint(
136
+ resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
137
+ )
138
+ )
139
+
140
+ if hasattr(block, "attentions"):
141
+ for attention_idx, attention in enumerate(block.attentions):
142
+ diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
143
+ idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
144
+ self_attention_prefix = f"{block_prefix}.{idx}"
145
+ cross_attention_prefix = f"{block_prefix}.{idx}"
146
+ cross_attention_index = 1 if not attention.add_self_attention else 2
147
+ idx = (
148
+ n * attention_idx + cross_attention_index
149
+ if block_type == "up"
150
+ else n * attention_idx + cross_attention_index + 1
151
+ )
152
+ cross_attention_prefix = f"{block_prefix}.{idx}"
153
+
154
+ diffusers_checkpoint.update(
155
+ cross_attn_to_diffusers_checkpoint(
156
+ checkpoint,
157
+ diffusers_attention_prefix=diffusers_attention_prefix,
158
+ diffusers_attention_index=2,
159
+ attention_prefix=cross_attention_prefix,
160
+ )
161
+ )
162
+
163
+ if attention.add_self_attention is True:
164
+ diffusers_checkpoint.update(
165
+ self_attn_to_diffusers_checkpoint(
166
+ checkpoint,
167
+ diffusers_attention_prefix=diffusers_attention_prefix,
168
+ attention_prefix=self_attention_prefix,
169
+ )
170
+ )
171
+
172
+ return diffusers_checkpoint
173
+
174
+
175
+ def unet_to_diffusers_checkpoint(model, checkpoint):
176
+ diffusers_checkpoint = {}
177
+
178
+ # pre-processing
179
+ diffusers_checkpoint.update(
180
+ {
181
+ "conv_in.weight": checkpoint["inner_model.proj_in.weight"],
182
+ "conv_in.bias": checkpoint["inner_model.proj_in.bias"],
183
+ }
184
+ )
185
+
186
+ # timestep and class embedding
187
+ diffusers_checkpoint.update(
188
+ {
189
+ "time_proj.weight": checkpoint["inner_model.timestep_embed.weight"].squeeze(-1),
190
+ "time_embedding.linear_1.weight": checkpoint["inner_model.mapping.0.weight"],
191
+ "time_embedding.linear_1.bias": checkpoint["inner_model.mapping.0.bias"],
192
+ "time_embedding.linear_2.weight": checkpoint["inner_model.mapping.2.weight"],
193
+ "time_embedding.linear_2.bias": checkpoint["inner_model.mapping.2.bias"],
194
+ "time_embedding.cond_proj.weight": checkpoint["inner_model.mapping_cond.weight"],
195
+ }
196
+ )
197
+
198
+ # down_blocks
199
+ for down_block_idx, down_block in enumerate(model.down_blocks):
200
+ diffusers_checkpoint.update(block_to_diffusers_checkpoint(down_block, checkpoint, down_block_idx, "down"))
201
+
202
+ # up_blocks
203
+ for up_block_idx, up_block in enumerate(model.up_blocks):
204
+ diffusers_checkpoint.update(block_to_diffusers_checkpoint(up_block, checkpoint, up_block_idx, "up"))
205
+
206
+ # post-processing
207
+ diffusers_checkpoint.update(
208
+ {
209
+ "conv_out.weight": checkpoint["inner_model.proj_out.weight"],
210
+ "conv_out.bias": checkpoint["inner_model.proj_out.bias"],
211
+ }
212
+ )
213
+
214
+ return diffusers_checkpoint
215
+
216
+
217
+ def unet_model_from_original_config(original_config):
218
+ in_channels = original_config["input_channels"] + original_config["unet_cond_dim"]
219
+ out_channels = original_config["input_channels"] + (1 if original_config["has_variance"] else 0)
220
+
221
+ block_out_channels = original_config["channels"]
222
+
223
+ assert len(set(original_config["depths"])) == 1, (
224
+ "UNet2DConditionModel currently do not support blocks with different number of layers"
225
+ )
226
+ layers_per_block = original_config["depths"][0]
227
+
228
+ class_labels_dim = original_config["mapping_cond_dim"]
229
+ cross_attention_dim = original_config["cross_cond_dim"]
230
+
231
+ attn1_types = []
232
+ attn2_types = []
233
+ for s, c in zip(original_config["self_attn_depths"], original_config["cross_attn_depths"]):
234
+ if s:
235
+ a1 = "self"
236
+ a2 = "cross" if c else None
237
+ elif c:
238
+ a1 = "cross"
239
+ a2 = None
240
+ else:
241
+ a1 = None
242
+ a2 = None
243
+ attn1_types.append(a1)
244
+ attn2_types.append(a2)
245
+
246
+ unet = UNet2DConditionModel(
247
+ in_channels=in_channels,
248
+ out_channels=out_channels,
249
+ down_block_types=("KDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D"),
250
+ mid_block_type=None,
251
+ up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"),
252
+ block_out_channels=block_out_channels,
253
+ layers_per_block=layers_per_block,
254
+ act_fn="gelu",
255
+ norm_num_groups=None,
256
+ cross_attention_dim=cross_attention_dim,
257
+ attention_head_dim=64,
258
+ time_cond_proj_dim=class_labels_dim,
259
+ resnet_time_scale_shift="scale_shift",
260
+ time_embedding_type="fourier",
261
+ timestep_post_act="gelu",
262
+ conv_in_kernel=1,
263
+ conv_out_kernel=1,
264
+ )
265
+
266
+ return unet
267
+
268
+
269
+ def main(args):
270
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
271
+
272
+ orig_config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json")
273
+ orig_weights_path = huggingface_hub.hf_hub_download(
274
+ UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth"
275
+ )
276
+ print(f"loading original model configuration from {orig_config_path}")
277
+ print(f"loading original model checkpoint from {orig_weights_path}")
278
+
279
+ print("converting to diffusers unet")
280
+ orig_config = K.config.load_config(open(orig_config_path))["model"]
281
+ model = unet_model_from_original_config(orig_config)
282
+
283
+ orig_checkpoint = torch.load(orig_weights_path, map_location=device)["model_ema"]
284
+ converted_checkpoint = unet_to_diffusers_checkpoint(model, orig_checkpoint)
285
+
286
+ model.load_state_dict(converted_checkpoint, strict=True)
287
+ model.save_pretrained(args.dump_path)
288
+ print(f"saving converted unet model in {args.dump_path}")
289
+
290
+
291
+ if __name__ == "__main__":
292
+ parser = argparse.ArgumentParser()
293
+
294
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
295
+ args = parser.parse_args()
296
+
297
+ main(args)
diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py ADDED
@@ -0,0 +1,1159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import tempfile
3
+
4
+ import torch
5
+ from accelerate import load_checkpoint_and_dispatch
6
+ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
7
+
8
+ from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel
9
+ from diffusers.models.transformers.prior_transformer import PriorTransformer
10
+ from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
11
+ from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler
12
+
13
+
14
+ r"""
15
+ Example - From the diffusers root directory:
16
+
17
+ Download weights:
18
+ ```sh
19
+ $ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt
20
+ $ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt
21
+ $ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt
22
+ $ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th
23
+ ```
24
+
25
+ Convert the model:
26
+ ```sh
27
+ $ python scripts/convert_kakao_brain_unclip_to_diffusers.py \
28
+ --decoder_checkpoint_path ./decoder-ckpt-step\=01000000-of-01000000.ckpt \
29
+ --super_res_unet_checkpoint_path ./improved-sr-ckpt-step\=1.2M.ckpt \
30
+ --prior_checkpoint_path ./prior-ckpt-step\=01000000-of-01000000.ckpt \
31
+ --clip_stat_path ./ViT-L-14_stats.th \
32
+ --dump_path <path where to save model>
33
+ ```
34
+ """
35
+
36
+
37
+ # prior
38
+
39
+ PRIOR_ORIGINAL_PREFIX = "model"
40
+
41
+ # Uses default arguments
42
+ PRIOR_CONFIG = {}
43
+
44
+
45
+ def prior_model_from_original_config():
46
+ model = PriorTransformer(**PRIOR_CONFIG)
47
+
48
+ return model
49
+
50
+
51
+ def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint, clip_stats_checkpoint):
52
+ diffusers_checkpoint = {}
53
+
54
+ # <original>.time_embed.0 -> <diffusers>.time_embedding.linear_1
55
+ diffusers_checkpoint.update(
56
+ {
57
+ "time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.weight"],
58
+ "time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.bias"],
59
+ }
60
+ )
61
+
62
+ # <original>.clip_img_proj -> <diffusers>.proj_in
63
+ diffusers_checkpoint.update(
64
+ {
65
+ "proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.weight"],
66
+ "proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.bias"],
67
+ }
68
+ )
69
+
70
+ # <original>.text_emb_proj -> <diffusers>.embedding_proj
71
+ diffusers_checkpoint.update(
72
+ {
73
+ "embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.weight"],
74
+ "embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.bias"],
75
+ }
76
+ )
77
+
78
+ # <original>.text_enc_proj -> <diffusers>.encoder_hidden_states_proj
79
+ diffusers_checkpoint.update(
80
+ {
81
+ "encoder_hidden_states_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.weight"],
82
+ "encoder_hidden_states_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.bias"],
83
+ }
84
+ )
85
+
86
+ # <original>.positional_embedding -> <diffusers>.positional_embedding
87
+ diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.positional_embedding"]})
88
+
89
+ # <original>.prd_emb -> <diffusers>.prd_embedding
90
+ diffusers_checkpoint.update({"prd_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.prd_emb"]})
91
+
92
+ # <original>.time_embed.2 -> <diffusers>.time_embedding.linear_2
93
+ diffusers_checkpoint.update(
94
+ {
95
+ "time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.weight"],
96
+ "time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.bias"],
97
+ }
98
+ )
99
+
100
+ # <original>.resblocks.<x> -> <diffusers>.transformer_blocks.<x>
101
+ for idx in range(len(model.transformer_blocks)):
102
+ diffusers_transformer_prefix = f"transformer_blocks.{idx}"
103
+ original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.transformer.resblocks.{idx}"
104
+
105
+ # <original>.attn -> <diffusers>.attn1
106
+ diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"
107
+ original_attention_prefix = f"{original_transformer_prefix}.attn"
108
+ diffusers_checkpoint.update(
109
+ prior_attention_to_diffusers(
110
+ checkpoint,
111
+ diffusers_attention_prefix=diffusers_attention_prefix,
112
+ original_attention_prefix=original_attention_prefix,
113
+ attention_head_dim=model.attention_head_dim,
114
+ )
115
+ )
116
+
117
+ # <original>.mlp -> <diffusers>.ff
118
+ diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"
119
+ original_ff_prefix = f"{original_transformer_prefix}.mlp"
120
+ diffusers_checkpoint.update(
121
+ prior_ff_to_diffusers(
122
+ checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix
123
+ )
124
+ )
125
+
126
+ # <original>.ln_1 -> <diffusers>.norm1
127
+ diffusers_checkpoint.update(
128
+ {
129
+ f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[
130
+ f"{original_transformer_prefix}.ln_1.weight"
131
+ ],
132
+ f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],
133
+ }
134
+ )
135
+
136
+ # <original>.ln_2 -> <diffusers>.norm3
137
+ diffusers_checkpoint.update(
138
+ {
139
+ f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[
140
+ f"{original_transformer_prefix}.ln_2.weight"
141
+ ],
142
+ f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],
143
+ }
144
+ )
145
+
146
+ # <original>.final_ln -> <diffusers>.norm_out
147
+ diffusers_checkpoint.update(
148
+ {
149
+ "norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.weight"],
150
+ "norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.bias"],
151
+ }
152
+ )
153
+
154
+ # <original>.out_proj -> <diffusers>.proj_to_clip_embeddings
155
+ diffusers_checkpoint.update(
156
+ {
157
+ "proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.weight"],
158
+ "proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.bias"],
159
+ }
160
+ )
161
+
162
+ # clip stats
163
+ clip_mean, clip_std = clip_stats_checkpoint
164
+ clip_mean = clip_mean[None, :]
165
+ clip_std = clip_std[None, :]
166
+
167
+ diffusers_checkpoint.update({"clip_mean": clip_mean, "clip_std": clip_std})
168
+
169
+ return diffusers_checkpoint
170
+
171
+
172
+ def prior_attention_to_diffusers(
173
+ checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim
174
+ ):
175
+ diffusers_checkpoint = {}
176
+
177
+ # <original>.c_qkv -> <diffusers>.{to_q, to_k, to_v}
178
+ [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
179
+ weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"],
180
+ bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"],
181
+ split=3,
182
+ chunk_size=attention_head_dim,
183
+ )
184
+
185
+ diffusers_checkpoint.update(
186
+ {
187
+ f"{diffusers_attention_prefix}.to_q.weight": q_weight,
188
+ f"{diffusers_attention_prefix}.to_q.bias": q_bias,
189
+ f"{diffusers_attention_prefix}.to_k.weight": k_weight,
190
+ f"{diffusers_attention_prefix}.to_k.bias": k_bias,
191
+ f"{diffusers_attention_prefix}.to_v.weight": v_weight,
192
+ f"{diffusers_attention_prefix}.to_v.bias": v_bias,
193
+ }
194
+ )
195
+
196
+ # <original>.c_proj -> <diffusers>.to_out.0
197
+ diffusers_checkpoint.update(
198
+ {
199
+ f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"],
200
+ f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"],
201
+ }
202
+ )
203
+
204
+ return diffusers_checkpoint
205
+
206
+
207
+ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix):
208
+ diffusers_checkpoint = {
209
+ # <original>.c_fc -> <diffusers>.net.0.proj
210
+ f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"],
211
+ f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"],
212
+ # <original>.c_proj -> <diffusers>.net.2
213
+ f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"],
214
+ f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"],
215
+ }
216
+
217
+ return diffusers_checkpoint
218
+
219
+
220
+ # done prior
221
+
222
+
223
+ # decoder
224
+
225
+ DECODER_ORIGINAL_PREFIX = "model"
226
+
227
+ # We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can
228
+ # update then.
229
+ DECODER_CONFIG = {
230
+ "sample_size": 64,
231
+ "layers_per_block": 3,
232
+ "down_block_types": (
233
+ "ResnetDownsampleBlock2D",
234
+ "SimpleCrossAttnDownBlock2D",
235
+ "SimpleCrossAttnDownBlock2D",
236
+ "SimpleCrossAttnDownBlock2D",
237
+ ),
238
+ "up_block_types": (
239
+ "SimpleCrossAttnUpBlock2D",
240
+ "SimpleCrossAttnUpBlock2D",
241
+ "SimpleCrossAttnUpBlock2D",
242
+ "ResnetUpsampleBlock2D",
243
+ ),
244
+ "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
245
+ "block_out_channels": (320, 640, 960, 1280),
246
+ "in_channels": 3,
247
+ "out_channels": 6,
248
+ "cross_attention_dim": 1536,
249
+ "class_embed_type": "identity",
250
+ "attention_head_dim": 64,
251
+ "resnet_time_scale_shift": "scale_shift",
252
+ }
253
+
254
+
255
+ def decoder_model_from_original_config():
256
+ model = UNet2DConditionModel(**DECODER_CONFIG)
257
+
258
+ return model
259
+
260
+
261
+ def decoder_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
262
+ diffusers_checkpoint = {}
263
+
264
+ original_unet_prefix = DECODER_ORIGINAL_PREFIX
265
+ num_head_channels = DECODER_CONFIG["attention_head_dim"]
266
+
267
+ diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))
268
+ diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))
269
+
270
+ # <original>.input_blocks -> <diffusers>.down_blocks
271
+
272
+ original_down_block_idx = 1
273
+
274
+ for diffusers_down_block_idx in range(len(model.down_blocks)):
275
+ checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
276
+ model,
277
+ checkpoint,
278
+ diffusers_down_block_idx=diffusers_down_block_idx,
279
+ original_down_block_idx=original_down_block_idx,
280
+ original_unet_prefix=original_unet_prefix,
281
+ num_head_channels=num_head_channels,
282
+ )
283
+
284
+ original_down_block_idx += num_original_down_blocks
285
+
286
+ diffusers_checkpoint.update(checkpoint_update)
287
+
288
+ # done <original>.input_blocks -> <diffusers>.down_blocks
289
+
290
+ diffusers_checkpoint.update(
291
+ unet_midblock_to_diffusers_checkpoint(
292
+ model,
293
+ checkpoint,
294
+ original_unet_prefix=original_unet_prefix,
295
+ num_head_channels=num_head_channels,
296
+ )
297
+ )
298
+
299
+ # <original>.output_blocks -> <diffusers>.up_blocks
300
+
301
+ original_up_block_idx = 0
302
+
303
+ for diffusers_up_block_idx in range(len(model.up_blocks)):
304
+ checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
305
+ model,
306
+ checkpoint,
307
+ diffusers_up_block_idx=diffusers_up_block_idx,
308
+ original_up_block_idx=original_up_block_idx,
309
+ original_unet_prefix=original_unet_prefix,
310
+ num_head_channels=num_head_channels,
311
+ )
312
+
313
+ original_up_block_idx += num_original_up_blocks
314
+
315
+ diffusers_checkpoint.update(checkpoint_update)
316
+
317
+ # done <original>.output_blocks -> <diffusers>.up_blocks
318
+
319
+ diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))
320
+ diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))
321
+
322
+ return diffusers_checkpoint
323
+
324
+
325
+ # done decoder
326
+
327
+ # text proj
328
+
329
+
330
+ def text_proj_from_original_config():
331
+ # From the conditional unet constructor where the dimension of the projected time embeddings is
332
+ # constructed
333
+ time_embed_dim = DECODER_CONFIG["block_out_channels"][0] * 4
334
+
335
+ cross_attention_dim = DECODER_CONFIG["cross_attention_dim"]
336
+
337
+ model = UnCLIPTextProjModel(time_embed_dim=time_embed_dim, cross_attention_dim=cross_attention_dim)
338
+
339
+ return model
340
+
341
+
342
+ # Note that the input checkpoint is the original decoder checkpoint
343
+ def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint):
344
+ diffusers_checkpoint = {
345
+ # <original>.text_seq_proj.0 -> <diffusers>.encoder_hidden_states_proj
346
+ "encoder_hidden_states_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.weight"],
347
+ "encoder_hidden_states_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.bias"],
348
+ # <original>.text_seq_proj.1 -> <diffusers>.text_encoder_hidden_states_norm
349
+ "text_encoder_hidden_states_norm.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.weight"],
350
+ "text_encoder_hidden_states_norm.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.bias"],
351
+ # <original>.clip_tok_proj -> <diffusers>.clip_extra_context_tokens_proj
352
+ "clip_extra_context_tokens_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.weight"],
353
+ "clip_extra_context_tokens_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.bias"],
354
+ # <original>.text_feat_proj -> <diffusers>.embedding_proj
355
+ "embedding_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.weight"],
356
+ "embedding_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.bias"],
357
+ # <original>.cf_param -> <diffusers>.learned_classifier_free_guidance_embeddings
358
+ "learned_classifier_free_guidance_embeddings": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.cf_param"],
359
+ # <original>.clip_emb -> <diffusers>.clip_image_embeddings_project_to_time_embeddings
360
+ "clip_image_embeddings_project_to_time_embeddings.weight": checkpoint[
361
+ f"{DECODER_ORIGINAL_PREFIX}.clip_emb.weight"
362
+ ],
363
+ "clip_image_embeddings_project_to_time_embeddings.bias": checkpoint[
364
+ f"{DECODER_ORIGINAL_PREFIX}.clip_emb.bias"
365
+ ],
366
+ }
367
+
368
+ return diffusers_checkpoint
369
+
370
+
371
+ # done text proj
372
+
373
+ # super res unet first steps
374
+
375
+ SUPER_RES_UNET_FIRST_STEPS_PREFIX = "model_first_steps"
376
+
377
+ SUPER_RES_UNET_FIRST_STEPS_CONFIG = {
378
+ "sample_size": 256,
379
+ "layers_per_block": 3,
380
+ "down_block_types": (
381
+ "ResnetDownsampleBlock2D",
382
+ "ResnetDownsampleBlock2D",
383
+ "ResnetDownsampleBlock2D",
384
+ "ResnetDownsampleBlock2D",
385
+ ),
386
+ "up_block_types": (
387
+ "ResnetUpsampleBlock2D",
388
+ "ResnetUpsampleBlock2D",
389
+ "ResnetUpsampleBlock2D",
390
+ "ResnetUpsampleBlock2D",
391
+ ),
392
+ "block_out_channels": (320, 640, 960, 1280),
393
+ "in_channels": 6,
394
+ "out_channels": 3,
395
+ "add_attention": False,
396
+ }
397
+
398
+
399
+ def super_res_unet_first_steps_model_from_original_config():
400
+ model = UNet2DModel(**SUPER_RES_UNET_FIRST_STEPS_CONFIG)
401
+
402
+ return model
403
+
404
+
405
+ def super_res_unet_first_steps_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
406
+ diffusers_checkpoint = {}
407
+
408
+ original_unet_prefix = SUPER_RES_UNET_FIRST_STEPS_PREFIX
409
+
410
+ diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))
411
+ diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))
412
+
413
+ # <original>.input_blocks -> <diffusers>.down_blocks
414
+
415
+ original_down_block_idx = 1
416
+
417
+ for diffusers_down_block_idx in range(len(model.down_blocks)):
418
+ checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
419
+ model,
420
+ checkpoint,
421
+ diffusers_down_block_idx=diffusers_down_block_idx,
422
+ original_down_block_idx=original_down_block_idx,
423
+ original_unet_prefix=original_unet_prefix,
424
+ num_head_channels=None,
425
+ )
426
+
427
+ original_down_block_idx += num_original_down_blocks
428
+
429
+ diffusers_checkpoint.update(checkpoint_update)
430
+
431
+ diffusers_checkpoint.update(
432
+ unet_midblock_to_diffusers_checkpoint(
433
+ model,
434
+ checkpoint,
435
+ original_unet_prefix=original_unet_prefix,
436
+ num_head_channels=None,
437
+ )
438
+ )
439
+
440
+ # <original>.output_blocks -> <diffusers>.up_blocks
441
+
442
+ original_up_block_idx = 0
443
+
444
+ for diffusers_up_block_idx in range(len(model.up_blocks)):
445
+ checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
446
+ model,
447
+ checkpoint,
448
+ diffusers_up_block_idx=diffusers_up_block_idx,
449
+ original_up_block_idx=original_up_block_idx,
450
+ original_unet_prefix=original_unet_prefix,
451
+ num_head_channels=None,
452
+ )
453
+
454
+ original_up_block_idx += num_original_up_blocks
455
+
456
+ diffusers_checkpoint.update(checkpoint_update)
457
+
458
+ # done <original>.output_blocks -> <diffusers>.up_blocks
459
+
460
+ diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))
461
+ diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))
462
+
463
+ return diffusers_checkpoint
464
+
465
+
466
+ # done super res unet first steps
467
+
468
+ # super res unet last step
469
+
470
+ SUPER_RES_UNET_LAST_STEP_PREFIX = "model_last_step"
471
+
472
+ SUPER_RES_UNET_LAST_STEP_CONFIG = {
473
+ "sample_size": 256,
474
+ "layers_per_block": 3,
475
+ "down_block_types": (
476
+ "ResnetDownsampleBlock2D",
477
+ "ResnetDownsampleBlock2D",
478
+ "ResnetDownsampleBlock2D",
479
+ "ResnetDownsampleBlock2D",
480
+ ),
481
+ "up_block_types": (
482
+ "ResnetUpsampleBlock2D",
483
+ "ResnetUpsampleBlock2D",
484
+ "ResnetUpsampleBlock2D",
485
+ "ResnetUpsampleBlock2D",
486
+ ),
487
+ "block_out_channels": (320, 640, 960, 1280),
488
+ "in_channels": 6,
489
+ "out_channels": 3,
490
+ "add_attention": False,
491
+ }
492
+
493
+
494
+ def super_res_unet_last_step_model_from_original_config():
495
+ model = UNet2DModel(**SUPER_RES_UNET_LAST_STEP_CONFIG)
496
+
497
+ return model
498
+
499
+
500
+ def super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
501
+ diffusers_checkpoint = {}
502
+
503
+ original_unet_prefix = SUPER_RES_UNET_LAST_STEP_PREFIX
504
+
505
+ diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))
506
+ diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))
507
+
508
+ # <original>.input_blocks -> <diffusers>.down_blocks
509
+
510
+ original_down_block_idx = 1
511
+
512
+ for diffusers_down_block_idx in range(len(model.down_blocks)):
513
+ checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
514
+ model,
515
+ checkpoint,
516
+ diffusers_down_block_idx=diffusers_down_block_idx,
517
+ original_down_block_idx=original_down_block_idx,
518
+ original_unet_prefix=original_unet_prefix,
519
+ num_head_channels=None,
520
+ )
521
+
522
+ original_down_block_idx += num_original_down_blocks
523
+
524
+ diffusers_checkpoint.update(checkpoint_update)
525
+
526
+ diffusers_checkpoint.update(
527
+ unet_midblock_to_diffusers_checkpoint(
528
+ model,
529
+ checkpoint,
530
+ original_unet_prefix=original_unet_prefix,
531
+ num_head_channels=None,
532
+ )
533
+ )
534
+
535
+ # <original>.output_blocks -> <diffusers>.up_blocks
536
+
537
+ original_up_block_idx = 0
538
+
539
+ for diffusers_up_block_idx in range(len(model.up_blocks)):
540
+ checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
541
+ model,
542
+ checkpoint,
543
+ diffusers_up_block_idx=diffusers_up_block_idx,
544
+ original_up_block_idx=original_up_block_idx,
545
+ original_unet_prefix=original_unet_prefix,
546
+ num_head_channels=None,
547
+ )
548
+
549
+ original_up_block_idx += num_original_up_blocks
550
+
551
+ diffusers_checkpoint.update(checkpoint_update)
552
+
553
+ # done <original>.output_blocks -> <diffusers>.up_blocks
554
+
555
+ diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))
556
+ diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))
557
+
558
+ return diffusers_checkpoint
559
+
560
+
561
+ # done super res unet last step
562
+
563
+
564
+ # unet utils
565
+
566
+
567
+ # <original>.time_embed -> <diffusers>.time_embedding
568
+ def unet_time_embeddings(checkpoint, original_unet_prefix):
569
+ diffusers_checkpoint = {}
570
+
571
+ diffusers_checkpoint.update(
572
+ {
573
+ "time_embedding.linear_1.weight": checkpoint[f"{original_unet_prefix}.time_embed.0.weight"],
574
+ "time_embedding.linear_1.bias": checkpoint[f"{original_unet_prefix}.time_embed.0.bias"],
575
+ "time_embedding.linear_2.weight": checkpoint[f"{original_unet_prefix}.time_embed.2.weight"],
576
+ "time_embedding.linear_2.bias": checkpoint[f"{original_unet_prefix}.time_embed.2.bias"],
577
+ }
578
+ )
579
+
580
+ return diffusers_checkpoint
581
+
582
+
583
+ # <original>.input_blocks.0 -> <diffusers>.conv_in
584
+ def unet_conv_in(checkpoint, original_unet_prefix):
585
+ diffusers_checkpoint = {}
586
+
587
+ diffusers_checkpoint.update(
588
+ {
589
+ "conv_in.weight": checkpoint[f"{original_unet_prefix}.input_blocks.0.0.weight"],
590
+ "conv_in.bias": checkpoint[f"{original_unet_prefix}.input_blocks.0.0.bias"],
591
+ }
592
+ )
593
+
594
+ return diffusers_checkpoint
595
+
596
+
597
+ # <original>.out.0 -> <diffusers>.conv_norm_out
598
+ def unet_conv_norm_out(checkpoint, original_unet_prefix):
599
+ diffusers_checkpoint = {}
600
+
601
+ diffusers_checkpoint.update(
602
+ {
603
+ "conv_norm_out.weight": checkpoint[f"{original_unet_prefix}.out.0.weight"],
604
+ "conv_norm_out.bias": checkpoint[f"{original_unet_prefix}.out.0.bias"],
605
+ }
606
+ )
607
+
608
+ return diffusers_checkpoint
609
+
610
+
611
+ # <original>.out.2 -> <diffusers>.conv_out
612
+ def unet_conv_out(checkpoint, original_unet_prefix):
613
+ diffusers_checkpoint = {}
614
+
615
+ diffusers_checkpoint.update(
616
+ {
617
+ "conv_out.weight": checkpoint[f"{original_unet_prefix}.out.2.weight"],
618
+ "conv_out.bias": checkpoint[f"{original_unet_prefix}.out.2.bias"],
619
+ }
620
+ )
621
+
622
+ return diffusers_checkpoint
623
+
624
+
625
+ # <original>.input_blocks -> <diffusers>.down_blocks
626
+ def unet_downblock_to_diffusers_checkpoint(
627
+ model, checkpoint, *, diffusers_down_block_idx, original_down_block_idx, original_unet_prefix, num_head_channels
628
+ ):
629
+ diffusers_checkpoint = {}
630
+
631
+ diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.resnets"
632
+ original_down_block_prefix = f"{original_unet_prefix}.input_blocks"
633
+
634
+ down_block = model.down_blocks[diffusers_down_block_idx]
635
+
636
+ num_resnets = len(down_block.resnets)
637
+
638
+ if down_block.downsamplers is None:
639
+ downsampler = False
640
+ else:
641
+ assert len(down_block.downsamplers) == 1
642
+ downsampler = True
643
+ # The downsample block is also a resnet
644
+ num_resnets += 1
645
+
646
+ for resnet_idx_inc in range(num_resnets):
647
+ full_resnet_prefix = f"{original_down_block_prefix}.{original_down_block_idx + resnet_idx_inc}.0"
648
+
649
+ if downsampler and resnet_idx_inc == num_resnets - 1:
650
+ # this is a downsample block
651
+ full_diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.downsamplers.0"
652
+ else:
653
+ # this is a regular resnet block
654
+ full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}"
655
+
656
+ diffusers_checkpoint.update(
657
+ resnet_to_diffusers_checkpoint(
658
+ checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix
659
+ )
660
+ )
661
+
662
+ if hasattr(down_block, "attentions"):
663
+ num_attentions = len(down_block.attentions)
664
+ diffusers_attention_prefix = f"down_blocks.{diffusers_down_block_idx}.attentions"
665
+
666
+ for attention_idx_inc in range(num_attentions):
667
+ full_attention_prefix = f"{original_down_block_prefix}.{original_down_block_idx + attention_idx_inc}.1"
668
+ full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}"
669
+
670
+ diffusers_checkpoint.update(
671
+ attention_to_diffusers_checkpoint(
672
+ checkpoint,
673
+ attention_prefix=full_attention_prefix,
674
+ diffusers_attention_prefix=full_diffusers_attention_prefix,
675
+ num_head_channels=num_head_channels,
676
+ )
677
+ )
678
+
679
+ num_original_down_blocks = num_resnets
680
+
681
+ return diffusers_checkpoint, num_original_down_blocks
682
+
683
+
684
+ # <original>.middle_block -> <diffusers>.mid_block
685
+ def unet_midblock_to_diffusers_checkpoint(model, checkpoint, *, original_unet_prefix, num_head_channels):
686
+ diffusers_checkpoint = {}
687
+
688
+ # block 0
689
+
690
+ original_block_idx = 0
691
+
692
+ diffusers_checkpoint.update(
693
+ resnet_to_diffusers_checkpoint(
694
+ checkpoint,
695
+ diffusers_resnet_prefix="mid_block.resnets.0",
696
+ resnet_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}",
697
+ )
698
+ )
699
+
700
+ original_block_idx += 1
701
+
702
+ # optional block 1
703
+
704
+ if hasattr(model.mid_block, "attentions") and model.mid_block.attentions[0] is not None:
705
+ diffusers_checkpoint.update(
706
+ attention_to_diffusers_checkpoint(
707
+ checkpoint,
708
+ diffusers_attention_prefix="mid_block.attentions.0",
709
+ attention_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}",
710
+ num_head_channels=num_head_channels,
711
+ )
712
+ )
713
+ original_block_idx += 1
714
+
715
+ # block 1 or block 2
716
+
717
+ diffusers_checkpoint.update(
718
+ resnet_to_diffusers_checkpoint(
719
+ checkpoint,
720
+ diffusers_resnet_prefix="mid_block.resnets.1",
721
+ resnet_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}",
722
+ )
723
+ )
724
+
725
+ return diffusers_checkpoint
726
+
727
+
728
+ # <original>.output_blocks -> <diffusers>.up_blocks
729
+ def unet_upblock_to_diffusers_checkpoint(
730
+ model, checkpoint, *, diffusers_up_block_idx, original_up_block_idx, original_unet_prefix, num_head_channels
731
+ ):
732
+ diffusers_checkpoint = {}
733
+
734
+ diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.resnets"
735
+ original_up_block_prefix = f"{original_unet_prefix}.output_blocks"
736
+
737
+ up_block = model.up_blocks[diffusers_up_block_idx]
738
+
739
+ num_resnets = len(up_block.resnets)
740
+
741
+ if up_block.upsamplers is None:
742
+ upsampler = False
743
+ else:
744
+ assert len(up_block.upsamplers) == 1
745
+ upsampler = True
746
+ # The upsample block is also a resnet
747
+ num_resnets += 1
748
+
749
+ has_attentions = hasattr(up_block, "attentions")
750
+
751
+ for resnet_idx_inc in range(num_resnets):
752
+ if upsampler and resnet_idx_inc == num_resnets - 1:
753
+ # this is an upsample block
754
+ if has_attentions:
755
+ # There is a middle attention block that we skip
756
+ original_resnet_block_idx = 2
757
+ else:
758
+ original_resnet_block_idx = 1
759
+
760
+ # we add the `minus 1` because the last two resnets are stuck together in the same output block
761
+ full_resnet_prefix = (
762
+ f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc - 1}.{original_resnet_block_idx}"
763
+ )
764
+
765
+ full_diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.upsamplers.0"
766
+ else:
767
+ # this is a regular resnet block
768
+ full_resnet_prefix = f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc}.0"
769
+ full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}"
770
+
771
+ diffusers_checkpoint.update(
772
+ resnet_to_diffusers_checkpoint(
773
+ checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix
774
+ )
775
+ )
776
+
777
+ if has_attentions:
778
+ num_attentions = len(up_block.attentions)
779
+ diffusers_attention_prefix = f"up_blocks.{diffusers_up_block_idx}.attentions"
780
+
781
+ for attention_idx_inc in range(num_attentions):
782
+ full_attention_prefix = f"{original_up_block_prefix}.{original_up_block_idx + attention_idx_inc}.1"
783
+ full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}"
784
+
785
+ diffusers_checkpoint.update(
786
+ attention_to_diffusers_checkpoint(
787
+ checkpoint,
788
+ attention_prefix=full_attention_prefix,
789
+ diffusers_attention_prefix=full_diffusers_attention_prefix,
790
+ num_head_channels=num_head_channels,
791
+ )
792
+ )
793
+
794
+ num_original_down_blocks = num_resnets - 1 if upsampler else num_resnets
795
+
796
+ return diffusers_checkpoint, num_original_down_blocks
797
+
798
+
799
+ def resnet_to_diffusers_checkpoint(checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
800
+ diffusers_checkpoint = {
801
+ f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.in_layers.0.weight"],
802
+ f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.in_layers.0.bias"],
803
+ f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.in_layers.2.weight"],
804
+ f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.in_layers.2.bias"],
805
+ f"{diffusers_resnet_prefix}.time_emb_proj.weight": checkpoint[f"{resnet_prefix}.emb_layers.1.weight"],
806
+ f"{diffusers_resnet_prefix}.time_emb_proj.bias": checkpoint[f"{resnet_prefix}.emb_layers.1.bias"],
807
+ f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.out_layers.0.weight"],
808
+ f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.out_layers.0.bias"],
809
+ f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.out_layers.3.weight"],
810
+ f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.out_layers.3.bias"],
811
+ }
812
+
813
+ skip_connection_prefix = f"{resnet_prefix}.skip_connection"
814
+
815
+ if f"{skip_connection_prefix}.weight" in checkpoint:
816
+ diffusers_checkpoint.update(
817
+ {
818
+ f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{skip_connection_prefix}.weight"],
819
+ f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{skip_connection_prefix}.bias"],
820
+ }
821
+ )
822
+
823
+ return diffusers_checkpoint
824
+
825
+
826
+ def attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix, num_head_channels):
827
+ diffusers_checkpoint = {}
828
+
829
+ # <original>.norm -> <diffusers>.group_norm
830
+ diffusers_checkpoint.update(
831
+ {
832
+ f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],
833
+ f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],
834
+ }
835
+ )
836
+
837
+ # <original>.qkv -> <diffusers>.{query, key, value}
838
+ [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
839
+ weight=checkpoint[f"{attention_prefix}.qkv.weight"][:, :, 0],
840
+ bias=checkpoint[f"{attention_prefix}.qkv.bias"],
841
+ split=3,
842
+ chunk_size=num_head_channels,
843
+ )
844
+
845
+ diffusers_checkpoint.update(
846
+ {
847
+ f"{diffusers_attention_prefix}.to_q.weight": q_weight,
848
+ f"{diffusers_attention_prefix}.to_q.bias": q_bias,
849
+ f"{diffusers_attention_prefix}.to_k.weight": k_weight,
850
+ f"{diffusers_attention_prefix}.to_k.bias": k_bias,
851
+ f"{diffusers_attention_prefix}.to_v.weight": v_weight,
852
+ f"{diffusers_attention_prefix}.to_v.bias": v_bias,
853
+ }
854
+ )
855
+
856
+ # <original>.encoder_kv -> <diffusers>.{context_key, context_value}
857
+ [encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions(
858
+ weight=checkpoint[f"{attention_prefix}.encoder_kv.weight"][:, :, 0],
859
+ bias=checkpoint[f"{attention_prefix}.encoder_kv.bias"],
860
+ split=2,
861
+ chunk_size=num_head_channels,
862
+ )
863
+
864
+ diffusers_checkpoint.update(
865
+ {
866
+ f"{diffusers_attention_prefix}.add_k_proj.weight": encoder_k_weight,
867
+ f"{diffusers_attention_prefix}.add_k_proj.bias": encoder_k_bias,
868
+ f"{diffusers_attention_prefix}.add_v_proj.weight": encoder_v_weight,
869
+ f"{diffusers_attention_prefix}.add_v_proj.bias": encoder_v_bias,
870
+ }
871
+ )
872
+
873
+ # <original>.proj_out (1d conv) -> <diffusers>.proj_attn (linear)
874
+ diffusers_checkpoint.update(
875
+ {
876
+ f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][
877
+ :, :, 0
878
+ ],
879
+ f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
880
+ }
881
+ )
882
+
883
+ return diffusers_checkpoint
884
+
885
+
886
+ # TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
887
+ def split_attentions(*, weight, bias, split, chunk_size):
888
+ weights = [None] * split
889
+ biases = [None] * split
890
+
891
+ weights_biases_idx = 0
892
+
893
+ for starting_row_index in range(0, weight.shape[0], chunk_size):
894
+ row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
895
+
896
+ weight_rows = weight[row_indices, :]
897
+ bias_rows = bias[row_indices]
898
+
899
+ if weights[weights_biases_idx] is None:
900
+ assert weights[weights_biases_idx] is None
901
+ weights[weights_biases_idx] = weight_rows
902
+ biases[weights_biases_idx] = bias_rows
903
+ else:
904
+ assert weights[weights_biases_idx] is not None
905
+ weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
906
+ biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
907
+
908
+ weights_biases_idx = (weights_biases_idx + 1) % split
909
+
910
+ return weights, biases
911
+
912
+
913
+ # done unet utils
914
+
915
+
916
+ # Driver functions
917
+
918
+
919
+ def text_encoder():
920
+ print("loading CLIP text encoder")
921
+
922
+ clip_name = "openai/clip-vit-large-patch14"
923
+
924
+ # sets pad_value to 0
925
+ pad_token = "!"
926
+
927
+ tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto")
928
+
929
+ assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0
930
+
931
+ text_encoder_model = CLIPTextModelWithProjection.from_pretrained(
932
+ clip_name,
933
+ # `CLIPTextModel` does not support device_map="auto"
934
+ # device_map="auto"
935
+ )
936
+
937
+ print("done loading CLIP text encoder")
938
+
939
+ return text_encoder_model, tokenizer_model
940
+
941
+
942
+ def prior(*, args, checkpoint_map_location):
943
+ print("loading prior")
944
+
945
+ prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location)
946
+ prior_checkpoint = prior_checkpoint["state_dict"]
947
+
948
+ clip_stats_checkpoint = torch.load(args.clip_stat_path, map_location=checkpoint_map_location)
949
+
950
+ prior_model = prior_model_from_original_config()
951
+
952
+ prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint(
953
+ prior_model, prior_checkpoint, clip_stats_checkpoint
954
+ )
955
+
956
+ del prior_checkpoint
957
+ del clip_stats_checkpoint
958
+
959
+ load_checkpoint_to_model(prior_diffusers_checkpoint, prior_model, strict=True)
960
+
961
+ print("done loading prior")
962
+
963
+ return prior_model
964
+
965
+
966
+ def decoder(*, args, checkpoint_map_location):
967
+ print("loading decoder")
968
+
969
+ decoder_checkpoint = torch.load(args.decoder_checkpoint_path, map_location=checkpoint_map_location)
970
+ decoder_checkpoint = decoder_checkpoint["state_dict"]
971
+
972
+ decoder_model = decoder_model_from_original_config()
973
+
974
+ decoder_diffusers_checkpoint = decoder_original_checkpoint_to_diffusers_checkpoint(
975
+ decoder_model, decoder_checkpoint
976
+ )
977
+
978
+ # text proj interlude
979
+
980
+ # The original decoder implementation includes a set of parameters that are used
981
+ # for creating the `encoder_hidden_states` which are what the U-net is conditioned
982
+ # on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
983
+ # the parameters into the UnCLIPTextProjModel class
984
+ text_proj_model = text_proj_from_original_config()
985
+
986
+ text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(decoder_checkpoint)
987
+
988
+ load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True)
989
+
990
+ # done text proj interlude
991
+
992
+ del decoder_checkpoint
993
+
994
+ load_checkpoint_to_model(decoder_diffusers_checkpoint, decoder_model, strict=True)
995
+
996
+ print("done loading decoder")
997
+
998
+ return decoder_model, text_proj_model
999
+
1000
+
1001
+ def super_res_unet(*, args, checkpoint_map_location):
1002
+ print("loading super resolution unet")
1003
+
1004
+ super_res_checkpoint = torch.load(args.super_res_unet_checkpoint_path, map_location=checkpoint_map_location)
1005
+ super_res_checkpoint = super_res_checkpoint["state_dict"]
1006
+
1007
+ # model_first_steps
1008
+
1009
+ super_res_first_model = super_res_unet_first_steps_model_from_original_config()
1010
+
1011
+ super_res_first_steps_checkpoint = super_res_unet_first_steps_original_checkpoint_to_diffusers_checkpoint(
1012
+ super_res_first_model, super_res_checkpoint
1013
+ )
1014
+
1015
+ # model_last_step
1016
+ super_res_last_model = super_res_unet_last_step_model_from_original_config()
1017
+
1018
+ super_res_last_step_checkpoint = super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(
1019
+ super_res_last_model, super_res_checkpoint
1020
+ )
1021
+
1022
+ del super_res_checkpoint
1023
+
1024
+ load_checkpoint_to_model(super_res_first_steps_checkpoint, super_res_first_model, strict=True)
1025
+
1026
+ load_checkpoint_to_model(super_res_last_step_checkpoint, super_res_last_model, strict=True)
1027
+
1028
+ print("done loading super resolution unet")
1029
+
1030
+ return super_res_first_model, super_res_last_model
1031
+
1032
+
1033
+ def load_checkpoint_to_model(checkpoint, model, strict=False):
1034
+ with tempfile.NamedTemporaryFile() as file:
1035
+ torch.save(checkpoint, file.name)
1036
+ del checkpoint
1037
+ if strict:
1038
+ model.load_state_dict(torch.load(file.name), strict=True)
1039
+ else:
1040
+ load_checkpoint_and_dispatch(model, file.name, device_map="auto")
1041
+
1042
+
1043
+ if __name__ == "__main__":
1044
+ parser = argparse.ArgumentParser()
1045
+
1046
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
1047
+
1048
+ parser.add_argument(
1049
+ "--prior_checkpoint_path",
1050
+ default=None,
1051
+ type=str,
1052
+ required=True,
1053
+ help="Path to the prior checkpoint to convert.",
1054
+ )
1055
+
1056
+ parser.add_argument(
1057
+ "--decoder_checkpoint_path",
1058
+ default=None,
1059
+ type=str,
1060
+ required=True,
1061
+ help="Path to the decoder checkpoint to convert.",
1062
+ )
1063
+
1064
+ parser.add_argument(
1065
+ "--super_res_unet_checkpoint_path",
1066
+ default=None,
1067
+ type=str,
1068
+ required=True,
1069
+ help="Path to the super resolution checkpoint to convert.",
1070
+ )
1071
+
1072
+ parser.add_argument(
1073
+ "--clip_stat_path", default=None, type=str, required=True, help="Path to the clip stats checkpoint to convert."
1074
+ )
1075
+
1076
+ parser.add_argument(
1077
+ "--checkpoint_load_device",
1078
+ default="cpu",
1079
+ type=str,
1080
+ required=False,
1081
+ help="The device passed to `map_location` when loading checkpoints.",
1082
+ )
1083
+
1084
+ parser.add_argument(
1085
+ "--debug",
1086
+ default=None,
1087
+ type=str,
1088
+ required=False,
1089
+ help="Only run a specific stage of the convert script. Used for debugging",
1090
+ )
1091
+
1092
+ args = parser.parse_args()
1093
+
1094
+ print(f"loading checkpoints to {args.checkpoint_load_device}")
1095
+
1096
+ checkpoint_map_location = torch.device(args.checkpoint_load_device)
1097
+
1098
+ if args.debug is not None:
1099
+ print(f"debug: only executing {args.debug}")
1100
+
1101
+ if args.debug is None:
1102
+ text_encoder_model, tokenizer_model = text_encoder()
1103
+
1104
+ prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
1105
+
1106
+ decoder_model, text_proj_model = decoder(args=args, checkpoint_map_location=checkpoint_map_location)
1107
+
1108
+ super_res_first_model, super_res_last_model = super_res_unet(
1109
+ args=args, checkpoint_map_location=checkpoint_map_location
1110
+ )
1111
+
1112
+ prior_scheduler = UnCLIPScheduler(
1113
+ variance_type="fixed_small_log",
1114
+ prediction_type="sample",
1115
+ num_train_timesteps=1000,
1116
+ clip_sample_range=5.0,
1117
+ )
1118
+
1119
+ decoder_scheduler = UnCLIPScheduler(
1120
+ variance_type="learned_range",
1121
+ prediction_type="epsilon",
1122
+ num_train_timesteps=1000,
1123
+ )
1124
+
1125
+ super_res_scheduler = UnCLIPScheduler(
1126
+ variance_type="fixed_small_log",
1127
+ prediction_type="epsilon",
1128
+ num_train_timesteps=1000,
1129
+ )
1130
+
1131
+ print(f"saving Kakao Brain unCLIP to {args.dump_path}")
1132
+
1133
+ pipe = UnCLIPPipeline(
1134
+ prior=prior_model,
1135
+ decoder=decoder_model,
1136
+ text_proj=text_proj_model,
1137
+ tokenizer=tokenizer_model,
1138
+ text_encoder=text_encoder_model,
1139
+ super_res_first=super_res_first_model,
1140
+ super_res_last=super_res_last_model,
1141
+ prior_scheduler=prior_scheduler,
1142
+ decoder_scheduler=decoder_scheduler,
1143
+ super_res_scheduler=super_res_scheduler,
1144
+ )
1145
+ pipe.save_pretrained(args.dump_path)
1146
+
1147
+ print("done writing Kakao Brain unCLIP")
1148
+ elif args.debug == "text_encoder":
1149
+ text_encoder_model, tokenizer_model = text_encoder()
1150
+ elif args.debug == "prior":
1151
+ prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
1152
+ elif args.debug == "decoder":
1153
+ decoder_model, text_proj_model = decoder(args=args, checkpoint_map_location=checkpoint_map_location)
1154
+ elif args.debug == "super_res_unet":
1155
+ super_res_first_model, super_res_last_model = super_res_unet(
1156
+ args=args, checkpoint_map_location=checkpoint_map_location
1157
+ )
1158
+ else:
1159
+ raise ValueError(f"unknown debug value : {args.debug}")
diffusers/scripts/convert_kandinsky3_unet.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import fnmatch
4
+
5
+ from safetensors.torch import load_file
6
+
7
+ from diffusers import Kandinsky3UNet
8
+
9
+
10
+ MAPPING = {
11
+ "to_time_embed.1": "time_embedding.linear_1",
12
+ "to_time_embed.3": "time_embedding.linear_2",
13
+ "in_layer": "conv_in",
14
+ "out_layer.0": "conv_norm_out",
15
+ "out_layer.2": "conv_out",
16
+ "down_samples": "down_blocks",
17
+ "up_samples": "up_blocks",
18
+ "projection_lin": "encoder_hid_proj.projection_linear",
19
+ "projection_ln": "encoder_hid_proj.projection_norm",
20
+ "feature_pooling": "add_time_condition",
21
+ "to_query": "to_q",
22
+ "to_key": "to_k",
23
+ "to_value": "to_v",
24
+ "output_layer": "to_out.0",
25
+ "self_attention_block": "attentions.0",
26
+ }
27
+
28
+ DYNAMIC_MAP = {
29
+ "resnet_attn_blocks.*.0": "resnets_in.*",
30
+ "resnet_attn_blocks.*.1": ("attentions.*", 1),
31
+ "resnet_attn_blocks.*.2": "resnets_out.*",
32
+ }
33
+ # MAPPING = {}
34
+
35
+
36
+ def convert_state_dict(unet_state_dict):
37
+ """
38
+ Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model.
39
+ Args:
40
+ unet_model (torch.nn.Module): The original U-Net model.
41
+ unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with.
42
+
43
+ Returns:
44
+ OrderedDict: The converted state dictionary.
45
+ """
46
+ # Example of renaming logic (this will vary based on your model's architecture)
47
+ converted_state_dict = {}
48
+ for key in unet_state_dict:
49
+ new_key = key
50
+ for pattern, new_pattern in MAPPING.items():
51
+ new_key = new_key.replace(pattern, new_pattern)
52
+
53
+ for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items():
54
+ has_matched = False
55
+ if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched:
56
+ star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1])
57
+
58
+ if isinstance(dyn_new_pattern, tuple):
59
+ new_star = star + dyn_new_pattern[-1]
60
+ dyn_new_pattern = dyn_new_pattern[0]
61
+ else:
62
+ new_star = star
63
+
64
+ pattern = dyn_pattern.replace("*", str(star))
65
+ new_pattern = dyn_new_pattern.replace("*", str(new_star))
66
+
67
+ new_key = new_key.replace(pattern, new_pattern)
68
+ has_matched = True
69
+
70
+ converted_state_dict[new_key] = unet_state_dict[key]
71
+
72
+ return converted_state_dict
73
+
74
+
75
+ def main(model_path, output_path):
76
+ # Load your original U-Net model
77
+ unet_state_dict = load_file(model_path)
78
+
79
+ # Initialize your Kandinsky3UNet model
80
+ config = {}
81
+
82
+ # Convert the state dict
83
+ converted_state_dict = convert_state_dict(unet_state_dict)
84
+
85
+ unet = Kandinsky3UNet(config)
86
+ unet.load_state_dict(converted_state_dict)
87
+
88
+ unet.save_pretrained(output_path)
89
+ print(f"Converted model saved to {output_path}")
90
+
91
+
92
+ if __name__ == "__main__":
93
+ parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format")
94
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model")
95
+ parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
96
+
97
+ args = parser.parse_args()
98
+ main(args.model_path, args.output_path)
diffusers/scripts/convert_kandinsky_to_diffusers.py ADDED
@@ -0,0 +1,1411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import tempfile
4
+
5
+ import torch
6
+ from accelerate import load_checkpoint_and_dispatch
7
+
8
+ from diffusers import UNet2DConditionModel
9
+ from diffusers.models.transformers.prior_transformer import PriorTransformer
10
+ from diffusers.models.vq_model import VQModel
11
+
12
+
13
+ """
14
+ Example - From the diffusers root directory:
15
+
16
+ Download weights:
17
+ ```sh
18
+ $ wget https://huggingface.co/ai-forever/Kandinsky_2.1/blob/main/prior_fp16.ckpt
19
+ ```
20
+
21
+ Convert the model:
22
+ ```sh
23
+ python scripts/convert_kandinsky_to_diffusers.py \
24
+ --prior_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/prior_fp16.ckpt \
25
+ --clip_stat_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/ViT-L-14_stats.th \
26
+ --text2img_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/decoder_fp16.ckpt \
27
+ --inpaint_text2img_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/inpainting_fp16.ckpt \
28
+ --movq_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/movq_final.ckpt \
29
+ --dump_path /home/yiyi_huggingface_co/dump \
30
+ --debug decoder
31
+ ```
32
+ """
33
+
34
+
35
+ # prior
36
+
37
+ PRIOR_ORIGINAL_PREFIX = "model"
38
+
39
+ # Uses default arguments
40
+ PRIOR_CONFIG = {}
41
+
42
+
43
+ def prior_model_from_original_config():
44
+ model = PriorTransformer(**PRIOR_CONFIG)
45
+
46
+ return model
47
+
48
+
49
+ def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint, clip_stats_checkpoint):
50
+ diffusers_checkpoint = {}
51
+
52
+ # <original>.time_embed.0 -> <diffusers>.time_embedding.linear_1
53
+ diffusers_checkpoint.update(
54
+ {
55
+ "time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.weight"],
56
+ "time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.bias"],
57
+ }
58
+ )
59
+
60
+ # <original>.clip_img_proj -> <diffusers>.proj_in
61
+ diffusers_checkpoint.update(
62
+ {
63
+ "proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.weight"],
64
+ "proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.bias"],
65
+ }
66
+ )
67
+
68
+ # <original>.text_emb_proj -> <diffusers>.embedding_proj
69
+ diffusers_checkpoint.update(
70
+ {
71
+ "embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.weight"],
72
+ "embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.bias"],
73
+ }
74
+ )
75
+
76
+ # <original>.text_enc_proj -> <diffusers>.encoder_hidden_states_proj
77
+ diffusers_checkpoint.update(
78
+ {
79
+ "encoder_hidden_states_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.weight"],
80
+ "encoder_hidden_states_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.bias"],
81
+ }
82
+ )
83
+
84
+ # <original>.positional_embedding -> <diffusers>.positional_embedding
85
+ diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.positional_embedding"]})
86
+
87
+ # <original>.prd_emb -> <diffusers>.prd_embedding
88
+ diffusers_checkpoint.update({"prd_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.prd_emb"]})
89
+
90
+ # <original>.time_embed.2 -> <diffusers>.time_embedding.linear_2
91
+ diffusers_checkpoint.update(
92
+ {
93
+ "time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.weight"],
94
+ "time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.bias"],
95
+ }
96
+ )
97
+
98
+ # <original>.resblocks.<x> -> <diffusers>.transformer_blocks.<x>
99
+ for idx in range(len(model.transformer_blocks)):
100
+ diffusers_transformer_prefix = f"transformer_blocks.{idx}"
101
+ original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.transformer.resblocks.{idx}"
102
+
103
+ # <original>.attn -> <diffusers>.attn1
104
+ diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"
105
+ original_attention_prefix = f"{original_transformer_prefix}.attn"
106
+ diffusers_checkpoint.update(
107
+ prior_attention_to_diffusers(
108
+ checkpoint,
109
+ diffusers_attention_prefix=diffusers_attention_prefix,
110
+ original_attention_prefix=original_attention_prefix,
111
+ attention_head_dim=model.attention_head_dim,
112
+ )
113
+ )
114
+
115
+ # <original>.mlp -> <diffusers>.ff
116
+ diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"
117
+ original_ff_prefix = f"{original_transformer_prefix}.mlp"
118
+ diffusers_checkpoint.update(
119
+ prior_ff_to_diffusers(
120
+ checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix
121
+ )
122
+ )
123
+
124
+ # <original>.ln_1 -> <diffusers>.norm1
125
+ diffusers_checkpoint.update(
126
+ {
127
+ f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[
128
+ f"{original_transformer_prefix}.ln_1.weight"
129
+ ],
130
+ f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],
131
+ }
132
+ )
133
+
134
+ # <original>.ln_2 -> <diffusers>.norm3
135
+ diffusers_checkpoint.update(
136
+ {
137
+ f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[
138
+ f"{original_transformer_prefix}.ln_2.weight"
139
+ ],
140
+ f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],
141
+ }
142
+ )
143
+
144
+ # <original>.final_ln -> <diffusers>.norm_out
145
+ diffusers_checkpoint.update(
146
+ {
147
+ "norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.weight"],
148
+ "norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.bias"],
149
+ }
150
+ )
151
+
152
+ # <original>.out_proj -> <diffusers>.proj_to_clip_embeddings
153
+ diffusers_checkpoint.update(
154
+ {
155
+ "proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.weight"],
156
+ "proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.bias"],
157
+ }
158
+ )
159
+
160
+ # clip stats
161
+ clip_mean, clip_std = clip_stats_checkpoint
162
+ clip_mean = clip_mean[None, :]
163
+ clip_std = clip_std[None, :]
164
+
165
+ diffusers_checkpoint.update({"clip_mean": clip_mean, "clip_std": clip_std})
166
+
167
+ return diffusers_checkpoint
168
+
169
+
170
+ def prior_attention_to_diffusers(
171
+ checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim
172
+ ):
173
+ diffusers_checkpoint = {}
174
+
175
+ # <original>.c_qkv -> <diffusers>.{to_q, to_k, to_v}
176
+ [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
177
+ weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"],
178
+ bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"],
179
+ split=3,
180
+ chunk_size=attention_head_dim,
181
+ )
182
+
183
+ diffusers_checkpoint.update(
184
+ {
185
+ f"{diffusers_attention_prefix}.to_q.weight": q_weight,
186
+ f"{diffusers_attention_prefix}.to_q.bias": q_bias,
187
+ f"{diffusers_attention_prefix}.to_k.weight": k_weight,
188
+ f"{diffusers_attention_prefix}.to_k.bias": k_bias,
189
+ f"{diffusers_attention_prefix}.to_v.weight": v_weight,
190
+ f"{diffusers_attention_prefix}.to_v.bias": v_bias,
191
+ }
192
+ )
193
+
194
+ # <original>.c_proj -> <diffusers>.to_out.0
195
+ diffusers_checkpoint.update(
196
+ {
197
+ f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"],
198
+ f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"],
199
+ }
200
+ )
201
+
202
+ return diffusers_checkpoint
203
+
204
+
205
+ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix):
206
+ diffusers_checkpoint = {
207
+ # <original>.c_fc -> <diffusers>.net.0.proj
208
+ f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"],
209
+ f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"],
210
+ # <original>.c_proj -> <diffusers>.net.2
211
+ f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"],
212
+ f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"],
213
+ }
214
+
215
+ return diffusers_checkpoint
216
+
217
+
218
+ # done prior
219
+
220
+ # unet
221
+
222
+ # We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can
223
+ # update then.
224
+
225
+ UNET_CONFIG = {
226
+ "act_fn": "silu",
227
+ "addition_embed_type": "text_image",
228
+ "addition_embed_type_num_heads": 64,
229
+ "attention_head_dim": 64,
230
+ "block_out_channels": [384, 768, 1152, 1536],
231
+ "center_input_sample": False,
232
+ "class_embed_type": None,
233
+ "class_embeddings_concat": False,
234
+ "conv_in_kernel": 3,
235
+ "conv_out_kernel": 3,
236
+ "cross_attention_dim": 768,
237
+ "cross_attention_norm": None,
238
+ "down_block_types": [
239
+ "ResnetDownsampleBlock2D",
240
+ "SimpleCrossAttnDownBlock2D",
241
+ "SimpleCrossAttnDownBlock2D",
242
+ "SimpleCrossAttnDownBlock2D",
243
+ ],
244
+ "downsample_padding": 1,
245
+ "dual_cross_attention": False,
246
+ "encoder_hid_dim": 1024,
247
+ "encoder_hid_dim_type": "text_image_proj",
248
+ "flip_sin_to_cos": True,
249
+ "freq_shift": 0,
250
+ "in_channels": 4,
251
+ "layers_per_block": 3,
252
+ "mid_block_only_cross_attention": None,
253
+ "mid_block_scale_factor": 1,
254
+ "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
255
+ "norm_eps": 1e-05,
256
+ "norm_num_groups": 32,
257
+ "num_class_embeds": None,
258
+ "only_cross_attention": False,
259
+ "out_channels": 8,
260
+ "projection_class_embeddings_input_dim": None,
261
+ "resnet_out_scale_factor": 1.0,
262
+ "resnet_skip_time_act": False,
263
+ "resnet_time_scale_shift": "scale_shift",
264
+ "sample_size": 64,
265
+ "time_cond_proj_dim": None,
266
+ "time_embedding_act_fn": None,
267
+ "time_embedding_dim": None,
268
+ "time_embedding_type": "positional",
269
+ "timestep_post_act": None,
270
+ "up_block_types": [
271
+ "SimpleCrossAttnUpBlock2D",
272
+ "SimpleCrossAttnUpBlock2D",
273
+ "SimpleCrossAttnUpBlock2D",
274
+ "ResnetUpsampleBlock2D",
275
+ ],
276
+ "upcast_attention": False,
277
+ "use_linear_projection": False,
278
+ }
279
+
280
+
281
+ def unet_model_from_original_config():
282
+ model = UNet2DConditionModel(**UNET_CONFIG)
283
+
284
+ return model
285
+
286
+
287
+ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
288
+ diffusers_checkpoint = {}
289
+
290
+ num_head_channels = UNET_CONFIG["attention_head_dim"]
291
+
292
+ diffusers_checkpoint.update(unet_time_embeddings(checkpoint))
293
+ diffusers_checkpoint.update(unet_conv_in(checkpoint))
294
+ diffusers_checkpoint.update(unet_add_embedding(checkpoint))
295
+ diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint))
296
+
297
+ # <original>.input_blocks -> <diffusers>.down_blocks
298
+
299
+ original_down_block_idx = 1
300
+
301
+ for diffusers_down_block_idx in range(len(model.down_blocks)):
302
+ checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
303
+ model,
304
+ checkpoint,
305
+ diffusers_down_block_idx=diffusers_down_block_idx,
306
+ original_down_block_idx=original_down_block_idx,
307
+ num_head_channels=num_head_channels,
308
+ )
309
+
310
+ original_down_block_idx += num_original_down_blocks
311
+
312
+ diffusers_checkpoint.update(checkpoint_update)
313
+
314
+ # done <original>.input_blocks -> <diffusers>.down_blocks
315
+
316
+ diffusers_checkpoint.update(
317
+ unet_midblock_to_diffusers_checkpoint(
318
+ model,
319
+ checkpoint,
320
+ num_head_channels=num_head_channels,
321
+ )
322
+ )
323
+
324
+ # <original>.output_blocks -> <diffusers>.up_blocks
325
+
326
+ original_up_block_idx = 0
327
+
328
+ for diffusers_up_block_idx in range(len(model.up_blocks)):
329
+ checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
330
+ model,
331
+ checkpoint,
332
+ diffusers_up_block_idx=diffusers_up_block_idx,
333
+ original_up_block_idx=original_up_block_idx,
334
+ num_head_channels=num_head_channels,
335
+ )
336
+
337
+ original_up_block_idx += num_original_up_blocks
338
+
339
+ diffusers_checkpoint.update(checkpoint_update)
340
+
341
+ # done <original>.output_blocks -> <diffusers>.up_blocks
342
+
343
+ diffusers_checkpoint.update(unet_conv_norm_out(checkpoint))
344
+ diffusers_checkpoint.update(unet_conv_out(checkpoint))
345
+
346
+ return diffusers_checkpoint
347
+
348
+
349
+ # done unet
350
+
351
+ # inpaint unet
352
+
353
+ # We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can
354
+ # update then.
355
+
356
+ INPAINT_UNET_CONFIG = {
357
+ "act_fn": "silu",
358
+ "addition_embed_type": "text_image",
359
+ "addition_embed_type_num_heads": 64,
360
+ "attention_head_dim": 64,
361
+ "block_out_channels": [384, 768, 1152, 1536],
362
+ "center_input_sample": False,
363
+ "class_embed_type": None,
364
+ "class_embeddings_concat": None,
365
+ "conv_in_kernel": 3,
366
+ "conv_out_kernel": 3,
367
+ "cross_attention_dim": 768,
368
+ "cross_attention_norm": None,
369
+ "down_block_types": [
370
+ "ResnetDownsampleBlock2D",
371
+ "SimpleCrossAttnDownBlock2D",
372
+ "SimpleCrossAttnDownBlock2D",
373
+ "SimpleCrossAttnDownBlock2D",
374
+ ],
375
+ "downsample_padding": 1,
376
+ "dual_cross_attention": False,
377
+ "encoder_hid_dim": 1024,
378
+ "encoder_hid_dim_type": "text_image_proj",
379
+ "flip_sin_to_cos": True,
380
+ "freq_shift": 0,
381
+ "in_channels": 9,
382
+ "layers_per_block": 3,
383
+ "mid_block_only_cross_attention": None,
384
+ "mid_block_scale_factor": 1,
385
+ "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
386
+ "norm_eps": 1e-05,
387
+ "norm_num_groups": 32,
388
+ "num_class_embeds": None,
389
+ "only_cross_attention": False,
390
+ "out_channels": 8,
391
+ "projection_class_embeddings_input_dim": None,
392
+ "resnet_out_scale_factor": 1.0,
393
+ "resnet_skip_time_act": False,
394
+ "resnet_time_scale_shift": "scale_shift",
395
+ "sample_size": 64,
396
+ "time_cond_proj_dim": None,
397
+ "time_embedding_act_fn": None,
398
+ "time_embedding_dim": None,
399
+ "time_embedding_type": "positional",
400
+ "timestep_post_act": None,
401
+ "up_block_types": [
402
+ "SimpleCrossAttnUpBlock2D",
403
+ "SimpleCrossAttnUpBlock2D",
404
+ "SimpleCrossAttnUpBlock2D",
405
+ "ResnetUpsampleBlock2D",
406
+ ],
407
+ "upcast_attention": False,
408
+ "use_linear_projection": False,
409
+ }
410
+
411
+
412
+ def inpaint_unet_model_from_original_config():
413
+ model = UNet2DConditionModel(**INPAINT_UNET_CONFIG)
414
+
415
+ return model
416
+
417
+
418
+ def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
419
+ diffusers_checkpoint = {}
420
+
421
+ num_head_channels = INPAINT_UNET_CONFIG["attention_head_dim"]
422
+
423
+ diffusers_checkpoint.update(unet_time_embeddings(checkpoint))
424
+ diffusers_checkpoint.update(unet_conv_in(checkpoint))
425
+ diffusers_checkpoint.update(unet_add_embedding(checkpoint))
426
+ diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint))
427
+
428
+ # <original>.input_blocks -> <diffusers>.down_blocks
429
+
430
+ original_down_block_idx = 1
431
+
432
+ for diffusers_down_block_idx in range(len(model.down_blocks)):
433
+ checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
434
+ model,
435
+ checkpoint,
436
+ diffusers_down_block_idx=diffusers_down_block_idx,
437
+ original_down_block_idx=original_down_block_idx,
438
+ num_head_channels=num_head_channels,
439
+ )
440
+
441
+ original_down_block_idx += num_original_down_blocks
442
+
443
+ diffusers_checkpoint.update(checkpoint_update)
444
+
445
+ # done <original>.input_blocks -> <diffusers>.down_blocks
446
+
447
+ diffusers_checkpoint.update(
448
+ unet_midblock_to_diffusers_checkpoint(
449
+ model,
450
+ checkpoint,
451
+ num_head_channels=num_head_channels,
452
+ )
453
+ )
454
+
455
+ # <original>.output_blocks -> <diffusers>.up_blocks
456
+
457
+ original_up_block_idx = 0
458
+
459
+ for diffusers_up_block_idx in range(len(model.up_blocks)):
460
+ checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
461
+ model,
462
+ checkpoint,
463
+ diffusers_up_block_idx=diffusers_up_block_idx,
464
+ original_up_block_idx=original_up_block_idx,
465
+ num_head_channels=num_head_channels,
466
+ )
467
+
468
+ original_up_block_idx += num_original_up_blocks
469
+
470
+ diffusers_checkpoint.update(checkpoint_update)
471
+
472
+ # done <original>.output_blocks -> <diffusers>.up_blocks
473
+
474
+ diffusers_checkpoint.update(unet_conv_norm_out(checkpoint))
475
+ diffusers_checkpoint.update(unet_conv_out(checkpoint))
476
+
477
+ return diffusers_checkpoint
478
+
479
+
480
+ # done inpaint unet
481
+
482
+
483
+ # unet utils
484
+
485
+
486
+ # <original>.time_embed -> <diffusers>.time_embedding
487
+ def unet_time_embeddings(checkpoint):
488
+ diffusers_checkpoint = {}
489
+
490
+ diffusers_checkpoint.update(
491
+ {
492
+ "time_embedding.linear_1.weight": checkpoint["time_embed.0.weight"],
493
+ "time_embedding.linear_1.bias": checkpoint["time_embed.0.bias"],
494
+ "time_embedding.linear_2.weight": checkpoint["time_embed.2.weight"],
495
+ "time_embedding.linear_2.bias": checkpoint["time_embed.2.bias"],
496
+ }
497
+ )
498
+
499
+ return diffusers_checkpoint
500
+
501
+
502
+ # <original>.input_blocks.0 -> <diffusers>.conv_in
503
+ def unet_conv_in(checkpoint):
504
+ diffusers_checkpoint = {}
505
+
506
+ diffusers_checkpoint.update(
507
+ {
508
+ "conv_in.weight": checkpoint["input_blocks.0.0.weight"],
509
+ "conv_in.bias": checkpoint["input_blocks.0.0.bias"],
510
+ }
511
+ )
512
+
513
+ return diffusers_checkpoint
514
+
515
+
516
+ def unet_add_embedding(checkpoint):
517
+ diffusers_checkpoint = {}
518
+
519
+ diffusers_checkpoint.update(
520
+ {
521
+ "add_embedding.text_norm.weight": checkpoint["ln_model_n.weight"],
522
+ "add_embedding.text_norm.bias": checkpoint["ln_model_n.bias"],
523
+ "add_embedding.text_proj.weight": checkpoint["proj_n.weight"],
524
+ "add_embedding.text_proj.bias": checkpoint["proj_n.bias"],
525
+ "add_embedding.image_proj.weight": checkpoint["img_layer.weight"],
526
+ "add_embedding.image_proj.bias": checkpoint["img_layer.bias"],
527
+ }
528
+ )
529
+
530
+ return diffusers_checkpoint
531
+
532
+
533
+ def unet_encoder_hid_proj(checkpoint):
534
+ diffusers_checkpoint = {}
535
+
536
+ diffusers_checkpoint.update(
537
+ {
538
+ "encoder_hid_proj.image_embeds.weight": checkpoint["clip_to_seq.weight"],
539
+ "encoder_hid_proj.image_embeds.bias": checkpoint["clip_to_seq.bias"],
540
+ "encoder_hid_proj.text_proj.weight": checkpoint["to_model_dim_n.weight"],
541
+ "encoder_hid_proj.text_proj.bias": checkpoint["to_model_dim_n.bias"],
542
+ }
543
+ )
544
+
545
+ return diffusers_checkpoint
546
+
547
+
548
+ # <original>.out.0 -> <diffusers>.conv_norm_out
549
+ def unet_conv_norm_out(checkpoint):
550
+ diffusers_checkpoint = {}
551
+
552
+ diffusers_checkpoint.update(
553
+ {
554
+ "conv_norm_out.weight": checkpoint["out.0.weight"],
555
+ "conv_norm_out.bias": checkpoint["out.0.bias"],
556
+ }
557
+ )
558
+
559
+ return diffusers_checkpoint
560
+
561
+
562
+ # <original>.out.2 -> <diffusers>.conv_out
563
+ def unet_conv_out(checkpoint):
564
+ diffusers_checkpoint = {}
565
+
566
+ diffusers_checkpoint.update(
567
+ {
568
+ "conv_out.weight": checkpoint["out.2.weight"],
569
+ "conv_out.bias": checkpoint["out.2.bias"],
570
+ }
571
+ )
572
+
573
+ return diffusers_checkpoint
574
+
575
+
576
+ # <original>.input_blocks -> <diffusers>.down_blocks
577
+ def unet_downblock_to_diffusers_checkpoint(
578
+ model, checkpoint, *, diffusers_down_block_idx, original_down_block_idx, num_head_channels
579
+ ):
580
+ diffusers_checkpoint = {}
581
+
582
+ diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.resnets"
583
+ original_down_block_prefix = "input_blocks"
584
+
585
+ down_block = model.down_blocks[diffusers_down_block_idx]
586
+
587
+ num_resnets = len(down_block.resnets)
588
+
589
+ if down_block.downsamplers is None:
590
+ downsampler = False
591
+ else:
592
+ assert len(down_block.downsamplers) == 1
593
+ downsampler = True
594
+ # The downsample block is also a resnet
595
+ num_resnets += 1
596
+
597
+ for resnet_idx_inc in range(num_resnets):
598
+ full_resnet_prefix = f"{original_down_block_prefix}.{original_down_block_idx + resnet_idx_inc}.0"
599
+
600
+ if downsampler and resnet_idx_inc == num_resnets - 1:
601
+ # this is a downsample block
602
+ full_diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.downsamplers.0"
603
+ else:
604
+ # this is a regular resnet block
605
+ full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}"
606
+
607
+ diffusers_checkpoint.update(
608
+ resnet_to_diffusers_checkpoint(
609
+ checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix
610
+ )
611
+ )
612
+
613
+ if hasattr(down_block, "attentions"):
614
+ num_attentions = len(down_block.attentions)
615
+ diffusers_attention_prefix = f"down_blocks.{diffusers_down_block_idx}.attentions"
616
+
617
+ for attention_idx_inc in range(num_attentions):
618
+ full_attention_prefix = f"{original_down_block_prefix}.{original_down_block_idx + attention_idx_inc}.1"
619
+ full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}"
620
+
621
+ diffusers_checkpoint.update(
622
+ attention_to_diffusers_checkpoint(
623
+ checkpoint,
624
+ attention_prefix=full_attention_prefix,
625
+ diffusers_attention_prefix=full_diffusers_attention_prefix,
626
+ num_head_channels=num_head_channels,
627
+ )
628
+ )
629
+
630
+ num_original_down_blocks = num_resnets
631
+
632
+ return diffusers_checkpoint, num_original_down_blocks
633
+
634
+
635
+ # <original>.middle_block -> <diffusers>.mid_block
636
+ def unet_midblock_to_diffusers_checkpoint(model, checkpoint, *, num_head_channels):
637
+ diffusers_checkpoint = {}
638
+
639
+ # block 0
640
+
641
+ original_block_idx = 0
642
+
643
+ diffusers_checkpoint.update(
644
+ resnet_to_diffusers_checkpoint(
645
+ checkpoint,
646
+ diffusers_resnet_prefix="mid_block.resnets.0",
647
+ resnet_prefix=f"middle_block.{original_block_idx}",
648
+ )
649
+ )
650
+
651
+ original_block_idx += 1
652
+
653
+ # optional block 1
654
+
655
+ if hasattr(model.mid_block, "attentions") and model.mid_block.attentions[0] is not None:
656
+ diffusers_checkpoint.update(
657
+ attention_to_diffusers_checkpoint(
658
+ checkpoint,
659
+ diffusers_attention_prefix="mid_block.attentions.0",
660
+ attention_prefix=f"middle_block.{original_block_idx}",
661
+ num_head_channels=num_head_channels,
662
+ )
663
+ )
664
+ original_block_idx += 1
665
+
666
+ # block 1 or block 2
667
+
668
+ diffusers_checkpoint.update(
669
+ resnet_to_diffusers_checkpoint(
670
+ checkpoint,
671
+ diffusers_resnet_prefix="mid_block.resnets.1",
672
+ resnet_prefix=f"middle_block.{original_block_idx}",
673
+ )
674
+ )
675
+
676
+ return diffusers_checkpoint
677
+
678
+
679
+ # <original>.output_blocks -> <diffusers>.up_blocks
680
+ def unet_upblock_to_diffusers_checkpoint(
681
+ model, checkpoint, *, diffusers_up_block_idx, original_up_block_idx, num_head_channels
682
+ ):
683
+ diffusers_checkpoint = {}
684
+
685
+ diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.resnets"
686
+ original_up_block_prefix = "output_blocks"
687
+
688
+ up_block = model.up_blocks[diffusers_up_block_idx]
689
+
690
+ num_resnets = len(up_block.resnets)
691
+
692
+ if up_block.upsamplers is None:
693
+ upsampler = False
694
+ else:
695
+ assert len(up_block.upsamplers) == 1
696
+ upsampler = True
697
+ # The upsample block is also a resnet
698
+ num_resnets += 1
699
+
700
+ has_attentions = hasattr(up_block, "attentions")
701
+
702
+ for resnet_idx_inc in range(num_resnets):
703
+ if upsampler and resnet_idx_inc == num_resnets - 1:
704
+ # this is an upsample block
705
+ if has_attentions:
706
+ # There is a middle attention block that we skip
707
+ original_resnet_block_idx = 2
708
+ else:
709
+ original_resnet_block_idx = 1
710
+
711
+ # we add the `minus 1` because the last two resnets are stuck together in the same output block
712
+ full_resnet_prefix = (
713
+ f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc - 1}.{original_resnet_block_idx}"
714
+ )
715
+
716
+ full_diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.upsamplers.0"
717
+ else:
718
+ # this is a regular resnet block
719
+ full_resnet_prefix = f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc}.0"
720
+ full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}"
721
+
722
+ diffusers_checkpoint.update(
723
+ resnet_to_diffusers_checkpoint(
724
+ checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix
725
+ )
726
+ )
727
+
728
+ if has_attentions:
729
+ num_attentions = len(up_block.attentions)
730
+ diffusers_attention_prefix = f"up_blocks.{diffusers_up_block_idx}.attentions"
731
+
732
+ for attention_idx_inc in range(num_attentions):
733
+ full_attention_prefix = f"{original_up_block_prefix}.{original_up_block_idx + attention_idx_inc}.1"
734
+ full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}"
735
+
736
+ diffusers_checkpoint.update(
737
+ attention_to_diffusers_checkpoint(
738
+ checkpoint,
739
+ attention_prefix=full_attention_prefix,
740
+ diffusers_attention_prefix=full_diffusers_attention_prefix,
741
+ num_head_channels=num_head_channels,
742
+ )
743
+ )
744
+
745
+ num_original_down_blocks = num_resnets - 1 if upsampler else num_resnets
746
+
747
+ return diffusers_checkpoint, num_original_down_blocks
748
+
749
+
750
+ def resnet_to_diffusers_checkpoint(checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
751
+ diffusers_checkpoint = {
752
+ f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.in_layers.0.weight"],
753
+ f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.in_layers.0.bias"],
754
+ f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.in_layers.2.weight"],
755
+ f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.in_layers.2.bias"],
756
+ f"{diffusers_resnet_prefix}.time_emb_proj.weight": checkpoint[f"{resnet_prefix}.emb_layers.1.weight"],
757
+ f"{diffusers_resnet_prefix}.time_emb_proj.bias": checkpoint[f"{resnet_prefix}.emb_layers.1.bias"],
758
+ f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.out_layers.0.weight"],
759
+ f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.out_layers.0.bias"],
760
+ f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.out_layers.3.weight"],
761
+ f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.out_layers.3.bias"],
762
+ }
763
+
764
+ skip_connection_prefix = f"{resnet_prefix}.skip_connection"
765
+
766
+ if f"{skip_connection_prefix}.weight" in checkpoint:
767
+ diffusers_checkpoint.update(
768
+ {
769
+ f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{skip_connection_prefix}.weight"],
770
+ f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{skip_connection_prefix}.bias"],
771
+ }
772
+ )
773
+
774
+ return diffusers_checkpoint
775
+
776
+
777
+ def attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix, num_head_channels):
778
+ diffusers_checkpoint = {}
779
+
780
+ # <original>.norm -> <diffusers>.group_norm
781
+ diffusers_checkpoint.update(
782
+ {
783
+ f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],
784
+ f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],
785
+ }
786
+ )
787
+
788
+ # <original>.qkv -> <diffusers>.{query, key, value}
789
+ [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
790
+ weight=checkpoint[f"{attention_prefix}.qkv.weight"][:, :, 0],
791
+ bias=checkpoint[f"{attention_prefix}.qkv.bias"],
792
+ split=3,
793
+ chunk_size=num_head_channels,
794
+ )
795
+
796
+ diffusers_checkpoint.update(
797
+ {
798
+ f"{diffusers_attention_prefix}.to_q.weight": q_weight,
799
+ f"{diffusers_attention_prefix}.to_q.bias": q_bias,
800
+ f"{diffusers_attention_prefix}.to_k.weight": k_weight,
801
+ f"{diffusers_attention_prefix}.to_k.bias": k_bias,
802
+ f"{diffusers_attention_prefix}.to_v.weight": v_weight,
803
+ f"{diffusers_attention_prefix}.to_v.bias": v_bias,
804
+ }
805
+ )
806
+
807
+ # <original>.encoder_kv -> <diffusers>.{context_key, context_value}
808
+ [encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions(
809
+ weight=checkpoint[f"{attention_prefix}.encoder_kv.weight"][:, :, 0],
810
+ bias=checkpoint[f"{attention_prefix}.encoder_kv.bias"],
811
+ split=2,
812
+ chunk_size=num_head_channels,
813
+ )
814
+
815
+ diffusers_checkpoint.update(
816
+ {
817
+ f"{diffusers_attention_prefix}.add_k_proj.weight": encoder_k_weight,
818
+ f"{diffusers_attention_prefix}.add_k_proj.bias": encoder_k_bias,
819
+ f"{diffusers_attention_prefix}.add_v_proj.weight": encoder_v_weight,
820
+ f"{diffusers_attention_prefix}.add_v_proj.bias": encoder_v_bias,
821
+ }
822
+ )
823
+
824
+ # <original>.proj_out (1d conv) -> <diffusers>.proj_attn (linear)
825
+ diffusers_checkpoint.update(
826
+ {
827
+ f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][
828
+ :, :, 0
829
+ ],
830
+ f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
831
+ }
832
+ )
833
+
834
+ return diffusers_checkpoint
835
+
836
+
837
+ # TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
838
+ def split_attentions(*, weight, bias, split, chunk_size):
839
+ weights = [None] * split
840
+ biases = [None] * split
841
+
842
+ weights_biases_idx = 0
843
+
844
+ for starting_row_index in range(0, weight.shape[0], chunk_size):
845
+ row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
846
+
847
+ weight_rows = weight[row_indices, :]
848
+ bias_rows = bias[row_indices]
849
+
850
+ if weights[weights_biases_idx] is None:
851
+ assert weights[weights_biases_idx] is None
852
+ weights[weights_biases_idx] = weight_rows
853
+ biases[weights_biases_idx] = bias_rows
854
+ else:
855
+ assert weights[weights_biases_idx] is not None
856
+ weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
857
+ biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
858
+
859
+ weights_biases_idx = (weights_biases_idx + 1) % split
860
+
861
+ return weights, biases
862
+
863
+
864
+ # done unet utils
865
+
866
+
867
+ def prior(*, args, checkpoint_map_location):
868
+ print("loading prior")
869
+
870
+ prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location)
871
+
872
+ clip_stats_checkpoint = torch.load(args.clip_stat_path, map_location=checkpoint_map_location)
873
+
874
+ prior_model = prior_model_from_original_config()
875
+
876
+ prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint(
877
+ prior_model, prior_checkpoint, clip_stats_checkpoint
878
+ )
879
+
880
+ del prior_checkpoint
881
+ del clip_stats_checkpoint
882
+
883
+ load_checkpoint_to_model(prior_diffusers_checkpoint, prior_model, strict=True)
884
+
885
+ print("done loading prior")
886
+
887
+ return prior_model
888
+
889
+
890
+ def text2img(*, args, checkpoint_map_location):
891
+ print("loading text2img")
892
+
893
+ text2img_checkpoint = torch.load(args.text2img_checkpoint_path, map_location=checkpoint_map_location)
894
+
895
+ unet_model = unet_model_from_original_config()
896
+
897
+ unet_diffusers_checkpoint = unet_original_checkpoint_to_diffusers_checkpoint(unet_model, text2img_checkpoint)
898
+
899
+ del text2img_checkpoint
900
+
901
+ load_checkpoint_to_model(unet_diffusers_checkpoint, unet_model, strict=True)
902
+
903
+ print("done loading text2img")
904
+
905
+ return unet_model
906
+
907
+
908
+ def inpaint_text2img(*, args, checkpoint_map_location):
909
+ print("loading inpaint text2img")
910
+
911
+ inpaint_text2img_checkpoint = torch.load(
912
+ args.inpaint_text2img_checkpoint_path, map_location=checkpoint_map_location
913
+ )
914
+
915
+ inpaint_unet_model = inpaint_unet_model_from_original_config()
916
+
917
+ inpaint_unet_diffusers_checkpoint = inpaint_unet_original_checkpoint_to_diffusers_checkpoint(
918
+ inpaint_unet_model, inpaint_text2img_checkpoint
919
+ )
920
+
921
+ del inpaint_text2img_checkpoint
922
+
923
+ load_checkpoint_to_model(inpaint_unet_diffusers_checkpoint, inpaint_unet_model, strict=True)
924
+
925
+ print("done loading inpaint text2img")
926
+
927
+ return inpaint_unet_model
928
+
929
+
930
+ # movq
931
+
932
+ MOVQ_CONFIG = {
933
+ "in_channels": 3,
934
+ "out_channels": 3,
935
+ "latent_channels": 4,
936
+ "down_block_types": ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "AttnDownEncoderBlock2D"),
937
+ "up_block_types": ("AttnUpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
938
+ "num_vq_embeddings": 16384,
939
+ "block_out_channels": (128, 256, 256, 512),
940
+ "vq_embed_dim": 4,
941
+ "layers_per_block": 2,
942
+ "norm_type": "spatial",
943
+ }
944
+
945
+
946
+ def movq_model_from_original_config():
947
+ movq = VQModel(**MOVQ_CONFIG)
948
+ return movq
949
+
950
+
951
+ def movq_encoder_to_diffusers_checkpoint(model, checkpoint):
952
+ diffusers_checkpoint = {}
953
+
954
+ # conv_in
955
+ diffusers_checkpoint.update(
956
+ {
957
+ "encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"],
958
+ "encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"],
959
+ }
960
+ )
961
+
962
+ # down_blocks
963
+ for down_block_idx, down_block in enumerate(model.encoder.down_blocks):
964
+ diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}"
965
+ down_block_prefix = f"encoder.down.{down_block_idx}"
966
+
967
+ # resnets
968
+ for resnet_idx, resnet in enumerate(down_block.resnets):
969
+ diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}"
970
+ resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}"
971
+
972
+ diffusers_checkpoint.update(
973
+ movq_resnet_to_diffusers_checkpoint(
974
+ resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
975
+ )
976
+ )
977
+
978
+ # downsample
979
+
980
+ # do not include the downsample when on the last down block
981
+ # There is no downsample on the last down block
982
+ if down_block_idx != len(model.encoder.down_blocks) - 1:
983
+ # There's a single downsample in the original checkpoint but a list of downsamples
984
+ # in the diffusers model.
985
+ diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv"
986
+ downsample_prefix = f"{down_block_prefix}.downsample.conv"
987
+ diffusers_checkpoint.update(
988
+ {
989
+ f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
990
+ f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
991
+ }
992
+ )
993
+
994
+ # attentions
995
+
996
+ if hasattr(down_block, "attentions"):
997
+ for attention_idx, _ in enumerate(down_block.attentions):
998
+ diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}"
999
+ attention_prefix = f"{down_block_prefix}.attn.{attention_idx}"
1000
+ diffusers_checkpoint.update(
1001
+ movq_attention_to_diffusers_checkpoint(
1002
+ checkpoint,
1003
+ diffusers_attention_prefix=diffusers_attention_prefix,
1004
+ attention_prefix=attention_prefix,
1005
+ )
1006
+ )
1007
+
1008
+ # mid block
1009
+
1010
+ # mid block attentions
1011
+
1012
+ # There is a single hardcoded attention block in the middle of the VQ-diffusion encoder
1013
+ diffusers_attention_prefix = "encoder.mid_block.attentions.0"
1014
+ attention_prefix = "encoder.mid.attn_1"
1015
+ diffusers_checkpoint.update(
1016
+ movq_attention_to_diffusers_checkpoint(
1017
+ checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
1018
+ )
1019
+ )
1020
+
1021
+ # mid block resnets
1022
+
1023
+ for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
1024
+ diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}"
1025
+
1026
+ # the hardcoded prefixes to `block_` are 1 and 2
1027
+ orig_resnet_idx = diffusers_resnet_idx + 1
1028
+ # There are two hardcoded resnets in the middle of the VQ-diffusion encoder
1029
+ resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}"
1030
+
1031
+ diffusers_checkpoint.update(
1032
+ movq_resnet_to_diffusers_checkpoint(
1033
+ resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
1034
+ )
1035
+ )
1036
+
1037
+ diffusers_checkpoint.update(
1038
+ {
1039
+ # conv_norm_out
1040
+ "encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"],
1041
+ "encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"],
1042
+ # conv_out
1043
+ "encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"],
1044
+ "encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"],
1045
+ }
1046
+ )
1047
+
1048
+ return diffusers_checkpoint
1049
+
1050
+
1051
+ def movq_decoder_to_diffusers_checkpoint(model, checkpoint):
1052
+ diffusers_checkpoint = {}
1053
+
1054
+ # conv in
1055
+ diffusers_checkpoint.update(
1056
+ {
1057
+ "decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"],
1058
+ "decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"],
1059
+ }
1060
+ )
1061
+
1062
+ # up_blocks
1063
+
1064
+ for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks):
1065
+ # up_blocks are stored in reverse order in the VQ-diffusion checkpoint
1066
+ orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx
1067
+
1068
+ diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}"
1069
+ up_block_prefix = f"decoder.up.{orig_up_block_idx}"
1070
+
1071
+ # resnets
1072
+ for resnet_idx, resnet in enumerate(up_block.resnets):
1073
+ diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
1074
+ resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}"
1075
+
1076
+ diffusers_checkpoint.update(
1077
+ movq_resnet_to_diffusers_checkpoint_spatial_norm(
1078
+ resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
1079
+ )
1080
+ )
1081
+
1082
+ # upsample
1083
+
1084
+ # there is no up sample on the last up block
1085
+ if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1:
1086
+ # There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples
1087
+ # in the diffusers model.
1088
+ diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv"
1089
+ downsample_prefix = f"{up_block_prefix}.upsample.conv"
1090
+ diffusers_checkpoint.update(
1091
+ {
1092
+ f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
1093
+ f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
1094
+ }
1095
+ )
1096
+
1097
+ # attentions
1098
+
1099
+ if hasattr(up_block, "attentions"):
1100
+ for attention_idx, _ in enumerate(up_block.attentions):
1101
+ diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}"
1102
+ attention_prefix = f"{up_block_prefix}.attn.{attention_idx}"
1103
+ diffusers_checkpoint.update(
1104
+ movq_attention_to_diffusers_checkpoint_spatial_norm(
1105
+ checkpoint,
1106
+ diffusers_attention_prefix=diffusers_attention_prefix,
1107
+ attention_prefix=attention_prefix,
1108
+ )
1109
+ )
1110
+
1111
+ # mid block
1112
+
1113
+ # mid block attentions
1114
+
1115
+ # There is a single hardcoded attention block in the middle of the VQ-diffusion decoder
1116
+ diffusers_attention_prefix = "decoder.mid_block.attentions.0"
1117
+ attention_prefix = "decoder.mid.attn_1"
1118
+ diffusers_checkpoint.update(
1119
+ movq_attention_to_diffusers_checkpoint_spatial_norm(
1120
+ checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
1121
+ )
1122
+ )
1123
+
1124
+ # mid block resnets
1125
+
1126
+ for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
1127
+ diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}"
1128
+
1129
+ # the hardcoded prefixes to `block_` are 1 and 2
1130
+ orig_resnet_idx = diffusers_resnet_idx + 1
1131
+ # There are two hardcoded resnets in the middle of the VQ-diffusion decoder
1132
+ resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}"
1133
+
1134
+ diffusers_checkpoint.update(
1135
+ movq_resnet_to_diffusers_checkpoint_spatial_norm(
1136
+ resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
1137
+ )
1138
+ )
1139
+
1140
+ diffusers_checkpoint.update(
1141
+ {
1142
+ # conv_norm_out
1143
+ "decoder.conv_norm_out.norm_layer.weight": checkpoint["decoder.norm_out.norm_layer.weight"],
1144
+ "decoder.conv_norm_out.norm_layer.bias": checkpoint["decoder.norm_out.norm_layer.bias"],
1145
+ "decoder.conv_norm_out.conv_y.weight": checkpoint["decoder.norm_out.conv_y.weight"],
1146
+ "decoder.conv_norm_out.conv_y.bias": checkpoint["decoder.norm_out.conv_y.bias"],
1147
+ "decoder.conv_norm_out.conv_b.weight": checkpoint["decoder.norm_out.conv_b.weight"],
1148
+ "decoder.conv_norm_out.conv_b.bias": checkpoint["decoder.norm_out.conv_b.bias"],
1149
+ # conv_out
1150
+ "decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"],
1151
+ "decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"],
1152
+ }
1153
+ )
1154
+
1155
+ return diffusers_checkpoint
1156
+
1157
+
1158
+ def movq_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
1159
+ rv = {
1160
+ # norm1
1161
+ f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"],
1162
+ f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"],
1163
+ # conv1
1164
+ f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"],
1165
+ f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"],
1166
+ # norm2
1167
+ f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"],
1168
+ f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"],
1169
+ # conv2
1170
+ f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"],
1171
+ f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"],
1172
+ }
1173
+
1174
+ if resnet.conv_shortcut is not None:
1175
+ rv.update(
1176
+ {
1177
+ f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"],
1178
+ f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"],
1179
+ }
1180
+ )
1181
+
1182
+ return rv
1183
+
1184
+
1185
+ def movq_resnet_to_diffusers_checkpoint_spatial_norm(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
1186
+ rv = {
1187
+ # norm1
1188
+ f"{diffusers_resnet_prefix}.norm1.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm1.norm_layer.weight"],
1189
+ f"{diffusers_resnet_prefix}.norm1.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm1.norm_layer.bias"],
1190
+ f"{diffusers_resnet_prefix}.norm1.conv_y.weight": checkpoint[f"{resnet_prefix}.norm1.conv_y.weight"],
1191
+ f"{diffusers_resnet_prefix}.norm1.conv_y.bias": checkpoint[f"{resnet_prefix}.norm1.conv_y.bias"],
1192
+ f"{diffusers_resnet_prefix}.norm1.conv_b.weight": checkpoint[f"{resnet_prefix}.norm1.conv_b.weight"],
1193
+ f"{diffusers_resnet_prefix}.norm1.conv_b.bias": checkpoint[f"{resnet_prefix}.norm1.conv_b.bias"],
1194
+ # conv1
1195
+ f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"],
1196
+ f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"],
1197
+ # norm2
1198
+ f"{diffusers_resnet_prefix}.norm2.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm2.norm_layer.weight"],
1199
+ f"{diffusers_resnet_prefix}.norm2.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm2.norm_layer.bias"],
1200
+ f"{diffusers_resnet_prefix}.norm2.conv_y.weight": checkpoint[f"{resnet_prefix}.norm2.conv_y.weight"],
1201
+ f"{diffusers_resnet_prefix}.norm2.conv_y.bias": checkpoint[f"{resnet_prefix}.norm2.conv_y.bias"],
1202
+ f"{diffusers_resnet_prefix}.norm2.conv_b.weight": checkpoint[f"{resnet_prefix}.norm2.conv_b.weight"],
1203
+ f"{diffusers_resnet_prefix}.norm2.conv_b.bias": checkpoint[f"{resnet_prefix}.norm2.conv_b.bias"],
1204
+ # conv2
1205
+ f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"],
1206
+ f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"],
1207
+ }
1208
+
1209
+ if resnet.conv_shortcut is not None:
1210
+ rv.update(
1211
+ {
1212
+ f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"],
1213
+ f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"],
1214
+ }
1215
+ )
1216
+
1217
+ return rv
1218
+
1219
+
1220
+ def movq_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
1221
+ return {
1222
+ # norm
1223
+ f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],
1224
+ f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],
1225
+ # query
1226
+ f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0],
1227
+ f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.q.bias"],
1228
+ # key
1229
+ f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0],
1230
+ f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.k.bias"],
1231
+ # value
1232
+ f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0],
1233
+ f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.v.bias"],
1234
+ # proj_attn
1235
+ f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][:, :, 0, 0],
1236
+ f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
1237
+ }
1238
+
1239
+
1240
+ def movq_attention_to_diffusers_checkpoint_spatial_norm(checkpoint, *, diffusers_attention_prefix, attention_prefix):
1241
+ return {
1242
+ # norm
1243
+ f"{diffusers_attention_prefix}.spatial_norm.norm_layer.weight": checkpoint[
1244
+ f"{attention_prefix}.norm.norm_layer.weight"
1245
+ ],
1246
+ f"{diffusers_attention_prefix}.spatial_norm.norm_layer.bias": checkpoint[
1247
+ f"{attention_prefix}.norm.norm_layer.bias"
1248
+ ],
1249
+ f"{diffusers_attention_prefix}.spatial_norm.conv_y.weight": checkpoint[
1250
+ f"{attention_prefix}.norm.conv_y.weight"
1251
+ ],
1252
+ f"{diffusers_attention_prefix}.spatial_norm.conv_y.bias": checkpoint[f"{attention_prefix}.norm.conv_y.bias"],
1253
+ f"{diffusers_attention_prefix}.spatial_norm.conv_b.weight": checkpoint[
1254
+ f"{attention_prefix}.norm.conv_b.weight"
1255
+ ],
1256
+ f"{diffusers_attention_prefix}.spatial_norm.conv_b.bias": checkpoint[f"{attention_prefix}.norm.conv_b.bias"],
1257
+ # query
1258
+ f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0],
1259
+ f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.q.bias"],
1260
+ # key
1261
+ f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0],
1262
+ f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.k.bias"],
1263
+ # value
1264
+ f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0],
1265
+ f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.v.bias"],
1266
+ # proj_attn
1267
+ f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][:, :, 0, 0],
1268
+ f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
1269
+ }
1270
+
1271
+
1272
+ def movq_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
1273
+ diffusers_checkpoint = {}
1274
+ diffusers_checkpoint.update(movq_encoder_to_diffusers_checkpoint(model, checkpoint))
1275
+
1276
+ # quant_conv
1277
+
1278
+ diffusers_checkpoint.update(
1279
+ {
1280
+ "quant_conv.weight": checkpoint["quant_conv.weight"],
1281
+ "quant_conv.bias": checkpoint["quant_conv.bias"],
1282
+ }
1283
+ )
1284
+
1285
+ # quantize
1286
+ diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding.weight"]})
1287
+
1288
+ # post_quant_conv
1289
+ diffusers_checkpoint.update(
1290
+ {
1291
+ "post_quant_conv.weight": checkpoint["post_quant_conv.weight"],
1292
+ "post_quant_conv.bias": checkpoint["post_quant_conv.bias"],
1293
+ }
1294
+ )
1295
+
1296
+ # decoder
1297
+ diffusers_checkpoint.update(movq_decoder_to_diffusers_checkpoint(model, checkpoint))
1298
+
1299
+ return diffusers_checkpoint
1300
+
1301
+
1302
+ def movq(*, args, checkpoint_map_location):
1303
+ print("loading movq")
1304
+
1305
+ movq_checkpoint = torch.load(args.movq_checkpoint_path, map_location=checkpoint_map_location)
1306
+
1307
+ movq_model = movq_model_from_original_config()
1308
+
1309
+ movq_diffusers_checkpoint = movq_original_checkpoint_to_diffusers_checkpoint(movq_model, movq_checkpoint)
1310
+
1311
+ del movq_checkpoint
1312
+
1313
+ load_checkpoint_to_model(movq_diffusers_checkpoint, movq_model, strict=True)
1314
+
1315
+ print("done loading movq")
1316
+
1317
+ return movq_model
1318
+
1319
+
1320
+ def load_checkpoint_to_model(checkpoint, model, strict=False):
1321
+ with tempfile.NamedTemporaryFile(delete=False) as file:
1322
+ torch.save(checkpoint, file.name)
1323
+ del checkpoint
1324
+ if strict:
1325
+ model.load_state_dict(torch.load(file.name), strict=True)
1326
+ else:
1327
+ load_checkpoint_and_dispatch(model, file.name, device_map="auto")
1328
+ os.remove(file.name)
1329
+
1330
+
1331
+ if __name__ == "__main__":
1332
+ parser = argparse.ArgumentParser()
1333
+
1334
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
1335
+
1336
+ parser.add_argument(
1337
+ "--prior_checkpoint_path",
1338
+ default=None,
1339
+ type=str,
1340
+ required=False,
1341
+ help="Path to the prior checkpoint to convert.",
1342
+ )
1343
+ parser.add_argument(
1344
+ "--clip_stat_path",
1345
+ default=None,
1346
+ type=str,
1347
+ required=False,
1348
+ help="Path to the clip stats checkpoint to convert.",
1349
+ )
1350
+ parser.add_argument(
1351
+ "--text2img_checkpoint_path",
1352
+ default=None,
1353
+ type=str,
1354
+ required=False,
1355
+ help="Path to the text2img checkpoint to convert.",
1356
+ )
1357
+ parser.add_argument(
1358
+ "--movq_checkpoint_path",
1359
+ default=None,
1360
+ type=str,
1361
+ required=False,
1362
+ help="Path to the text2img checkpoint to convert.",
1363
+ )
1364
+ parser.add_argument(
1365
+ "--inpaint_text2img_checkpoint_path",
1366
+ default=None,
1367
+ type=str,
1368
+ required=False,
1369
+ help="Path to the inpaint text2img checkpoint to convert.",
1370
+ )
1371
+ parser.add_argument(
1372
+ "--checkpoint_load_device",
1373
+ default="cpu",
1374
+ type=str,
1375
+ required=False,
1376
+ help="The device passed to `map_location` when loading checkpoints.",
1377
+ )
1378
+
1379
+ parser.add_argument(
1380
+ "--debug",
1381
+ default=None,
1382
+ type=str,
1383
+ required=False,
1384
+ help="Only run a specific stage of the convert script. Used for debugging",
1385
+ )
1386
+
1387
+ args = parser.parse_args()
1388
+
1389
+ print(f"loading checkpoints to {args.checkpoint_load_device}")
1390
+
1391
+ checkpoint_map_location = torch.device(args.checkpoint_load_device)
1392
+
1393
+ if args.debug is not None:
1394
+ print(f"debug: only executing {args.debug}")
1395
+
1396
+ if args.debug is None:
1397
+ print("to-do")
1398
+ elif args.debug == "prior":
1399
+ prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
1400
+ prior_model.save_pretrained(args.dump_path)
1401
+ elif args.debug == "text2img":
1402
+ unet_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location)
1403
+ unet_model.save_pretrained(f"{args.dump_path}/unet")
1404
+ elif args.debug == "inpaint_text2img":
1405
+ inpaint_unet_model = inpaint_text2img(args=args, checkpoint_map_location=checkpoint_map_location)
1406
+ inpaint_unet_model.save_pretrained(f"{args.dump_path}/inpaint_unet")
1407
+ elif args.debug == "decoder":
1408
+ decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location)
1409
+ decoder.save_pretrained(f"{args.dump_path}/decoder")
1410
+ else:
1411
+ raise ValueError(f"unknown debug value : {args.debug}")
diffusers/scripts/convert_ldm_original_checkpoint_to_diffusers.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Conversion script for the LDM checkpoints."""
16
+
17
+ import argparse
18
+ import json
19
+
20
+ import torch
21
+
22
+ from diffusers import DDPMScheduler, LDMPipeline, UNet2DModel, VQModel
23
+
24
+
25
+ def shave_segments(path, n_shave_prefix_segments=1):
26
+ """
27
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
28
+ """
29
+ if n_shave_prefix_segments >= 0:
30
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
31
+ else:
32
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
33
+
34
+
35
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
36
+ """
37
+ Updates paths inside resnets to the new naming scheme (local renaming)
38
+ """
39
+ mapping = []
40
+ for old_item in old_list:
41
+ new_item = old_item.replace("in_layers.0", "norm1")
42
+ new_item = new_item.replace("in_layers.2", "conv1")
43
+
44
+ new_item = new_item.replace("out_layers.0", "norm2")
45
+ new_item = new_item.replace("out_layers.3", "conv2")
46
+
47
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
48
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
49
+
50
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
51
+
52
+ mapping.append({"old": old_item, "new": new_item})
53
+
54
+ return mapping
55
+
56
+
57
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
58
+ """
59
+ Updates paths inside attentions to the new naming scheme (local renaming)
60
+ """
61
+ mapping = []
62
+ for old_item in old_list:
63
+ new_item = old_item
64
+
65
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
66
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
67
+
68
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
69
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
70
+
71
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
72
+
73
+ mapping.append({"old": old_item, "new": new_item})
74
+
75
+ return mapping
76
+
77
+
78
+ def assign_to_checkpoint(
79
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
80
+ ):
81
+ """
82
+ This does the final conversion step: take locally converted weights and apply a global renaming
83
+ to them. It splits attention layers, and takes into account additional replacements
84
+ that may arise.
85
+
86
+ Assigns the weights to the new checkpoint.
87
+ """
88
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
89
+
90
+ # Splits the attention layers into three variables.
91
+ if attention_paths_to_split is not None:
92
+ for path, path_map in attention_paths_to_split.items():
93
+ old_tensor = old_checkpoint[path]
94
+ channels = old_tensor.shape[0] // 3
95
+
96
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
97
+
98
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
99
+
100
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
101
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
102
+
103
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
104
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
105
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
106
+
107
+ for path in paths:
108
+ new_path = path["new"]
109
+
110
+ # These have already been assigned
111
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
112
+ continue
113
+
114
+ # Global renaming happens here
115
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
116
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
117
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
118
+
119
+ if additional_replacements is not None:
120
+ for replacement in additional_replacements:
121
+ new_path = new_path.replace(replacement["old"], replacement["new"])
122
+
123
+ # proj_attn.weight has to be converted from conv 1D to linear
124
+ if "proj_attn.weight" in new_path:
125
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
126
+ else:
127
+ checkpoint[new_path] = old_checkpoint[path["old"]]
128
+
129
+
130
+ def convert_ldm_checkpoint(checkpoint, config):
131
+ """
132
+ Takes a state dict and a config, and returns a converted checkpoint.
133
+ """
134
+ new_checkpoint = {}
135
+
136
+ new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
137
+ new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
138
+ new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
139
+ new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
140
+
141
+ new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
142
+ new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
143
+
144
+ new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
145
+ new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
146
+ new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
147
+ new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
148
+
149
+ # Retrieves the keys for the input blocks only
150
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer})
151
+ input_blocks = {
152
+ layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key]
153
+ for layer_id in range(num_input_blocks)
154
+ }
155
+
156
+ # Retrieves the keys for the middle blocks only
157
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "middle_block" in layer})
158
+ middle_blocks = {
159
+ layer_id: [key for key in checkpoint if f"middle_block.{layer_id}" in key]
160
+ for layer_id in range(num_middle_blocks)
161
+ }
162
+
163
+ # Retrieves the keys for the output blocks only
164
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "output_blocks" in layer})
165
+ output_blocks = {
166
+ layer_id: [key for key in checkpoint if f"output_blocks.{layer_id}" in key]
167
+ for layer_id in range(num_output_blocks)
168
+ }
169
+
170
+ for i in range(1, num_input_blocks):
171
+ block_id = (i - 1) // (config["num_res_blocks"] + 1)
172
+ layer_in_block_id = (i - 1) % (config["num_res_blocks"] + 1)
173
+
174
+ resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key]
175
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
176
+
177
+ if f"input_blocks.{i}.0.op.weight" in checkpoint:
178
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[
179
+ f"input_blocks.{i}.0.op.weight"
180
+ ]
181
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[
182
+ f"input_blocks.{i}.0.op.bias"
183
+ ]
184
+ continue
185
+
186
+ paths = renew_resnet_paths(resnets)
187
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
188
+ resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"}
189
+ assign_to_checkpoint(
190
+ paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config
191
+ )
192
+
193
+ if len(attentions):
194
+ paths = renew_attention_paths(attentions)
195
+ meta_path = {
196
+ "old": f"input_blocks.{i}.1",
197
+ "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}",
198
+ }
199
+ to_split = {
200
+ f"input_blocks.{i}.1.qkv.bias": {
201
+ "key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
202
+ "query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
203
+ "value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
204
+ },
205
+ f"input_blocks.{i}.1.qkv.weight": {
206
+ "key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
207
+ "query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
208
+ "value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
209
+ },
210
+ }
211
+ assign_to_checkpoint(
212
+ paths,
213
+ new_checkpoint,
214
+ checkpoint,
215
+ additional_replacements=[meta_path],
216
+ attention_paths_to_split=to_split,
217
+ config=config,
218
+ )
219
+
220
+ resnet_0 = middle_blocks[0]
221
+ attentions = middle_blocks[1]
222
+ resnet_1 = middle_blocks[2]
223
+
224
+ resnet_0_paths = renew_resnet_paths(resnet_0)
225
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint, config=config)
226
+
227
+ resnet_1_paths = renew_resnet_paths(resnet_1)
228
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint, config=config)
229
+
230
+ attentions_paths = renew_attention_paths(attentions)
231
+ to_split = {
232
+ "middle_block.1.qkv.bias": {
233
+ "key": "mid_block.attentions.0.key.bias",
234
+ "query": "mid_block.attentions.0.query.bias",
235
+ "value": "mid_block.attentions.0.value.bias",
236
+ },
237
+ "middle_block.1.qkv.weight": {
238
+ "key": "mid_block.attentions.0.key.weight",
239
+ "query": "mid_block.attentions.0.query.weight",
240
+ "value": "mid_block.attentions.0.value.weight",
241
+ },
242
+ }
243
+ assign_to_checkpoint(
244
+ attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config
245
+ )
246
+
247
+ for i in range(num_output_blocks):
248
+ block_id = i // (config["num_res_blocks"] + 1)
249
+ layer_in_block_id = i % (config["num_res_blocks"] + 1)
250
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
251
+ output_block_list = {}
252
+
253
+ for layer in output_block_layers:
254
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
255
+ if layer_id in output_block_list:
256
+ output_block_list[layer_id].append(layer_name)
257
+ else:
258
+ output_block_list[layer_id] = [layer_name]
259
+
260
+ if len(output_block_list) > 1:
261
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
262
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
263
+
264
+ resnet_0_paths = renew_resnet_paths(resnets)
265
+ paths = renew_resnet_paths(resnets)
266
+
267
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
268
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
269
+
270
+ if ["conv.weight", "conv.bias"] in output_block_list.values():
271
+ index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
272
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
273
+ f"output_blocks.{i}.{index}.conv.weight"
274
+ ]
275
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
276
+ f"output_blocks.{i}.{index}.conv.bias"
277
+ ]
278
+
279
+ # Clear attentions as they have been attributed above.
280
+ if len(attentions) == 2:
281
+ attentions = []
282
+
283
+ if len(attentions):
284
+ paths = renew_attention_paths(attentions)
285
+ meta_path = {
286
+ "old": f"output_blocks.{i}.1",
287
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
288
+ }
289
+ to_split = {
290
+ f"output_blocks.{i}.1.qkv.bias": {
291
+ "key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
292
+ "query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
293
+ "value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
294
+ },
295
+ f"output_blocks.{i}.1.qkv.weight": {
296
+ "key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
297
+ "query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
298
+ "value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
299
+ },
300
+ }
301
+ assign_to_checkpoint(
302
+ paths,
303
+ new_checkpoint,
304
+ checkpoint,
305
+ additional_replacements=[meta_path],
306
+ attention_paths_to_split=to_split if any("qkv" in key for key in attentions) else None,
307
+ config=config,
308
+ )
309
+ else:
310
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
311
+ for path in resnet_0_paths:
312
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
313
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
314
+
315
+ new_checkpoint[new_path] = checkpoint[old_path]
316
+
317
+ return new_checkpoint
318
+
319
+
320
+ if __name__ == "__main__":
321
+ parser = argparse.ArgumentParser()
322
+
323
+ parser.add_argument(
324
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
325
+ )
326
+
327
+ parser.add_argument(
328
+ "--config_file",
329
+ default=None,
330
+ type=str,
331
+ required=True,
332
+ help="The config json file corresponding to the architecture.",
333
+ )
334
+
335
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
336
+
337
+ args = parser.parse_args()
338
+
339
+ checkpoint = torch.load(args.checkpoint_path)
340
+
341
+ with open(args.config_file) as f:
342
+ config = json.loads(f.read())
343
+
344
+ converted_checkpoint = convert_ldm_checkpoint(checkpoint, config)
345
+
346
+ if "ldm" in config:
347
+ del config["ldm"]
348
+
349
+ model = UNet2DModel(**config)
350
+ model.load_state_dict(converted_checkpoint)
351
+
352
+ try:
353
+ scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
354
+ vqvae = VQModel.from_pretrained("/".join(args.checkpoint_path.split("/")[:-1]))
355
+
356
+ pipe = LDMPipeline(unet=model, scheduler=scheduler, vae=vqvae)
357
+ pipe.save_pretrained(args.dump_path)
358
+ except: # noqa: E722
359
+ model.save_pretrained(args.dump_path)
diffusers/scripts/convert_ltx_to_diffusers.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ from typing import Any, Dict
4
+
5
+ import torch
6
+ from accelerate import init_empty_weights
7
+ from safetensors.torch import load_file
8
+ from transformers import T5EncoderModel, T5Tokenizer
9
+
10
+ from diffusers import (
11
+ AutoencoderKLLTXVideo,
12
+ FlowMatchEulerDiscreteScheduler,
13
+ LTXConditionPipeline,
14
+ LTXLatentUpsamplePipeline,
15
+ LTXPipeline,
16
+ LTXVideoTransformer3DModel,
17
+ )
18
+ from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
19
+
20
+
21
+ def remove_keys_(key: str, state_dict: Dict[str, Any]):
22
+ state_dict.pop(key)
23
+
24
+
25
+ TOKENIZER_MAX_LENGTH = 128
26
+
27
+ TRANSFORMER_KEYS_RENAME_DICT = {
28
+ "patchify_proj": "proj_in",
29
+ "adaln_single": "time_embed",
30
+ "q_norm": "norm_q",
31
+ "k_norm": "norm_k",
32
+ }
33
+
34
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
35
+ "vae": remove_keys_,
36
+ }
37
+
38
+ VAE_KEYS_RENAME_DICT = {
39
+ # decoder
40
+ "up_blocks.0": "mid_block",
41
+ "up_blocks.1": "up_blocks.0",
42
+ "up_blocks.2": "up_blocks.1.upsamplers.0",
43
+ "up_blocks.3": "up_blocks.1",
44
+ "up_blocks.4": "up_blocks.2.conv_in",
45
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
46
+ "up_blocks.6": "up_blocks.2",
47
+ "up_blocks.7": "up_blocks.3.conv_in",
48
+ "up_blocks.8": "up_blocks.3.upsamplers.0",
49
+ "up_blocks.9": "up_blocks.3",
50
+ # encoder
51
+ "down_blocks.0": "down_blocks.0",
52
+ "down_blocks.1": "down_blocks.0.downsamplers.0",
53
+ "down_blocks.2": "down_blocks.0.conv_out",
54
+ "down_blocks.3": "down_blocks.1",
55
+ "down_blocks.4": "down_blocks.1.downsamplers.0",
56
+ "down_blocks.5": "down_blocks.1.conv_out",
57
+ "down_blocks.6": "down_blocks.2",
58
+ "down_blocks.7": "down_blocks.2.downsamplers.0",
59
+ "down_blocks.8": "down_blocks.3",
60
+ "down_blocks.9": "mid_block",
61
+ # common
62
+ "conv_shortcut": "conv_shortcut.conv",
63
+ "res_blocks": "resnets",
64
+ "norm3.norm": "norm3",
65
+ "per_channel_statistics.mean-of-means": "latents_mean",
66
+ "per_channel_statistics.std-of-means": "latents_std",
67
+ }
68
+
69
+ VAE_091_RENAME_DICT = {
70
+ # decoder
71
+ "up_blocks.0": "mid_block",
72
+ "up_blocks.1": "up_blocks.0.upsamplers.0",
73
+ "up_blocks.2": "up_blocks.0",
74
+ "up_blocks.3": "up_blocks.1.upsamplers.0",
75
+ "up_blocks.4": "up_blocks.1",
76
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
77
+ "up_blocks.6": "up_blocks.2",
78
+ "up_blocks.7": "up_blocks.3.upsamplers.0",
79
+ "up_blocks.8": "up_blocks.3",
80
+ # common
81
+ "last_time_embedder": "time_embedder",
82
+ "last_scale_shift_table": "scale_shift_table",
83
+ }
84
+
85
+ VAE_095_RENAME_DICT = {
86
+ # decoder
87
+ "up_blocks.0": "mid_block",
88
+ "up_blocks.1": "up_blocks.0.upsamplers.0",
89
+ "up_blocks.2": "up_blocks.0",
90
+ "up_blocks.3": "up_blocks.1.upsamplers.0",
91
+ "up_blocks.4": "up_blocks.1",
92
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
93
+ "up_blocks.6": "up_blocks.2",
94
+ "up_blocks.7": "up_blocks.3.upsamplers.0",
95
+ "up_blocks.8": "up_blocks.3",
96
+ # encoder
97
+ "down_blocks.0": "down_blocks.0",
98
+ "down_blocks.1": "down_blocks.0.downsamplers.0",
99
+ "down_blocks.2": "down_blocks.1",
100
+ "down_blocks.3": "down_blocks.1.downsamplers.0",
101
+ "down_blocks.4": "down_blocks.2",
102
+ "down_blocks.5": "down_blocks.2.downsamplers.0",
103
+ "down_blocks.6": "down_blocks.3",
104
+ "down_blocks.7": "down_blocks.3.downsamplers.0",
105
+ "down_blocks.8": "mid_block",
106
+ # common
107
+ "last_time_embedder": "time_embedder",
108
+ "last_scale_shift_table": "scale_shift_table",
109
+ }
110
+
111
+ VAE_SPECIAL_KEYS_REMAP = {
112
+ "per_channel_statistics.channel": remove_keys_,
113
+ "per_channel_statistics.mean-of-means": remove_keys_,
114
+ "per_channel_statistics.mean-of-stds": remove_keys_,
115
+ "model.diffusion_model": remove_keys_,
116
+ }
117
+
118
+
119
+ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
120
+ state_dict = saved_dict
121
+ if "model" in saved_dict.keys():
122
+ state_dict = state_dict["model"]
123
+ if "module" in saved_dict.keys():
124
+ state_dict = state_dict["module"]
125
+ if "state_dict" in saved_dict.keys():
126
+ state_dict = state_dict["state_dict"]
127
+ return state_dict
128
+
129
+
130
+ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
131
+ state_dict[new_key] = state_dict.pop(old_key)
132
+
133
+
134
+ def convert_transformer(ckpt_path: str, config, dtype: torch.dtype):
135
+ PREFIX_KEY = "model.diffusion_model."
136
+
137
+ original_state_dict = get_state_dict(load_file(ckpt_path))
138
+ with init_empty_weights():
139
+ transformer = LTXVideoTransformer3DModel(**config)
140
+
141
+ for key in list(original_state_dict.keys()):
142
+ new_key = key[:]
143
+ if new_key.startswith(PREFIX_KEY):
144
+ new_key = key[len(PREFIX_KEY) :]
145
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
146
+ new_key = new_key.replace(replace_key, rename_key)
147
+ update_state_dict_inplace(original_state_dict, key, new_key)
148
+
149
+ for key in list(original_state_dict.keys()):
150
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
151
+ if special_key not in key:
152
+ continue
153
+ handler_fn_inplace(key, original_state_dict)
154
+
155
+ transformer.load_state_dict(original_state_dict, strict=True, assign=True)
156
+ return transformer
157
+
158
+
159
+ def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
160
+ PREFIX_KEY = "vae."
161
+
162
+ original_state_dict = get_state_dict(load_file(ckpt_path))
163
+ with init_empty_weights():
164
+ vae = AutoencoderKLLTXVideo(**config)
165
+
166
+ for key in list(original_state_dict.keys()):
167
+ new_key = key[:]
168
+ if new_key.startswith(PREFIX_KEY):
169
+ new_key = key[len(PREFIX_KEY) :]
170
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
171
+ new_key = new_key.replace(replace_key, rename_key)
172
+ update_state_dict_inplace(original_state_dict, key, new_key)
173
+
174
+ for key in list(original_state_dict.keys()):
175
+ for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
176
+ if special_key not in key:
177
+ continue
178
+ handler_fn_inplace(key, original_state_dict)
179
+
180
+ vae.load_state_dict(original_state_dict, strict=True, assign=True)
181
+ return vae
182
+
183
+
184
+ def convert_spatial_latent_upsampler(ckpt_path: str, config, dtype: torch.dtype):
185
+ original_state_dict = get_state_dict(load_file(ckpt_path))
186
+
187
+ with init_empty_weights():
188
+ latent_upsampler = LTXLatentUpsamplerModel(**config)
189
+
190
+ latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
191
+ latent_upsampler.to(dtype)
192
+ return latent_upsampler
193
+
194
+
195
+ def get_transformer_config(version: str) -> Dict[str, Any]:
196
+ if version == "0.9.7":
197
+ config = {
198
+ "in_channels": 128,
199
+ "out_channels": 128,
200
+ "patch_size": 1,
201
+ "patch_size_t": 1,
202
+ "num_attention_heads": 32,
203
+ "attention_head_dim": 128,
204
+ "cross_attention_dim": 4096,
205
+ "num_layers": 48,
206
+ "activation_fn": "gelu-approximate",
207
+ "qk_norm": "rms_norm_across_heads",
208
+ "norm_elementwise_affine": False,
209
+ "norm_eps": 1e-6,
210
+ "caption_channels": 4096,
211
+ "attention_bias": True,
212
+ "attention_out_bias": True,
213
+ }
214
+ else:
215
+ config = {
216
+ "in_channels": 128,
217
+ "out_channels": 128,
218
+ "patch_size": 1,
219
+ "patch_size_t": 1,
220
+ "num_attention_heads": 32,
221
+ "attention_head_dim": 64,
222
+ "cross_attention_dim": 2048,
223
+ "num_layers": 28,
224
+ "activation_fn": "gelu-approximate",
225
+ "qk_norm": "rms_norm_across_heads",
226
+ "norm_elementwise_affine": False,
227
+ "norm_eps": 1e-6,
228
+ "caption_channels": 4096,
229
+ "attention_bias": True,
230
+ "attention_out_bias": True,
231
+ }
232
+ return config
233
+
234
+
235
+ def get_vae_config(version: str) -> Dict[str, Any]:
236
+ if version in ["0.9.0"]:
237
+ config = {
238
+ "in_channels": 3,
239
+ "out_channels": 3,
240
+ "latent_channels": 128,
241
+ "block_out_channels": (128, 256, 512, 512),
242
+ "down_block_types": (
243
+ "LTXVideoDownBlock3D",
244
+ "LTXVideoDownBlock3D",
245
+ "LTXVideoDownBlock3D",
246
+ "LTXVideoDownBlock3D",
247
+ ),
248
+ "decoder_block_out_channels": (128, 256, 512, 512),
249
+ "layers_per_block": (4, 3, 3, 3, 4),
250
+ "decoder_layers_per_block": (4, 3, 3, 3, 4),
251
+ "spatio_temporal_scaling": (True, True, True, False),
252
+ "decoder_spatio_temporal_scaling": (True, True, True, False),
253
+ "decoder_inject_noise": (False, False, False, False, False),
254
+ "downsample_type": ("conv", "conv", "conv", "conv"),
255
+ "upsample_residual": (False, False, False, False),
256
+ "upsample_factor": (1, 1, 1, 1),
257
+ "patch_size": 4,
258
+ "patch_size_t": 1,
259
+ "resnet_norm_eps": 1e-6,
260
+ "scaling_factor": 1.0,
261
+ "encoder_causal": True,
262
+ "decoder_causal": False,
263
+ "timestep_conditioning": False,
264
+ }
265
+ elif version in ["0.9.1"]:
266
+ config = {
267
+ "in_channels": 3,
268
+ "out_channels": 3,
269
+ "latent_channels": 128,
270
+ "block_out_channels": (128, 256, 512, 512),
271
+ "down_block_types": (
272
+ "LTXVideoDownBlock3D",
273
+ "LTXVideoDownBlock3D",
274
+ "LTXVideoDownBlock3D",
275
+ "LTXVideoDownBlock3D",
276
+ ),
277
+ "decoder_block_out_channels": (256, 512, 1024),
278
+ "layers_per_block": (4, 3, 3, 3, 4),
279
+ "decoder_layers_per_block": (5, 6, 7, 8),
280
+ "spatio_temporal_scaling": (True, True, True, False),
281
+ "decoder_spatio_temporal_scaling": (True, True, True),
282
+ "decoder_inject_noise": (True, True, True, False),
283
+ "downsample_type": ("conv", "conv", "conv", "conv"),
284
+ "upsample_residual": (True, True, True),
285
+ "upsample_factor": (2, 2, 2),
286
+ "timestep_conditioning": True,
287
+ "patch_size": 4,
288
+ "patch_size_t": 1,
289
+ "resnet_norm_eps": 1e-6,
290
+ "scaling_factor": 1.0,
291
+ "encoder_causal": True,
292
+ "decoder_causal": False,
293
+ }
294
+ VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
295
+ elif version in ["0.9.5"]:
296
+ config = {
297
+ "in_channels": 3,
298
+ "out_channels": 3,
299
+ "latent_channels": 128,
300
+ "block_out_channels": (128, 256, 512, 1024, 2048),
301
+ "down_block_types": (
302
+ "LTXVideo095DownBlock3D",
303
+ "LTXVideo095DownBlock3D",
304
+ "LTXVideo095DownBlock3D",
305
+ "LTXVideo095DownBlock3D",
306
+ ),
307
+ "decoder_block_out_channels": (256, 512, 1024),
308
+ "layers_per_block": (4, 6, 6, 2, 2),
309
+ "decoder_layers_per_block": (5, 5, 5, 5),
310
+ "spatio_temporal_scaling": (True, True, True, True),
311
+ "decoder_spatio_temporal_scaling": (True, True, True),
312
+ "decoder_inject_noise": (False, False, False, False),
313
+ "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
314
+ "upsample_residual": (True, True, True),
315
+ "upsample_factor": (2, 2, 2),
316
+ "timestep_conditioning": True,
317
+ "patch_size": 4,
318
+ "patch_size_t": 1,
319
+ "resnet_norm_eps": 1e-6,
320
+ "scaling_factor": 1.0,
321
+ "encoder_causal": True,
322
+ "decoder_causal": False,
323
+ "spatial_compression_ratio": 32,
324
+ "temporal_compression_ratio": 8,
325
+ }
326
+ VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
327
+ elif version in ["0.9.7"]:
328
+ config = {
329
+ "in_channels": 3,
330
+ "out_channels": 3,
331
+ "latent_channels": 128,
332
+ "block_out_channels": (128, 256, 512, 1024, 2048),
333
+ "down_block_types": (
334
+ "LTXVideo095DownBlock3D",
335
+ "LTXVideo095DownBlock3D",
336
+ "LTXVideo095DownBlock3D",
337
+ "LTXVideo095DownBlock3D",
338
+ ),
339
+ "decoder_block_out_channels": (256, 512, 1024),
340
+ "layers_per_block": (4, 6, 6, 2, 2),
341
+ "decoder_layers_per_block": (5, 5, 5, 5),
342
+ "spatio_temporal_scaling": (True, True, True, True),
343
+ "decoder_spatio_temporal_scaling": (True, True, True),
344
+ "decoder_inject_noise": (False, False, False, False),
345
+ "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
346
+ "upsample_residual": (True, True, True),
347
+ "upsample_factor": (2, 2, 2),
348
+ "timestep_conditioning": True,
349
+ "patch_size": 4,
350
+ "patch_size_t": 1,
351
+ "resnet_norm_eps": 1e-6,
352
+ "scaling_factor": 1.0,
353
+ "encoder_causal": True,
354
+ "decoder_causal": False,
355
+ "spatial_compression_ratio": 32,
356
+ "temporal_compression_ratio": 8,
357
+ }
358
+ VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
359
+ return config
360
+
361
+
362
+ def get_spatial_latent_upsampler_config(version: str) -> Dict[str, Any]:
363
+ if version == "0.9.7":
364
+ config = {
365
+ "in_channels": 128,
366
+ "mid_channels": 512,
367
+ "num_blocks_per_stage": 4,
368
+ "dims": 3,
369
+ "spatial_upsample": True,
370
+ "temporal_upsample": False,
371
+ }
372
+ else:
373
+ raise ValueError(f"Unsupported version: {version}")
374
+ return config
375
+
376
+
377
+ def get_args():
378
+ parser = argparse.ArgumentParser()
379
+ parser.add_argument(
380
+ "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
381
+ )
382
+ parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
383
+ parser.add_argument(
384
+ "--spatial_latent_upsampler_path",
385
+ type=str,
386
+ default=None,
387
+ help="Path to original spatial latent upsampler checkpoint",
388
+ )
389
+ parser.add_argument(
390
+ "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
391
+ )
392
+ parser.add_argument(
393
+ "--typecast_text_encoder",
394
+ action="store_true",
395
+ default=False,
396
+ help="Whether or not to apply fp16/bf16 precision to text_encoder",
397
+ )
398
+ parser.add_argument("--save_pipeline", action="store_true")
399
+ parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
400
+ parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
401
+ parser.add_argument(
402
+ "--version",
403
+ type=str,
404
+ default="0.9.0",
405
+ choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7"],
406
+ help="Version of the LTX model",
407
+ )
408
+ return parser.parse_args()
409
+
410
+
411
+ DTYPE_MAPPING = {
412
+ "fp32": torch.float32,
413
+ "fp16": torch.float16,
414
+ "bf16": torch.bfloat16,
415
+ }
416
+
417
+ VARIANT_MAPPING = {
418
+ "fp32": None,
419
+ "fp16": "fp16",
420
+ "bf16": "bf16",
421
+ }
422
+
423
+
424
+ if __name__ == "__main__":
425
+ args = get_args()
426
+
427
+ transformer = None
428
+ dtype = DTYPE_MAPPING[args.dtype]
429
+ variant = VARIANT_MAPPING[args.dtype]
430
+ output_path = Path(args.output_path)
431
+
432
+ if args.transformer_ckpt_path is not None:
433
+ config = get_transformer_config(args.version)
434
+ transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, config, dtype)
435
+ if not args.save_pipeline:
436
+ transformer.save_pretrained(
437
+ output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
438
+ )
439
+
440
+ if args.vae_ckpt_path is not None:
441
+ config = get_vae_config(args.version)
442
+ vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype)
443
+ if not args.save_pipeline:
444
+ vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
445
+
446
+ if args.spatial_latent_upsampler_path is not None:
447
+ config = get_spatial_latent_upsampler_config(args.version)
448
+ latent_upsampler: LTXLatentUpsamplerModel = convert_spatial_latent_upsampler(
449
+ args.spatial_latent_upsampler_path, config, dtype
450
+ )
451
+ if not args.save_pipeline:
452
+ latent_upsampler.save_pretrained(
453
+ output_path / "latent_upsampler", safe_serialization=True, max_shard_size="5GB", variant=variant
454
+ )
455
+
456
+ if args.save_pipeline:
457
+ text_encoder_id = "google/t5-v1_1-xxl"
458
+ tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
459
+ text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
460
+
461
+ if args.typecast_text_encoder:
462
+ text_encoder = text_encoder.to(dtype=dtype)
463
+
464
+ # Apparently, the conversion does not work anymore without this :shrug:
465
+ for param in text_encoder.parameters():
466
+ param.data = param.data.contiguous()
467
+
468
+ if args.version in ["0.9.5", "0.9.7"]:
469
+ scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
470
+ else:
471
+ scheduler = FlowMatchEulerDiscreteScheduler(
472
+ use_dynamic_shifting=True,
473
+ base_shift=0.95,
474
+ max_shift=2.05,
475
+ base_image_seq_len=1024,
476
+ max_image_seq_len=4096,
477
+ shift_terminal=0.1,
478
+ )
479
+
480
+ if args.version in ["0.9.0", "0.9.1", "0.9.5"]:
481
+ pipe = LTXPipeline(
482
+ scheduler=scheduler,
483
+ vae=vae,
484
+ text_encoder=text_encoder,
485
+ tokenizer=tokenizer,
486
+ transformer=transformer,
487
+ )
488
+ pipe.save_pretrained(
489
+ output_path.as_posix(), safe_serialization=True, variant=variant, max_shard_size="5GB"
490
+ )
491
+ elif args.version in ["0.9.7"]:
492
+ pipe = LTXConditionPipeline(
493
+ scheduler=scheduler,
494
+ vae=vae,
495
+ text_encoder=text_encoder,
496
+ tokenizer=tokenizer,
497
+ transformer=transformer,
498
+ )
499
+ pipe_upsample = LTXLatentUpsamplePipeline(
500
+ vae=vae,
501
+ latent_upsampler=latent_upsampler,
502
+ )
503
+ pipe.save_pretrained(
504
+ (output_path / "ltx_pipeline").as_posix(),
505
+ safe_serialization=True,
506
+ variant=variant,
507
+ max_shard_size="5GB",
508
+ )
509
+ pipe_upsample.save_pretrained(
510
+ (output_path / "ltx_upsample_pipeline").as_posix(),
511
+ safe_serialization=True,
512
+ variant=variant,
513
+ max_shard_size="5GB",
514
+ )
515
+ else:
516
+ raise ValueError(f"Unsupported version: {args.version}")
diffusers/scripts/convert_lumina_to_diffusers.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from safetensors.torch import load_file
6
+ from transformers import AutoModel, AutoTokenizer
7
+
8
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaPipeline
9
+
10
+
11
+ def main(args):
12
+ # checkpoint from https://huggingface.co/Alpha-VLLM/Lumina-Next-SFT or https://huggingface.co/Alpha-VLLM/Lumina-Next-T2I
13
+ all_sd = load_file(args.origin_ckpt_path, device="cpu")
14
+ converted_state_dict = {}
15
+ # pad token
16
+ converted_state_dict["pad_token"] = all_sd["pad_token"]
17
+
18
+ # patch embed
19
+ converted_state_dict["patch_embedder.weight"] = all_sd["x_embedder.weight"]
20
+ converted_state_dict["patch_embedder.bias"] = all_sd["x_embedder.bias"]
21
+
22
+ # time and caption embed
23
+ converted_state_dict["time_caption_embed.timestep_embedder.linear_1.weight"] = all_sd["t_embedder.mlp.0.weight"]
24
+ converted_state_dict["time_caption_embed.timestep_embedder.linear_1.bias"] = all_sd["t_embedder.mlp.0.bias"]
25
+ converted_state_dict["time_caption_embed.timestep_embedder.linear_2.weight"] = all_sd["t_embedder.mlp.2.weight"]
26
+ converted_state_dict["time_caption_embed.timestep_embedder.linear_2.bias"] = all_sd["t_embedder.mlp.2.bias"]
27
+ converted_state_dict["time_caption_embed.caption_embedder.0.weight"] = all_sd["cap_embedder.0.weight"]
28
+ converted_state_dict["time_caption_embed.caption_embedder.0.bias"] = all_sd["cap_embedder.0.bias"]
29
+ converted_state_dict["time_caption_embed.caption_embedder.1.weight"] = all_sd["cap_embedder.1.weight"]
30
+ converted_state_dict["time_caption_embed.caption_embedder.1.bias"] = all_sd["cap_embedder.1.bias"]
31
+
32
+ for i in range(24):
33
+ # adaln
34
+ converted_state_dict[f"layers.{i}.gate"] = all_sd[f"layers.{i}.attention.gate"]
35
+ converted_state_dict[f"layers.{i}.adaLN_modulation.1.weight"] = all_sd[f"layers.{i}.adaLN_modulation.1.weight"]
36
+ converted_state_dict[f"layers.{i}.adaLN_modulation.1.bias"] = all_sd[f"layers.{i}.adaLN_modulation.1.bias"]
37
+
38
+ # qkv
39
+ converted_state_dict[f"layers.{i}.attn1.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
40
+ converted_state_dict[f"layers.{i}.attn1.to_k.weight"] = all_sd[f"layers.{i}.attention.wk.weight"]
41
+ converted_state_dict[f"layers.{i}.attn1.to_v.weight"] = all_sd[f"layers.{i}.attention.wv.weight"]
42
+
43
+ # cap
44
+ converted_state_dict[f"layers.{i}.attn2.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
45
+ converted_state_dict[f"layers.{i}.attn2.to_k.weight"] = all_sd[f"layers.{i}.attention.wk_y.weight"]
46
+ converted_state_dict[f"layers.{i}.attn2.to_v.weight"] = all_sd[f"layers.{i}.attention.wv_y.weight"]
47
+
48
+ # output
49
+ converted_state_dict[f"layers.{i}.attn2.to_out.0.weight"] = all_sd[f"layers.{i}.attention.wo.weight"]
50
+
51
+ # attention
52
+ # qk norm
53
+ converted_state_dict[f"layers.{i}.attn1.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
54
+ converted_state_dict[f"layers.{i}.attn1.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]
55
+
56
+ converted_state_dict[f"layers.{i}.attn1.norm_k.weight"] = all_sd[f"layers.{i}.attention.k_norm.weight"]
57
+ converted_state_dict[f"layers.{i}.attn1.norm_k.bias"] = all_sd[f"layers.{i}.attention.k_norm.bias"]
58
+
59
+ converted_state_dict[f"layers.{i}.attn2.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
60
+ converted_state_dict[f"layers.{i}.attn2.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]
61
+
62
+ converted_state_dict[f"layers.{i}.attn2.norm_k.weight"] = all_sd[f"layers.{i}.attention.ky_norm.weight"]
63
+ converted_state_dict[f"layers.{i}.attn2.norm_k.bias"] = all_sd[f"layers.{i}.attention.ky_norm.bias"]
64
+
65
+ # attention norm
66
+ converted_state_dict[f"layers.{i}.attn_norm1.weight"] = all_sd[f"layers.{i}.attention_norm1.weight"]
67
+ converted_state_dict[f"layers.{i}.attn_norm2.weight"] = all_sd[f"layers.{i}.attention_norm2.weight"]
68
+ converted_state_dict[f"layers.{i}.norm1_context.weight"] = all_sd[f"layers.{i}.attention_y_norm.weight"]
69
+
70
+ # feed forward
71
+ converted_state_dict[f"layers.{i}.feed_forward.linear_1.weight"] = all_sd[f"layers.{i}.feed_forward.w1.weight"]
72
+ converted_state_dict[f"layers.{i}.feed_forward.linear_2.weight"] = all_sd[f"layers.{i}.feed_forward.w2.weight"]
73
+ converted_state_dict[f"layers.{i}.feed_forward.linear_3.weight"] = all_sd[f"layers.{i}.feed_forward.w3.weight"]
74
+
75
+ # feed forward norm
76
+ converted_state_dict[f"layers.{i}.ffn_norm1.weight"] = all_sd[f"layers.{i}.ffn_norm1.weight"]
77
+ converted_state_dict[f"layers.{i}.ffn_norm2.weight"] = all_sd[f"layers.{i}.ffn_norm2.weight"]
78
+
79
+ # final layer
80
+ converted_state_dict["final_layer.linear.weight"] = all_sd["final_layer.linear.weight"]
81
+ converted_state_dict["final_layer.linear.bias"] = all_sd["final_layer.linear.bias"]
82
+
83
+ converted_state_dict["final_layer.adaLN_modulation.1.weight"] = all_sd["final_layer.adaLN_modulation.1.weight"]
84
+ converted_state_dict["final_layer.adaLN_modulation.1.bias"] = all_sd["final_layer.adaLN_modulation.1.bias"]
85
+
86
+ # Lumina-Next-SFT 2B
87
+ transformer = LuminaNextDiT2DModel(
88
+ sample_size=128,
89
+ patch_size=2,
90
+ in_channels=4,
91
+ hidden_size=2304,
92
+ num_layers=24,
93
+ num_attention_heads=32,
94
+ num_kv_heads=8,
95
+ multiple_of=256,
96
+ ffn_dim_multiplier=None,
97
+ norm_eps=1e-5,
98
+ learn_sigma=True,
99
+ qk_norm=True,
100
+ cross_attention_dim=2048,
101
+ scaling_factor=1.0,
102
+ )
103
+ transformer.load_state_dict(converted_state_dict, strict=True)
104
+
105
+ num_model_params = sum(p.numel() for p in transformer.parameters())
106
+ print(f"Total number of transformer parameters: {num_model_params}")
107
+
108
+ if args.only_transformer:
109
+ transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
110
+ else:
111
+ scheduler = FlowMatchEulerDiscreteScheduler()
112
+
113
+ vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float32)
114
+
115
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
116
+ text_encoder = AutoModel.from_pretrained("google/gemma-2b")
117
+
118
+ pipeline = LuminaPipeline(
119
+ tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
120
+ )
121
+ pipeline.save_pretrained(args.dump_path)
122
+
123
+
124
+ if __name__ == "__main__":
125
+ parser = argparse.ArgumentParser()
126
+
127
+ parser.add_argument(
128
+ "--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
129
+ )
130
+ parser.add_argument(
131
+ "--image_size",
132
+ default=1024,
133
+ type=int,
134
+ choices=[256, 512, 1024],
135
+ required=False,
136
+ help="Image size of pretrained model, either 512 or 1024.",
137
+ )
138
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
139
+ parser.add_argument("--only_transformer", default=True, type=bool, required=True)
140
+
141
+ args = parser.parse_args()
142
+ main(args)
diffusers/scripts/convert_mochi_to_diffusers.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from contextlib import nullcontext
3
+
4
+ import torch
5
+ from accelerate import init_empty_weights
6
+ from safetensors.torch import load_file
7
+ from transformers import T5EncoderModel, T5Tokenizer
8
+
9
+ from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
10
+ from diffusers.utils.import_utils import is_accelerate_available
11
+
12
+
13
+ CTX = init_empty_weights if is_accelerate_available() else nullcontext
14
+
15
+ TOKENIZER_MAX_LENGTH = 256
16
+
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
19
+ parser.add_argument("--vae_encoder_checkpoint_path", default=None, type=str)
20
+ parser.add_argument("--vae_decoder_checkpoint_path", default=None, type=str)
21
+ parser.add_argument("--output_path", required=True, type=str)
22
+ parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
23
+ parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
24
+ parser.add_argument("--dtype", type=str, default=None)
25
+
26
+ args = parser.parse_args()
27
+
28
+
29
+ # This is specific to `AdaLayerNormContinuous`:
30
+ # Diffusers implementation split the linear projection into the scale, shift while Mochi split it into shift, scale
31
+ def swap_scale_shift(weight, dim):
32
+ shift, scale = weight.chunk(2, dim=0)
33
+ new_weight = torch.cat([scale, shift], dim=0)
34
+ return new_weight
35
+
36
+
37
+ def swap_proj_gate(weight):
38
+ proj, gate = weight.chunk(2, dim=0)
39
+ new_weight = torch.cat([gate, proj], dim=0)
40
+ return new_weight
41
+
42
+
43
+ def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
44
+ original_state_dict = load_file(ckpt_path, device="cpu")
45
+ new_state_dict = {}
46
+
47
+ # Convert patch_embed
48
+ new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight")
49
+ new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias")
50
+
51
+ # Convert time_embed
52
+ new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight")
53
+ new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias")
54
+ new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight")
55
+ new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("t_embedder.mlp.2.bias")
56
+ new_state_dict["time_embed.pooler.to_kv.weight"] = original_state_dict.pop("t5_y_embedder.to_kv.weight")
57
+ new_state_dict["time_embed.pooler.to_kv.bias"] = original_state_dict.pop("t5_y_embedder.to_kv.bias")
58
+ new_state_dict["time_embed.pooler.to_q.weight"] = original_state_dict.pop("t5_y_embedder.to_q.weight")
59
+ new_state_dict["time_embed.pooler.to_q.bias"] = original_state_dict.pop("t5_y_embedder.to_q.bias")
60
+ new_state_dict["time_embed.pooler.to_out.weight"] = original_state_dict.pop("t5_y_embedder.to_out.weight")
61
+ new_state_dict["time_embed.pooler.to_out.bias"] = original_state_dict.pop("t5_y_embedder.to_out.bias")
62
+ new_state_dict["time_embed.caption_proj.weight"] = original_state_dict.pop("t5_yproj.weight")
63
+ new_state_dict["time_embed.caption_proj.bias"] = original_state_dict.pop("t5_yproj.bias")
64
+
65
+ # Convert transformer blocks
66
+ num_layers = 48
67
+ for i in range(num_layers):
68
+ block_prefix = f"transformer_blocks.{i}."
69
+ old_prefix = f"blocks.{i}."
70
+
71
+ # norm1
72
+ new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(old_prefix + "mod_x.weight")
73
+ new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(old_prefix + "mod_x.bias")
74
+ if i < num_layers - 1:
75
+ new_state_dict[block_prefix + "norm1_context.linear.weight"] = original_state_dict.pop(
76
+ old_prefix + "mod_y.weight"
77
+ )
78
+ new_state_dict[block_prefix + "norm1_context.linear.bias"] = original_state_dict.pop(
79
+ old_prefix + "mod_y.bias"
80
+ )
81
+ else:
82
+ new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop(
83
+ old_prefix + "mod_y.weight"
84
+ )
85
+ new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop(
86
+ old_prefix + "mod_y.bias"
87
+ )
88
+
89
+ # Visual attention
90
+ qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight")
91
+ q, k, v = qkv_weight.chunk(3, dim=0)
92
+
93
+ new_state_dict[block_prefix + "attn1.to_q.weight"] = q
94
+ new_state_dict[block_prefix + "attn1.to_k.weight"] = k
95
+ new_state_dict[block_prefix + "attn1.to_v.weight"] = v
96
+ new_state_dict[block_prefix + "attn1.norm_q.weight"] = original_state_dict.pop(
97
+ old_prefix + "attn.q_norm_x.weight"
98
+ )
99
+ new_state_dict[block_prefix + "attn1.norm_k.weight"] = original_state_dict.pop(
100
+ old_prefix + "attn.k_norm_x.weight"
101
+ )
102
+ new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
103
+ old_prefix + "attn.proj_x.weight"
104
+ )
105
+ new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(old_prefix + "attn.proj_x.bias")
106
+
107
+ # Context attention
108
+ qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight")
109
+ q, k, v = qkv_weight.chunk(3, dim=0)
110
+
111
+ new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
112
+ new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
113
+ new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
114
+ new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = original_state_dict.pop(
115
+ old_prefix + "attn.q_norm_y.weight"
116
+ )
117
+ new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = original_state_dict.pop(
118
+ old_prefix + "attn.k_norm_y.weight"
119
+ )
120
+ if i < num_layers - 1:
121
+ new_state_dict[block_prefix + "attn1.to_add_out.weight"] = original_state_dict.pop(
122
+ old_prefix + "attn.proj_y.weight"
123
+ )
124
+ new_state_dict[block_prefix + "attn1.to_add_out.bias"] = original_state_dict.pop(
125
+ old_prefix + "attn.proj_y.bias"
126
+ )
127
+
128
+ # MLP
129
+ new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
130
+ original_state_dict.pop(old_prefix + "mlp_x.w1.weight")
131
+ )
132
+ new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight")
133
+ if i < num_layers - 1:
134
+ new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
135
+ original_state_dict.pop(old_prefix + "mlp_y.w1.weight")
136
+ )
137
+ new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop(
138
+ old_prefix + "mlp_y.w2.weight"
139
+ )
140
+
141
+ # Output layers
142
+ new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
143
+ original_state_dict.pop("final_layer.mod.weight"), dim=0
144
+ )
145
+ new_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.bias"), dim=0)
146
+ new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
147
+ new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
148
+
149
+ new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies")
150
+
151
+ print("Remaining Keys:", original_state_dict.keys())
152
+
153
+ return new_state_dict
154
+
155
+
156
+ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_path):
157
+ encoder_state_dict = load_file(encoder_ckpt_path, device="cpu")
158
+ decoder_state_dict = load_file(decoder_ckpt_path, device="cpu")
159
+ new_state_dict = {}
160
+
161
+ # ==== Decoder =====
162
+ prefix = "decoder."
163
+
164
+ # Convert conv_in
165
+ new_state_dict[f"{prefix}conv_in.weight"] = decoder_state_dict.pop("blocks.0.0.weight")
166
+ new_state_dict[f"{prefix}conv_in.bias"] = decoder_state_dict.pop("blocks.0.0.bias")
167
+
168
+ # Convert block_in (MochiMidBlock3D)
169
+ for i in range(3): # layers_per_block[-1] = 3
170
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
171
+ f"blocks.0.{i + 1}.stack.0.weight"
172
+ )
173
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
174
+ f"blocks.0.{i + 1}.stack.0.bias"
175
+ )
176
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
177
+ f"blocks.0.{i + 1}.stack.2.weight"
178
+ )
179
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
180
+ f"blocks.0.{i + 1}.stack.2.bias"
181
+ )
182
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
183
+ f"blocks.0.{i + 1}.stack.3.weight"
184
+ )
185
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
186
+ f"blocks.0.{i + 1}.stack.3.bias"
187
+ )
188
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
189
+ f"blocks.0.{i + 1}.stack.5.weight"
190
+ )
191
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
192
+ f"blocks.0.{i + 1}.stack.5.bias"
193
+ )
194
+
195
+ # Convert up_blocks (MochiUpBlock3D)
196
+ down_block_layers = [6, 4, 3] # layers_per_block[-2], layers_per_block[-3], layers_per_block[-4]
197
+ for block in range(3):
198
+ for i in range(down_block_layers[block]):
199
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
200
+ f"blocks.{block + 1}.blocks.{i}.stack.0.weight"
201
+ )
202
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
203
+ f"blocks.{block + 1}.blocks.{i}.stack.0.bias"
204
+ )
205
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
206
+ f"blocks.{block + 1}.blocks.{i}.stack.2.weight"
207
+ )
208
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
209
+ f"blocks.{block + 1}.blocks.{i}.stack.2.bias"
210
+ )
211
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
212
+ f"blocks.{block + 1}.blocks.{i}.stack.3.weight"
213
+ )
214
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
215
+ f"blocks.{block + 1}.blocks.{i}.stack.3.bias"
216
+ )
217
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
218
+ f"blocks.{block + 1}.blocks.{i}.stack.5.weight"
219
+ )
220
+ new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
221
+ f"blocks.{block + 1}.blocks.{i}.stack.5.bias"
222
+ )
223
+ new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
224
+ f"blocks.{block + 1}.proj.weight"
225
+ )
226
+ new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(
227
+ f"blocks.{block + 1}.proj.bias"
228
+ )
229
+
230
+ # Convert block_out (MochiMidBlock3D)
231
+ for i in range(3): # layers_per_block[0] = 3
232
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
233
+ f"blocks.4.{i}.stack.0.weight"
234
+ )
235
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
236
+ f"blocks.4.{i}.stack.0.bias"
237
+ )
238
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
239
+ f"blocks.4.{i}.stack.2.weight"
240
+ )
241
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
242
+ f"blocks.4.{i}.stack.2.bias"
243
+ )
244
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
245
+ f"blocks.4.{i}.stack.3.weight"
246
+ )
247
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
248
+ f"blocks.4.{i}.stack.3.bias"
249
+ )
250
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
251
+ f"blocks.4.{i}.stack.5.weight"
252
+ )
253
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
254
+ f"blocks.4.{i}.stack.5.bias"
255
+ )
256
+
257
+ # Convert proj_out (Conv1x1 ~= nn.Linear)
258
+ new_state_dict[f"{prefix}proj_out.weight"] = decoder_state_dict.pop("output_proj.weight")
259
+ new_state_dict[f"{prefix}proj_out.bias"] = decoder_state_dict.pop("output_proj.bias")
260
+
261
+ print("Remaining Decoder Keys:", decoder_state_dict.keys())
262
+
263
+ # ==== Encoder =====
264
+ prefix = "encoder."
265
+
266
+ new_state_dict[f"{prefix}proj_in.weight"] = encoder_state_dict.pop("layers.0.weight")
267
+ new_state_dict[f"{prefix}proj_in.bias"] = encoder_state_dict.pop("layers.0.bias")
268
+
269
+ # Convert block_in (MochiMidBlock3D)
270
+ for i in range(3): # layers_per_block[0] = 3
271
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
272
+ f"layers.{i + 1}.stack.0.weight"
273
+ )
274
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
275
+ f"layers.{i + 1}.stack.0.bias"
276
+ )
277
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
278
+ f"layers.{i + 1}.stack.2.weight"
279
+ )
280
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
281
+ f"layers.{i + 1}.stack.2.bias"
282
+ )
283
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
284
+ f"layers.{i + 1}.stack.3.weight"
285
+ )
286
+ new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
287
+ f"layers.{i + 1}.stack.3.bias"
288
+ )
289
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
290
+ f"layers.{i + 1}.stack.5.weight"
291
+ )
292
+ new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
293
+ f"layers.{i + 1}.stack.5.bias"
294
+ )
295
+
296
+ # Convert down_blocks (MochiDownBlock3D)
297
+ down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3]
298
+ for block in range(3):
299
+ new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop(
300
+ f"layers.{block + 4}.layers.0.weight"
301
+ )
302
+ new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop(
303
+ f"layers.{block + 4}.layers.0.bias"
304
+ )
305
+
306
+ for i in range(down_block_layers[block]):
307
+ # Convert resnets
308
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = (
309
+ encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight")
310
+ )
311
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
312
+ f"layers.{block + 4}.layers.{i + 1}.stack.0.bias"
313
+ )
314
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
315
+ f"layers.{block + 4}.layers.{i + 1}.stack.2.weight"
316
+ )
317
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
318
+ f"layers.{block + 4}.layers.{i + 1}.stack.2.bias"
319
+ )
320
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = (
321
+ encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight")
322
+ )
323
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
324
+ f"layers.{block + 4}.layers.{i + 1}.stack.3.bias"
325
+ )
326
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
327
+ f"layers.{block + 4}.layers.{i + 1}.stack.5.weight"
328
+ )
329
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
330
+ f"layers.{block + 4}.layers.{i + 1}.stack.5.bias"
331
+ )
332
+
333
+ # Convert attentions
334
+ qkv_weight = encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.qkv.weight")
335
+ q, k, v = qkv_weight.chunk(3, dim=0)
336
+
337
+ new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q
338
+ new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k
339
+ new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v
340
+ new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
341
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.weight"
342
+ )
343
+ new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
344
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.bias"
345
+ )
346
+ new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
347
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.weight"
348
+ )
349
+ new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
350
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.bias"
351
+ )
352
+
353
+ # Convert block_out (MochiMidBlock3D)
354
+ for i in range(3): # layers_per_block[-1] = 3
355
+ # Convert resnets
356
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
357
+ f"layers.{i + 7}.stack.0.weight"
358
+ )
359
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
360
+ f"layers.{i + 7}.stack.0.bias"
361
+ )
362
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
363
+ f"layers.{i + 7}.stack.2.weight"
364
+ )
365
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
366
+ f"layers.{i + 7}.stack.2.bias"
367
+ )
368
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
369
+ f"layers.{i + 7}.stack.3.weight"
370
+ )
371
+ new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
372
+ f"layers.{i + 7}.stack.3.bias"
373
+ )
374
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
375
+ f"layers.{i + 7}.stack.5.weight"
376
+ )
377
+ new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
378
+ f"layers.{i + 7}.stack.5.bias"
379
+ )
380
+
381
+ # Convert attentions
382
+ qkv_weight = encoder_state_dict.pop(f"layers.{i + 7}.attn_block.attn.qkv.weight")
383
+ q, k, v = qkv_weight.chunk(3, dim=0)
384
+
385
+ new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q
386
+ new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k
387
+ new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v
388
+ new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
389
+ f"layers.{i + 7}.attn_block.attn.out.weight"
390
+ )
391
+ new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
392
+ f"layers.{i + 7}.attn_block.attn.out.bias"
393
+ )
394
+ new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
395
+ f"layers.{i + 7}.attn_block.norm.weight"
396
+ )
397
+ new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
398
+ f"layers.{i + 7}.attn_block.norm.bias"
399
+ )
400
+
401
+ # Convert output layers
402
+ new_state_dict[f"{prefix}norm_out.norm_layer.weight"] = encoder_state_dict.pop("output_norm.weight")
403
+ new_state_dict[f"{prefix}norm_out.norm_layer.bias"] = encoder_state_dict.pop("output_norm.bias")
404
+ new_state_dict[f"{prefix}proj_out.weight"] = encoder_state_dict.pop("output_proj.weight")
405
+
406
+ print("Remaining Encoder Keys:", encoder_state_dict.keys())
407
+
408
+ return new_state_dict
409
+
410
+
411
+ def main(args):
412
+ if args.dtype is None:
413
+ dtype = None
414
+ if args.dtype == "fp16":
415
+ dtype = torch.float16
416
+ elif args.dtype == "bf16":
417
+ dtype = torch.bfloat16
418
+ elif args.dtype == "fp32":
419
+ dtype = torch.float32
420
+ else:
421
+ raise ValueError(f"Unsupported dtype: {args.dtype}")
422
+
423
+ transformer = None
424
+ vae = None
425
+
426
+ if args.transformer_checkpoint_path is not None:
427
+ converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers(
428
+ args.transformer_checkpoint_path
429
+ )
430
+ transformer = MochiTransformer3DModel()
431
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
432
+ if dtype is not None:
433
+ transformer = transformer.to(dtype=dtype)
434
+
435
+ if args.vae_encoder_checkpoint_path is not None and args.vae_decoder_checkpoint_path is not None:
436
+ vae = AutoencoderKLMochi(latent_channels=12, out_channels=3)
437
+ converted_vae_state_dict = convert_mochi_vae_state_dict_to_diffusers(
438
+ args.vae_encoder_checkpoint_path, args.vae_decoder_checkpoint_path
439
+ )
440
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
441
+ if dtype is not None:
442
+ vae = vae.to(dtype=dtype)
443
+
444
+ text_encoder_id = "google/t5-v1_1-xxl"
445
+ tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
446
+ text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
447
+
448
+ # Apparently, the conversion does not work anymore without this :shrug:
449
+ for param in text_encoder.parameters():
450
+ param.data = param.data.contiguous()
451
+
452
+ pipe = MochiPipeline(
453
+ scheduler=FlowMatchEulerDiscreteScheduler(invert_sigmas=True),
454
+ vae=vae,
455
+ text_encoder=text_encoder,
456
+ tokenizer=tokenizer,
457
+ transformer=transformer,
458
+ )
459
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
460
+
461
+
462
+ if __name__ == "__main__":
463
+ main(args)
diffusers/scripts/convert_models_diffuser_to_diffusers.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import torch
5
+
6
+ from diffusers import UNet1DModel
7
+
8
+
9
+ os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True)
10
+ os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True)
11
+
12
+ os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True)
13
+
14
+
15
+ def unet(hor):
16
+ if hor == 128:
17
+ down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
18
+ block_out_channels = (32, 128, 256)
19
+ up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D")
20
+
21
+ elif hor == 32:
22
+ down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
23
+ block_out_channels = (32, 64, 128, 256)
24
+ up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
25
+ model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
26
+ state_dict = model.state_dict()
27
+ config = {
28
+ "down_block_types": down_block_types,
29
+ "block_out_channels": block_out_channels,
30
+ "up_block_types": up_block_types,
31
+ "layers_per_block": 1,
32
+ "use_timestep_embedding": True,
33
+ "out_block_type": "OutConv1DBlock",
34
+ "norm_num_groups": 8,
35
+ "downsample_each_block": False,
36
+ "in_channels": 14,
37
+ "out_channels": 14,
38
+ "extra_in_channels": 0,
39
+ "time_embedding_type": "positional",
40
+ "flip_sin_to_cos": False,
41
+ "freq_shift": 1,
42
+ "sample_size": 65536,
43
+ "mid_block_type": "MidResTemporalBlock1D",
44
+ "act_fn": "mish",
45
+ }
46
+ hf_value_function = UNet1DModel(**config)
47
+ print(f"length of state dict: {len(state_dict.keys())}")
48
+ print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
49
+ mapping = dict(zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
50
+ for k, v in mapping.items():
51
+ state_dict[v] = state_dict.pop(k)
52
+ hf_value_function.load_state_dict(state_dict)
53
+
54
+ torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin")
55
+ with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f:
56
+ json.dump(config, f)
57
+
58
+
59
+ def value_function():
60
+ config = {
61
+ "in_channels": 14,
62
+ "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
63
+ "up_block_types": (),
64
+ "out_block_type": "ValueFunction",
65
+ "mid_block_type": "ValueFunctionMidBlock1D",
66
+ "block_out_channels": (32, 64, 128, 256),
67
+ "layers_per_block": 1,
68
+ "downsample_each_block": True,
69
+ "sample_size": 65536,
70
+ "out_channels": 14,
71
+ "extra_in_channels": 0,
72
+ "time_embedding_type": "positional",
73
+ "use_timestep_embedding": True,
74
+ "flip_sin_to_cos": False,
75
+ "freq_shift": 1,
76
+ "norm_num_groups": 8,
77
+ "act_fn": "mish",
78
+ }
79
+
80
+ model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
81
+ state_dict = model
82
+ hf_value_function = UNet1DModel(**config)
83
+ print(f"length of state dict: {len(state_dict.keys())}")
84
+ print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
85
+
86
+ mapping = dict(zip(state_dict.keys(), hf_value_function.state_dict().keys()))
87
+ for k, v in mapping.items():
88
+ state_dict[v] = state_dict.pop(k)
89
+
90
+ hf_value_function.load_state_dict(state_dict)
91
+
92
+ torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin")
93
+ with open("hub/hopper-medium-v2/value_function/config.json", "w") as f:
94
+ json.dump(config, f)
95
+
96
+
97
+ if __name__ == "__main__":
98
+ unet(32)
99
+ # unet(128)
100
+ value_function()
diffusers/scripts/convert_ms_text_to_video_to_diffusers.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Conversion script for the LDM checkpoints."""
16
+
17
+ import argparse
18
+
19
+ import torch
20
+
21
+ from diffusers import UNet3DConditionModel
22
+
23
+
24
+ def assign_to_checkpoint(
25
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
26
+ ):
27
+ """
28
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
29
+ attention layers, and takes into account additional replacements that may arise.
30
+
31
+ Assigns the weights to the new checkpoint.
32
+ """
33
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
34
+
35
+ # Splits the attention layers into three variables.
36
+ if attention_paths_to_split is not None:
37
+ for path, path_map in attention_paths_to_split.items():
38
+ old_tensor = old_checkpoint[path]
39
+ channels = old_tensor.shape[0] // 3
40
+
41
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
42
+
43
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
44
+
45
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
46
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
47
+
48
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
49
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
50
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
51
+
52
+ for path in paths:
53
+ new_path = path["new"]
54
+
55
+ # These have already been assigned
56
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
57
+ continue
58
+
59
+ if additional_replacements is not None:
60
+ for replacement in additional_replacements:
61
+ new_path = new_path.replace(replacement["old"], replacement["new"])
62
+
63
+ # proj_attn.weight has to be converted from conv 1D to linear
64
+ weight = old_checkpoint[path["old"]]
65
+ names = ["proj_attn.weight"]
66
+ names_2 = ["proj_out.weight", "proj_in.weight"]
67
+ if any(k in new_path for k in names):
68
+ checkpoint[new_path] = weight[:, :, 0]
69
+ elif any(k in new_path for k in names_2) and len(weight.shape) > 2 and ".attentions." not in new_path:
70
+ checkpoint[new_path] = weight[:, :, 0]
71
+ else:
72
+ checkpoint[new_path] = weight
73
+
74
+
75
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
76
+ """
77
+ Updates paths inside attentions to the new naming scheme (local renaming)
78
+ """
79
+ mapping = []
80
+ for old_item in old_list:
81
+ new_item = old_item
82
+
83
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
84
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
85
+
86
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
87
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
88
+
89
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
90
+
91
+ mapping.append({"old": old_item, "new": new_item})
92
+
93
+ return mapping
94
+
95
+
96
+ def shave_segments(path, n_shave_prefix_segments=1):
97
+ """
98
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
99
+ """
100
+ if n_shave_prefix_segments >= 0:
101
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
102
+ else:
103
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
104
+
105
+
106
+ def renew_temp_conv_paths(old_list, n_shave_prefix_segments=0):
107
+ """
108
+ Updates paths inside resnets to the new naming scheme (local renaming)
109
+ """
110
+ mapping = []
111
+ for old_item in old_list:
112
+ mapping.append({"old": old_item, "new": old_item})
113
+
114
+ return mapping
115
+
116
+
117
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
118
+ """
119
+ Updates paths inside resnets to the new naming scheme (local renaming)
120
+ """
121
+ mapping = []
122
+ for old_item in old_list:
123
+ new_item = old_item.replace("in_layers.0", "norm1")
124
+ new_item = new_item.replace("in_layers.2", "conv1")
125
+
126
+ new_item = new_item.replace("out_layers.0", "norm2")
127
+ new_item = new_item.replace("out_layers.3", "conv2")
128
+
129
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
130
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
131
+
132
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
133
+
134
+ if "temopral_conv" not in old_item:
135
+ mapping.append({"old": old_item, "new": new_item})
136
+
137
+ return mapping
138
+
139
+
140
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
141
+ """
142
+ Takes a state dict and a config, and returns a converted checkpoint.
143
+ """
144
+
145
+ # extract state_dict for UNet
146
+ unet_state_dict = {}
147
+ keys = list(checkpoint.keys())
148
+
149
+ unet_key = "model.diffusion_model."
150
+
151
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
152
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
153
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
154
+ print(
155
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
156
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
157
+ )
158
+ for key in keys:
159
+ if key.startswith("model.diffusion_model"):
160
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
161
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
162
+ else:
163
+ if sum(k.startswith("model_ema") for k in keys) > 100:
164
+ print(
165
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
166
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
167
+ )
168
+
169
+ for key in keys:
170
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
171
+
172
+ new_checkpoint = {}
173
+
174
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
175
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
176
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
177
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
178
+
179
+ if config["class_embed_type"] is None:
180
+ # No parameters to port
181
+ ...
182
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
183
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
184
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
185
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
186
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
187
+ else:
188
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
189
+
190
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
191
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
192
+
193
+ first_temp_attention = [v for v in unet_state_dict if v.startswith("input_blocks.0.1")]
194
+ paths = renew_attention_paths(first_temp_attention)
195
+ meta_path = {"old": "input_blocks.0.1", "new": "transformer_in"}
196
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
197
+
198
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
199
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
200
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
201
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
202
+
203
+ # Retrieves the keys for the input blocks only
204
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
205
+ input_blocks = {
206
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
207
+ for layer_id in range(num_input_blocks)
208
+ }
209
+
210
+ # Retrieves the keys for the middle blocks only
211
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
212
+ middle_blocks = {
213
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
214
+ for layer_id in range(num_middle_blocks)
215
+ }
216
+
217
+ # Retrieves the keys for the output blocks only
218
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
219
+ output_blocks = {
220
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
221
+ for layer_id in range(num_output_blocks)
222
+ }
223
+
224
+ for i in range(1, num_input_blocks):
225
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
226
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
227
+
228
+ resnets = [
229
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
230
+ ]
231
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
232
+ temp_attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.2" in key]
233
+
234
+ if f"input_blocks.{i}.op.weight" in unet_state_dict:
235
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
236
+ f"input_blocks.{i}.op.weight"
237
+ )
238
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
239
+ f"input_blocks.{i}.op.bias"
240
+ )
241
+
242
+ paths = renew_resnet_paths(resnets)
243
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
244
+ assign_to_checkpoint(
245
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
246
+ )
247
+
248
+ temporal_convs = [key for key in resnets if "temopral_conv" in key]
249
+ paths = renew_temp_conv_paths(temporal_convs)
250
+ meta_path = {
251
+ "old": f"input_blocks.{i}.0.temopral_conv",
252
+ "new": f"down_blocks.{block_id}.temp_convs.{layer_in_block_id}",
253
+ }
254
+ assign_to_checkpoint(
255
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
256
+ )
257
+
258
+ if len(attentions):
259
+ paths = renew_attention_paths(attentions)
260
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
261
+ assign_to_checkpoint(
262
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
263
+ )
264
+
265
+ if len(temp_attentions):
266
+ paths = renew_attention_paths(temp_attentions)
267
+ meta_path = {
268
+ "old": f"input_blocks.{i}.2",
269
+ "new": f"down_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
270
+ }
271
+ assign_to_checkpoint(
272
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
273
+ )
274
+
275
+ resnet_0 = middle_blocks[0]
276
+ temporal_convs_0 = [key for key in resnet_0 if "temopral_conv" in key]
277
+ attentions = middle_blocks[1]
278
+ temp_attentions = middle_blocks[2]
279
+ resnet_1 = middle_blocks[3]
280
+ temporal_convs_1 = [key for key in resnet_1 if "temopral_conv" in key]
281
+
282
+ resnet_0_paths = renew_resnet_paths(resnet_0)
283
+ meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"}
284
+ assign_to_checkpoint(
285
+ resnet_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
286
+ )
287
+
288
+ temp_conv_0_paths = renew_temp_conv_paths(temporal_convs_0)
289
+ meta_path = {"old": "middle_block.0.temopral_conv", "new": "mid_block.temp_convs.0"}
290
+ assign_to_checkpoint(
291
+ temp_conv_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
292
+ )
293
+
294
+ resnet_1_paths = renew_resnet_paths(resnet_1)
295
+ meta_path = {"old": "middle_block.3", "new": "mid_block.resnets.1"}
296
+ assign_to_checkpoint(
297
+ resnet_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
298
+ )
299
+
300
+ temp_conv_1_paths = renew_temp_conv_paths(temporal_convs_1)
301
+ meta_path = {"old": "middle_block.3.temopral_conv", "new": "mid_block.temp_convs.1"}
302
+ assign_to_checkpoint(
303
+ temp_conv_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
304
+ )
305
+
306
+ attentions_paths = renew_attention_paths(attentions)
307
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
308
+ assign_to_checkpoint(
309
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
310
+ )
311
+
312
+ temp_attentions_paths = renew_attention_paths(temp_attentions)
313
+ meta_path = {"old": "middle_block.2", "new": "mid_block.temp_attentions.0"}
314
+ assign_to_checkpoint(
315
+ temp_attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
316
+ )
317
+
318
+ for i in range(num_output_blocks):
319
+ block_id = i // (config["layers_per_block"] + 1)
320
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
321
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
322
+ output_block_list = {}
323
+
324
+ for layer in output_block_layers:
325
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
326
+ if layer_id in output_block_list:
327
+ output_block_list[layer_id].append(layer_name)
328
+ else:
329
+ output_block_list[layer_id] = [layer_name]
330
+
331
+ if len(output_block_list) > 1:
332
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
333
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
334
+ temp_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key]
335
+
336
+ resnet_0_paths = renew_resnet_paths(resnets)
337
+ paths = renew_resnet_paths(resnets)
338
+
339
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
340
+ assign_to_checkpoint(
341
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
342
+ )
343
+
344
+ temporal_convs = [key for key in resnets if "temopral_conv" in key]
345
+ paths = renew_temp_conv_paths(temporal_convs)
346
+ meta_path = {
347
+ "old": f"output_blocks.{i}.0.temopral_conv",
348
+ "new": f"up_blocks.{block_id}.temp_convs.{layer_in_block_id}",
349
+ }
350
+ assign_to_checkpoint(
351
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
352
+ )
353
+
354
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
355
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
356
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
357
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
358
+ f"output_blocks.{i}.{index}.conv.weight"
359
+ ]
360
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
361
+ f"output_blocks.{i}.{index}.conv.bias"
362
+ ]
363
+
364
+ # Clear attentions as they have been attributed above.
365
+ if len(attentions) == 2:
366
+ attentions = []
367
+
368
+ if len(attentions):
369
+ paths = renew_attention_paths(attentions)
370
+ meta_path = {
371
+ "old": f"output_blocks.{i}.1",
372
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
373
+ }
374
+ assign_to_checkpoint(
375
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
376
+ )
377
+
378
+ if len(temp_attentions):
379
+ paths = renew_attention_paths(temp_attentions)
380
+ meta_path = {
381
+ "old": f"output_blocks.{i}.2",
382
+ "new": f"up_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
383
+ }
384
+ assign_to_checkpoint(
385
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
386
+ )
387
+ else:
388
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
389
+ for path in resnet_0_paths:
390
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
391
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
392
+ new_checkpoint[new_path] = unet_state_dict[old_path]
393
+
394
+ temopral_conv_paths = [l for l in output_block_layers if "temopral_conv" in l]
395
+ for path in temopral_conv_paths:
396
+ pruned_path = path.split("temopral_conv.")[-1]
397
+ old_path = ".".join(["output_blocks", str(i), str(block_id), "temopral_conv", pruned_path])
398
+ new_path = ".".join(["up_blocks", str(block_id), "temp_convs", str(layer_in_block_id), pruned_path])
399
+ new_checkpoint[new_path] = unet_state_dict[old_path]
400
+
401
+ return new_checkpoint
402
+
403
+
404
+ if __name__ == "__main__":
405
+ parser = argparse.ArgumentParser()
406
+
407
+ parser.add_argument(
408
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
409
+ )
410
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
411
+ args = parser.parse_args()
412
+
413
+ unet_checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
414
+ unet = UNet3DConditionModel()
415
+
416
+ converted_ckpt = convert_ldm_unet_checkpoint(unet_checkpoint, unet.config)
417
+
418
+ diff_0 = set(unet.state_dict().keys()) - set(converted_ckpt.keys())
419
+ diff_1 = set(converted_ckpt.keys()) - set(unet.state_dict().keys())
420
+
421
+ assert len(diff_0) == len(diff_1) == 0, "Converted weights don't match"
422
+
423
+ # load state_dict
424
+ unet.load_state_dict(converted_ckpt)
425
+
426
+ unet.save_pretrained(args.dump_path)
427
+
428
+ # -- finish converting the unet --
diffusers/scripts/convert_music_spectrogram_to_diffusers.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import os
4
+
5
+ import jax as jnp
6
+ import numpy as onp
7
+ import torch
8
+ import torch.nn as nn
9
+ from music_spectrogram_diffusion import inference
10
+ from t5x import checkpoints
11
+
12
+ from diffusers import DDPMScheduler, OnnxRuntimeModel, SpectrogramDiffusionPipeline
13
+ from diffusers.pipelines.spectrogram_diffusion import SpectrogramContEncoder, SpectrogramNotesEncoder, T5FilmDecoder
14
+
15
+
16
+ MODEL = "base_with_context"
17
+
18
+
19
+ def load_notes_encoder(weights, model):
20
+ model.token_embedder.weight = nn.Parameter(torch.Tensor(weights["token_embedder"]["embedding"]))
21
+ model.position_encoding.weight = nn.Parameter(torch.Tensor(weights["Embed_0"]["embedding"]), requires_grad=False)
22
+ for lyr_num, lyr in enumerate(model.encoders):
23
+ ly_weight = weights[f"layers_{lyr_num}"]
24
+ lyr.layer[0].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_attention_layer_norm"]["scale"]))
25
+
26
+ attention_weights = ly_weight["attention"]
27
+ lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
28
+ lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
29
+ lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
30
+ lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
31
+
32
+ lyr.layer[1].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
33
+
34
+ lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
35
+ lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
36
+ lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wo"]["kernel"].T))
37
+
38
+ model.layer_norm.weight = nn.Parameter(torch.Tensor(weights["encoder_norm"]["scale"]))
39
+ return model
40
+
41
+
42
+ def load_continuous_encoder(weights, model):
43
+ model.input_proj.weight = nn.Parameter(torch.Tensor(weights["input_proj"]["kernel"].T))
44
+
45
+ model.position_encoding.weight = nn.Parameter(torch.Tensor(weights["Embed_0"]["embedding"]), requires_grad=False)
46
+
47
+ for lyr_num, lyr in enumerate(model.encoders):
48
+ ly_weight = weights[f"layers_{lyr_num}"]
49
+ attention_weights = ly_weight["attention"]
50
+
51
+ lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
52
+ lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
53
+ lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
54
+ lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
55
+ lyr.layer[0].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_attention_layer_norm"]["scale"]))
56
+
57
+ lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
58
+ lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
59
+ lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wo"]["kernel"].T))
60
+ lyr.layer[1].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
61
+
62
+ model.layer_norm.weight = nn.Parameter(torch.Tensor(weights["encoder_norm"]["scale"]))
63
+
64
+ return model
65
+
66
+
67
+ def load_decoder(weights, model):
68
+ model.conditioning_emb[0].weight = nn.Parameter(torch.Tensor(weights["time_emb_dense0"]["kernel"].T))
69
+ model.conditioning_emb[2].weight = nn.Parameter(torch.Tensor(weights["time_emb_dense1"]["kernel"].T))
70
+
71
+ model.position_encoding.weight = nn.Parameter(torch.Tensor(weights["Embed_0"]["embedding"]), requires_grad=False)
72
+
73
+ model.continuous_inputs_projection.weight = nn.Parameter(
74
+ torch.Tensor(weights["continuous_inputs_projection"]["kernel"].T)
75
+ )
76
+
77
+ for lyr_num, lyr in enumerate(model.decoders):
78
+ ly_weight = weights[f"layers_{lyr_num}"]
79
+ lyr.layer[0].layer_norm.weight = nn.Parameter(
80
+ torch.Tensor(ly_weight["pre_self_attention_layer_norm"]["scale"])
81
+ )
82
+
83
+ lyr.layer[0].FiLMLayer.scale_bias.weight = nn.Parameter(
84
+ torch.Tensor(ly_weight["FiLMLayer_0"]["DenseGeneral_0"]["kernel"].T)
85
+ )
86
+
87
+ attention_weights = ly_weight["self_attention"]
88
+ lyr.layer[0].attention.to_q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
89
+ lyr.layer[0].attention.to_k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
90
+ lyr.layer[0].attention.to_v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
91
+ lyr.layer[0].attention.to_out[0].weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
92
+
93
+ attention_weights = ly_weight["MultiHeadDotProductAttention_0"]
94
+ lyr.layer[1].attention.to_q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
95
+ lyr.layer[1].attention.to_k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
96
+ lyr.layer[1].attention.to_v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
97
+ lyr.layer[1].attention.to_out[0].weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
98
+ lyr.layer[1].layer_norm.weight = nn.Parameter(
99
+ torch.Tensor(ly_weight["pre_cross_attention_layer_norm"]["scale"])
100
+ )
101
+
102
+ lyr.layer[2].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
103
+ lyr.layer[2].film.scale_bias.weight = nn.Parameter(
104
+ torch.Tensor(ly_weight["FiLMLayer_1"]["DenseGeneral_0"]["kernel"].T)
105
+ )
106
+ lyr.layer[2].DenseReluDense.wi_0.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
107
+ lyr.layer[2].DenseReluDense.wi_1.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
108
+ lyr.layer[2].DenseReluDense.wo.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wo"]["kernel"].T))
109
+
110
+ model.decoder_norm.weight = nn.Parameter(torch.Tensor(weights["decoder_norm"]["scale"]))
111
+
112
+ model.spec_out.weight = nn.Parameter(torch.Tensor(weights["spec_out_dense"]["kernel"].T))
113
+
114
+ return model
115
+
116
+
117
+ def main(args):
118
+ t5_checkpoint = checkpoints.load_t5x_checkpoint(args.checkpoint_path)
119
+ t5_checkpoint = jnp.tree_util.tree_map(onp.array, t5_checkpoint)
120
+
121
+ gin_overrides = [
122
+ "from __gin__ import dynamic_registration",
123
+ "from music_spectrogram_diffusion.models.diffusion import diffusion_utils",
124
+ "diffusion_utils.ClassifierFreeGuidanceConfig.eval_condition_weight = 2.0",
125
+ "diffusion_utils.DiffusionConfig.classifier_free_guidance = @diffusion_utils.ClassifierFreeGuidanceConfig()",
126
+ ]
127
+
128
+ gin_file = os.path.join(args.checkpoint_path, "..", "config.gin")
129
+ gin_config = inference.parse_training_gin_file(gin_file, gin_overrides)
130
+ synth_model = inference.InferenceModel(args.checkpoint_path, gin_config)
131
+
132
+ scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", variance_type="fixed_large")
133
+
134
+ notes_encoder = SpectrogramNotesEncoder(
135
+ max_length=synth_model.sequence_length["inputs"],
136
+ vocab_size=synth_model.model.module.config.vocab_size,
137
+ d_model=synth_model.model.module.config.emb_dim,
138
+ dropout_rate=synth_model.model.module.config.dropout_rate,
139
+ num_layers=synth_model.model.module.config.num_encoder_layers,
140
+ num_heads=synth_model.model.module.config.num_heads,
141
+ d_kv=synth_model.model.module.config.head_dim,
142
+ d_ff=synth_model.model.module.config.mlp_dim,
143
+ feed_forward_proj="gated-gelu",
144
+ )
145
+
146
+ continuous_encoder = SpectrogramContEncoder(
147
+ input_dims=synth_model.audio_codec.n_dims,
148
+ targets_context_length=synth_model.sequence_length["targets_context"],
149
+ d_model=synth_model.model.module.config.emb_dim,
150
+ dropout_rate=synth_model.model.module.config.dropout_rate,
151
+ num_layers=synth_model.model.module.config.num_encoder_layers,
152
+ num_heads=synth_model.model.module.config.num_heads,
153
+ d_kv=synth_model.model.module.config.head_dim,
154
+ d_ff=synth_model.model.module.config.mlp_dim,
155
+ feed_forward_proj="gated-gelu",
156
+ )
157
+
158
+ decoder = T5FilmDecoder(
159
+ input_dims=synth_model.audio_codec.n_dims,
160
+ targets_length=synth_model.sequence_length["targets_context"],
161
+ max_decoder_noise_time=synth_model.model.module.config.max_decoder_noise_time,
162
+ d_model=synth_model.model.module.config.emb_dim,
163
+ num_layers=synth_model.model.module.config.num_decoder_layers,
164
+ num_heads=synth_model.model.module.config.num_heads,
165
+ d_kv=synth_model.model.module.config.head_dim,
166
+ d_ff=synth_model.model.module.config.mlp_dim,
167
+ dropout_rate=synth_model.model.module.config.dropout_rate,
168
+ )
169
+
170
+ notes_encoder = load_notes_encoder(t5_checkpoint["target"]["token_encoder"], notes_encoder)
171
+ continuous_encoder = load_continuous_encoder(t5_checkpoint["target"]["continuous_encoder"], continuous_encoder)
172
+ decoder = load_decoder(t5_checkpoint["target"]["decoder"], decoder)
173
+
174
+ melgan = OnnxRuntimeModel.from_pretrained("kashif/soundstream_mel_decoder")
175
+
176
+ pipe = SpectrogramDiffusionPipeline(
177
+ notes_encoder=notes_encoder,
178
+ continuous_encoder=continuous_encoder,
179
+ decoder=decoder,
180
+ scheduler=scheduler,
181
+ melgan=melgan,
182
+ )
183
+ if args.save:
184
+ pipe.save_pretrained(args.output_path)
185
+
186
+
187
+ if __name__ == "__main__":
188
+ parser = argparse.ArgumentParser()
189
+
190
+ parser.add_argument("--output_path", default=None, type=str, required=True, help="Path to the converted model.")
191
+ parser.add_argument(
192
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."
193
+ )
194
+ parser.add_argument(
195
+ "--checkpoint_path",
196
+ default=f"{MODEL}/checkpoint_500000",
197
+ type=str,
198
+ required=False,
199
+ help="Path to the original jax model checkpoint.",
200
+ )
201
+ args = parser.parse_args()
202
+
203
+ main(args)
diffusers/scripts/convert_original_audioldm_to_diffusers.py ADDED
@@ -0,0 +1,1042 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Conversion script for the AudioLDM checkpoints."""
16
+
17
+ import argparse
18
+ import re
19
+
20
+ import torch
21
+ import yaml
22
+ from transformers import (
23
+ AutoTokenizer,
24
+ ClapTextConfig,
25
+ ClapTextModelWithProjection,
26
+ SpeechT5HifiGan,
27
+ SpeechT5HifiGanConfig,
28
+ )
29
+
30
+ from diffusers import (
31
+ AudioLDMPipeline,
32
+ AutoencoderKL,
33
+ DDIMScheduler,
34
+ DPMSolverMultistepScheduler,
35
+ EulerAncestralDiscreteScheduler,
36
+ EulerDiscreteScheduler,
37
+ HeunDiscreteScheduler,
38
+ LMSDiscreteScheduler,
39
+ PNDMScheduler,
40
+ UNet2DConditionModel,
41
+ )
42
+
43
+
44
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
45
+ def shave_segments(path, n_shave_prefix_segments=1):
46
+ """
47
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
48
+ """
49
+ if n_shave_prefix_segments >= 0:
50
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
51
+ else:
52
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
53
+
54
+
55
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_resnet_paths
56
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
57
+ """
58
+ Updates paths inside resnets to the new naming scheme (local renaming)
59
+ """
60
+ mapping = []
61
+ for old_item in old_list:
62
+ new_item = old_item.replace("in_layers.0", "norm1")
63
+ new_item = new_item.replace("in_layers.2", "conv1")
64
+
65
+ new_item = new_item.replace("out_layers.0", "norm2")
66
+ new_item = new_item.replace("out_layers.3", "conv2")
67
+
68
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
69
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
70
+
71
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
72
+
73
+ mapping.append({"old": old_item, "new": new_item})
74
+
75
+ return mapping
76
+
77
+
78
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths
79
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
80
+ """
81
+ Updates paths inside resnets to the new naming scheme (local renaming)
82
+ """
83
+ mapping = []
84
+ for old_item in old_list:
85
+ new_item = old_item
86
+
87
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
88
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
89
+
90
+ mapping.append({"old": old_item, "new": new_item})
91
+
92
+ return mapping
93
+
94
+
95
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_attention_paths
96
+ def renew_attention_paths(old_list):
97
+ """
98
+ Updates paths inside attentions to the new naming scheme (local renaming)
99
+ """
100
+ mapping = []
101
+ for old_item in old_list:
102
+ new_item = old_item
103
+
104
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
105
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
106
+
107
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
108
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
109
+
110
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
111
+
112
+ mapping.append({"old": old_item, "new": new_item})
113
+
114
+ return mapping
115
+
116
+
117
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_attention_paths
118
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
119
+ """
120
+ Updates paths inside attentions to the new naming scheme (local renaming)
121
+ """
122
+ mapping = []
123
+ for old_item in old_list:
124
+ new_item = old_item
125
+
126
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
127
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
128
+
129
+ new_item = new_item.replace("q.weight", "query.weight")
130
+ new_item = new_item.replace("q.bias", "query.bias")
131
+
132
+ new_item = new_item.replace("k.weight", "key.weight")
133
+ new_item = new_item.replace("k.bias", "key.bias")
134
+
135
+ new_item = new_item.replace("v.weight", "value.weight")
136
+ new_item = new_item.replace("v.bias", "value.bias")
137
+
138
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
139
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
140
+
141
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
142
+
143
+ mapping.append({"old": old_item, "new": new_item})
144
+
145
+ return mapping
146
+
147
+
148
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
149
+ def assign_to_checkpoint(
150
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
151
+ ):
152
+ """
153
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
154
+ attention layers, and takes into account additional replacements that may arise.
155
+
156
+ Assigns the weights to the new checkpoint.
157
+ """
158
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
159
+
160
+ # Splits the attention layers into three variables.
161
+ if attention_paths_to_split is not None:
162
+ for path, path_map in attention_paths_to_split.items():
163
+ old_tensor = old_checkpoint[path]
164
+ channels = old_tensor.shape[0] // 3
165
+
166
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
167
+
168
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
169
+
170
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
171
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
172
+
173
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
174
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
175
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
176
+
177
+ for path in paths:
178
+ new_path = path["new"]
179
+
180
+ # These have already been assigned
181
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
182
+ continue
183
+
184
+ # Global renaming happens here
185
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
186
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
187
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
188
+
189
+ if additional_replacements is not None:
190
+ for replacement in additional_replacements:
191
+ new_path = new_path.replace(replacement["old"], replacement["new"])
192
+
193
+ # proj_attn.weight has to be converted from conv 1D to linear
194
+ if "proj_attn.weight" in new_path:
195
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
196
+ else:
197
+ checkpoint[new_path] = old_checkpoint[path["old"]]
198
+
199
+
200
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
201
+ def conv_attn_to_linear(checkpoint):
202
+ keys = list(checkpoint.keys())
203
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
204
+ for key in keys:
205
+ if ".".join(key.split(".")[-2:]) in attn_keys:
206
+ if checkpoint[key].ndim > 2:
207
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
208
+ elif "proj_attn.weight" in key:
209
+ if checkpoint[key].ndim > 2:
210
+ checkpoint[key] = checkpoint[key][:, :, 0]
211
+
212
+
213
+ def create_unet_diffusers_config(original_config, image_size: int):
214
+ """
215
+ Creates a UNet config for diffusers based on the config of the original AudioLDM model.
216
+ """
217
+ unet_params = original_config["model"]["params"]["unet_config"]["params"]
218
+ vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
219
+
220
+ block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
221
+
222
+ down_block_types = []
223
+ resolution = 1
224
+ for i in range(len(block_out_channels)):
225
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
226
+ down_block_types.append(block_type)
227
+ if i != len(block_out_channels) - 1:
228
+ resolution *= 2
229
+
230
+ up_block_types = []
231
+ for i in range(len(block_out_channels)):
232
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
233
+ up_block_types.append(block_type)
234
+ resolution //= 2
235
+
236
+ vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
237
+
238
+ cross_attention_dim = (
239
+ unet_params["cross_attention_dim"] if "cross_attention_dim" in unet_params else block_out_channels
240
+ )
241
+
242
+ class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None
243
+ projection_class_embeddings_input_dim = (
244
+ unet_params["extra_film_condition_dim"] if "extra_film_condition_dim" in unet_params else None
245
+ )
246
+ class_embeddings_concat = unet_params["extra_film_use_concat"] if "extra_film_use_concat" in unet_params else None
247
+
248
+ config = {
249
+ "sample_size": image_size // vae_scale_factor,
250
+ "in_channels": unet_params["in_channels"],
251
+ "out_channels": unet_params["out_channels"],
252
+ "down_block_types": tuple(down_block_types),
253
+ "up_block_types": tuple(up_block_types),
254
+ "block_out_channels": tuple(block_out_channels),
255
+ "layers_per_block": unet_params["num_res_blocks"],
256
+ "cross_attention_dim": cross_attention_dim,
257
+ "class_embed_type": class_embed_type,
258
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
259
+ "class_embeddings_concat": class_embeddings_concat,
260
+ }
261
+
262
+ return config
263
+
264
+
265
+ # Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config
266
+ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
267
+ """
268
+ Creates a VAE config for diffusers based on the config of the original AudioLDM model. Compared to the original
269
+ Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
270
+ """
271
+ vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
272
+ _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
273
+
274
+ block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
275
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
276
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
277
+
278
+ scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215
279
+
280
+ config = {
281
+ "sample_size": image_size,
282
+ "in_channels": vae_params["in_channels"],
283
+ "out_channels": vae_params["out_ch"],
284
+ "down_block_types": tuple(down_block_types),
285
+ "up_block_types": tuple(up_block_types),
286
+ "block_out_channels": tuple(block_out_channels),
287
+ "latent_channels": vae_params["z_channels"],
288
+ "layers_per_block": vae_params["num_res_blocks"],
289
+ "scaling_factor": float(scaling_factor),
290
+ }
291
+ return config
292
+
293
+
294
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
295
+ def create_diffusers_schedular(original_config):
296
+ schedular = DDIMScheduler(
297
+ num_train_timesteps=original_config["model"]["params"]["timesteps"],
298
+ beta_start=original_config["model"]["params"]["linear_start"],
299
+ beta_end=original_config["model"]["params"]["linear_end"],
300
+ beta_schedule="scaled_linear",
301
+ )
302
+ return schedular
303
+
304
+
305
+ # Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_unet_checkpoint
306
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
307
+ """
308
+ Takes a state dict and a config, and returns a converted checkpoint. Compared to the original Stable Diffusion
309
+ conversion, this function additionally converts the learnt film embedding linear layer.
310
+ """
311
+
312
+ # extract state_dict for UNet
313
+ unet_state_dict = {}
314
+ keys = list(checkpoint.keys())
315
+
316
+ unet_key = "model.diffusion_model."
317
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
318
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
319
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
320
+ print(
321
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
322
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
323
+ )
324
+ for key in keys:
325
+ if key.startswith("model.diffusion_model"):
326
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
327
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
328
+ else:
329
+ if sum(k.startswith("model_ema") for k in keys) > 100:
330
+ print(
331
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
332
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
333
+ )
334
+
335
+ for key in keys:
336
+ if key.startswith(unet_key):
337
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
338
+
339
+ new_checkpoint = {}
340
+
341
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
342
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
343
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
344
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
345
+
346
+ new_checkpoint["class_embedding.weight"] = unet_state_dict["film_emb.weight"]
347
+ new_checkpoint["class_embedding.bias"] = unet_state_dict["film_emb.bias"]
348
+
349
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
350
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
351
+
352
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
353
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
354
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
355
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
356
+
357
+ # Retrieves the keys for the input blocks only
358
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
359
+ input_blocks = {
360
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
361
+ for layer_id in range(num_input_blocks)
362
+ }
363
+
364
+ # Retrieves the keys for the middle blocks only
365
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
366
+ middle_blocks = {
367
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
368
+ for layer_id in range(num_middle_blocks)
369
+ }
370
+
371
+ # Retrieves the keys for the output blocks only
372
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
373
+ output_blocks = {
374
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
375
+ for layer_id in range(num_output_blocks)
376
+ }
377
+
378
+ for i in range(1, num_input_blocks):
379
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
380
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
381
+
382
+ resnets = [
383
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
384
+ ]
385
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
386
+
387
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
388
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
389
+ f"input_blocks.{i}.0.op.weight"
390
+ )
391
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
392
+ f"input_blocks.{i}.0.op.bias"
393
+ )
394
+
395
+ paths = renew_resnet_paths(resnets)
396
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
397
+ assign_to_checkpoint(
398
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
399
+ )
400
+
401
+ if len(attentions):
402
+ paths = renew_attention_paths(attentions)
403
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
404
+ assign_to_checkpoint(
405
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
406
+ )
407
+
408
+ resnet_0 = middle_blocks[0]
409
+ attentions = middle_blocks[1]
410
+ resnet_1 = middle_blocks[2]
411
+
412
+ resnet_0_paths = renew_resnet_paths(resnet_0)
413
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
414
+
415
+ resnet_1_paths = renew_resnet_paths(resnet_1)
416
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
417
+
418
+ attentions_paths = renew_attention_paths(attentions)
419
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
420
+ assign_to_checkpoint(
421
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
422
+ )
423
+
424
+ for i in range(num_output_blocks):
425
+ block_id = i // (config["layers_per_block"] + 1)
426
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
427
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
428
+ output_block_list = {}
429
+
430
+ for layer in output_block_layers:
431
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
432
+ if layer_id in output_block_list:
433
+ output_block_list[layer_id].append(layer_name)
434
+ else:
435
+ output_block_list[layer_id] = [layer_name]
436
+
437
+ if len(output_block_list) > 1:
438
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
439
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
440
+
441
+ resnet_0_paths = renew_resnet_paths(resnets)
442
+ paths = renew_resnet_paths(resnets)
443
+
444
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
445
+ assign_to_checkpoint(
446
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
447
+ )
448
+
449
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
450
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
451
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
452
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
453
+ f"output_blocks.{i}.{index}.conv.weight"
454
+ ]
455
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
456
+ f"output_blocks.{i}.{index}.conv.bias"
457
+ ]
458
+
459
+ # Clear attentions as they have been attributed above.
460
+ if len(attentions) == 2:
461
+ attentions = []
462
+
463
+ if len(attentions):
464
+ paths = renew_attention_paths(attentions)
465
+ meta_path = {
466
+ "old": f"output_blocks.{i}.1",
467
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
468
+ }
469
+ assign_to_checkpoint(
470
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
471
+ )
472
+ else:
473
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
474
+ for path in resnet_0_paths:
475
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
476
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
477
+
478
+ new_checkpoint[new_path] = unet_state_dict[old_path]
479
+
480
+ return new_checkpoint
481
+
482
+
483
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
484
+ def convert_ldm_vae_checkpoint(checkpoint, config):
485
+ # extract state dict for VAE
486
+ vae_state_dict = {}
487
+ vae_key = "first_stage_model."
488
+ keys = list(checkpoint.keys())
489
+ for key in keys:
490
+ if key.startswith(vae_key):
491
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
492
+
493
+ new_checkpoint = {}
494
+
495
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
496
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
497
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
498
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
499
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
500
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
501
+
502
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
503
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
504
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
505
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
506
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
507
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
508
+
509
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
510
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
511
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
512
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
513
+
514
+ # Retrieves the keys for the encoder down blocks only
515
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
516
+ down_blocks = {
517
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
518
+ }
519
+
520
+ # Retrieves the keys for the decoder up blocks only
521
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
522
+ up_blocks = {
523
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
524
+ }
525
+
526
+ for i in range(num_down_blocks):
527
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
528
+
529
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
530
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
531
+ f"encoder.down.{i}.downsample.conv.weight"
532
+ )
533
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
534
+ f"encoder.down.{i}.downsample.conv.bias"
535
+ )
536
+
537
+ paths = renew_vae_resnet_paths(resnets)
538
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
539
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
540
+
541
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
542
+ num_mid_res_blocks = 2
543
+ for i in range(1, num_mid_res_blocks + 1):
544
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
545
+
546
+ paths = renew_vae_resnet_paths(resnets)
547
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
548
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
549
+
550
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
551
+ paths = renew_vae_attention_paths(mid_attentions)
552
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
553
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
554
+ conv_attn_to_linear(new_checkpoint)
555
+
556
+ for i in range(num_up_blocks):
557
+ block_id = num_up_blocks - 1 - i
558
+ resnets = [
559
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
560
+ ]
561
+
562
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
563
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
564
+ f"decoder.up.{block_id}.upsample.conv.weight"
565
+ ]
566
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
567
+ f"decoder.up.{block_id}.upsample.conv.bias"
568
+ ]
569
+
570
+ paths = renew_vae_resnet_paths(resnets)
571
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
572
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
573
+
574
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
575
+ num_mid_res_blocks = 2
576
+ for i in range(1, num_mid_res_blocks + 1):
577
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
578
+
579
+ paths = renew_vae_resnet_paths(resnets)
580
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
581
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
582
+
583
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
584
+ paths = renew_vae_attention_paths(mid_attentions)
585
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
586
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
587
+ conv_attn_to_linear(new_checkpoint)
588
+ return new_checkpoint
589
+
590
+
591
+ CLAP_KEYS_TO_MODIFY_MAPPING = {
592
+ "text_branch": "text_model",
593
+ "attn": "attention.self",
594
+ "self.proj": "output.dense",
595
+ "attention.self_mask": "attn_mask",
596
+ "mlp.fc1": "intermediate.dense",
597
+ "mlp.fc2": "output.dense",
598
+ "norm1": "layernorm_before",
599
+ "norm2": "layernorm_after",
600
+ "bn0": "batch_norm",
601
+ }
602
+
603
+ CLAP_KEYS_TO_IGNORE = ["text_transform"]
604
+
605
+ CLAP_EXPECTED_MISSING_KEYS = ["text_model.embeddings.token_type_ids"]
606
+
607
+
608
+ def convert_open_clap_checkpoint(checkpoint):
609
+ """
610
+ Takes a state dict and returns a converted CLAP checkpoint.
611
+ """
612
+ # extract state dict for CLAP text embedding model, discarding the audio component
613
+ model_state_dict = {}
614
+ model_key = "cond_stage_model.model.text_"
615
+ keys = list(checkpoint.keys())
616
+ for key in keys:
617
+ if key.startswith(model_key):
618
+ model_state_dict[key.replace(model_key, "text_")] = checkpoint.get(key)
619
+
620
+ new_checkpoint = {}
621
+
622
+ sequential_layers_pattern = r".*sequential.(\d+).*"
623
+ text_projection_pattern = r".*_projection.(\d+).*"
624
+
625
+ for key, value in model_state_dict.items():
626
+ # check if key should be ignored in mapping
627
+ if key.split(".")[0] in CLAP_KEYS_TO_IGNORE:
628
+ continue
629
+
630
+ # check if any key needs to be modified
631
+ for key_to_modify, new_key in CLAP_KEYS_TO_MODIFY_MAPPING.items():
632
+ if key_to_modify in key:
633
+ key = key.replace(key_to_modify, new_key)
634
+
635
+ if re.match(sequential_layers_pattern, key):
636
+ # replace sequential layers with list
637
+ sequential_layer = re.match(sequential_layers_pattern, key).group(1)
638
+
639
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
640
+ elif re.match(text_projection_pattern, key):
641
+ projecton_layer = int(re.match(text_projection_pattern, key).group(1))
642
+
643
+ # Because in CLAP they use `nn.Sequential`...
644
+ transformers_projection_layer = 1 if projecton_layer == 0 else 2
645
+
646
+ key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
647
+
648
+ if "audio" and "qkv" in key:
649
+ # split qkv into query key and value
650
+ mixed_qkv = value
651
+ qkv_dim = mixed_qkv.size(0) // 3
652
+
653
+ query_layer = mixed_qkv[:qkv_dim]
654
+ key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
655
+ value_layer = mixed_qkv[qkv_dim * 2 :]
656
+
657
+ new_checkpoint[key.replace("qkv", "query")] = query_layer
658
+ new_checkpoint[key.replace("qkv", "key")] = key_layer
659
+ new_checkpoint[key.replace("qkv", "value")] = value_layer
660
+ else:
661
+ new_checkpoint[key] = value
662
+
663
+ return new_checkpoint
664
+
665
+
666
+ def create_transformers_vocoder_config(original_config):
667
+ """
668
+ Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
669
+ """
670
+ vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"]
671
+
672
+ config = {
673
+ "model_in_dim": vocoder_params["num_mels"],
674
+ "sampling_rate": vocoder_params["sampling_rate"],
675
+ "upsample_initial_channel": vocoder_params["upsample_initial_channel"],
676
+ "upsample_rates": list(vocoder_params["upsample_rates"]),
677
+ "upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]),
678
+ "resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]),
679
+ "resblock_dilation_sizes": [
680
+ list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"]
681
+ ],
682
+ "normalize_before": False,
683
+ }
684
+
685
+ return config
686
+
687
+
688
+ def convert_hifigan_checkpoint(checkpoint, config):
689
+ """
690
+ Takes a state dict and config, and returns a converted HiFiGAN vocoder checkpoint.
691
+ """
692
+ # extract state dict for vocoder
693
+ vocoder_state_dict = {}
694
+ vocoder_key = "first_stage_model.vocoder."
695
+ keys = list(checkpoint.keys())
696
+ for key in keys:
697
+ if key.startswith(vocoder_key):
698
+ vocoder_state_dict[key.replace(vocoder_key, "")] = checkpoint.get(key)
699
+
700
+ # fix upsampler keys, everything else is correct already
701
+ for i in range(len(config.upsample_rates)):
702
+ vocoder_state_dict[f"upsampler.{i}.weight"] = vocoder_state_dict.pop(f"ups.{i}.weight")
703
+ vocoder_state_dict[f"upsampler.{i}.bias"] = vocoder_state_dict.pop(f"ups.{i}.bias")
704
+
705
+ if not config.normalize_before:
706
+ # if we don't set normalize_before then these variables are unused, so we set them to their initialised values
707
+ vocoder_state_dict["mean"] = torch.zeros(config.model_in_dim)
708
+ vocoder_state_dict["scale"] = torch.ones(config.model_in_dim)
709
+
710
+ return vocoder_state_dict
711
+
712
+
713
+ # Adapted from https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/blob/84a0384742a22bd80c44e903e241f0623e874f1d/audioldm/utils.py#L72-L73
714
+ DEFAULT_CONFIG = {
715
+ "model": {
716
+ "params": {
717
+ "linear_start": 0.0015,
718
+ "linear_end": 0.0195,
719
+ "timesteps": 1000,
720
+ "channels": 8,
721
+ "scale_by_std": True,
722
+ "unet_config": {
723
+ "target": "audioldm.latent_diffusion.openaimodel.UNetModel",
724
+ "params": {
725
+ "extra_film_condition_dim": 512,
726
+ "extra_film_use_concat": True,
727
+ "in_channels": 8,
728
+ "out_channels": 8,
729
+ "model_channels": 128,
730
+ "attention_resolutions": [8, 4, 2],
731
+ "num_res_blocks": 2,
732
+ "channel_mult": [1, 2, 3, 5],
733
+ "num_head_channels": 32,
734
+ },
735
+ },
736
+ "first_stage_config": {
737
+ "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL",
738
+ "params": {
739
+ "embed_dim": 8,
740
+ "ddconfig": {
741
+ "z_channels": 8,
742
+ "resolution": 256,
743
+ "in_channels": 1,
744
+ "out_ch": 1,
745
+ "ch": 128,
746
+ "ch_mult": [1, 2, 4],
747
+ "num_res_blocks": 2,
748
+ },
749
+ },
750
+ },
751
+ "vocoder_config": {
752
+ "target": "audioldm.first_stage_model.vocoder",
753
+ "params": {
754
+ "upsample_rates": [5, 4, 2, 2, 2],
755
+ "upsample_kernel_sizes": [16, 16, 8, 4, 4],
756
+ "upsample_initial_channel": 1024,
757
+ "resblock_kernel_sizes": [3, 7, 11],
758
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
759
+ "num_mels": 64,
760
+ "sampling_rate": 16000,
761
+ },
762
+ },
763
+ },
764
+ },
765
+ }
766
+
767
+
768
+ def load_pipeline_from_original_audioldm_ckpt(
769
+ checkpoint_path: str,
770
+ original_config_file: str = None,
771
+ image_size: int = 512,
772
+ prediction_type: str = None,
773
+ extract_ema: bool = False,
774
+ scheduler_type: str = "ddim",
775
+ num_in_channels: int = None,
776
+ model_channels: int = None,
777
+ num_head_channels: int = None,
778
+ device: str = None,
779
+ from_safetensors: bool = False,
780
+ ) -> AudioLDMPipeline:
781
+ """
782
+ Load an AudioLDM pipeline object from a `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file.
783
+
784
+ Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the
785
+ global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
786
+ recommended that you override the default values and/or supply an `original_config_file` wherever possible.
787
+
788
+ Args:
789
+ checkpoint_path (`str`): Path to `.ckpt` file.
790
+ original_config_file (`str`):
791
+ Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
792
+ set to the audioldm-s-full-v2 config.
793
+ image_size (`int`, *optional*, defaults to 512):
794
+ The image size that the model was trained on.
795
+ prediction_type (`str`, *optional*):
796
+ The prediction type that the model was trained on. If `None`, will be automatically
797
+ inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`.
798
+ num_in_channels (`int`, *optional*, defaults to None):
799
+ The number of UNet input channels. If `None`, it will be automatically inferred from the config.
800
+ model_channels (`int`, *optional*, defaults to None):
801
+ The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override
802
+ to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.
803
+ num_head_channels (`int`, *optional*, defaults to None):
804
+ The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override
805
+ to 32 for the small and medium checkpoints, and 64 for the large.
806
+ scheduler_type (`str`, *optional*, defaults to 'pndm'):
807
+ Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
808
+ "ddim"]`.
809
+ extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
810
+ checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
811
+ `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
812
+ inference. Non-EMA weights are usually better to continue fine-tuning.
813
+ device (`str`, *optional*, defaults to `None`):
814
+ The device to use. Pass `None` to determine automatically.
815
+ from_safetensors (`str`, *optional*, defaults to `False`):
816
+ If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
817
+ return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
818
+ """
819
+
820
+ if from_safetensors:
821
+ from safetensors import safe_open
822
+
823
+ checkpoint = {}
824
+ with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
825
+ for key in f.keys():
826
+ checkpoint[key] = f.get_tensor(key)
827
+ else:
828
+ if device is None:
829
+ device = "cuda" if torch.cuda.is_available() else "cpu"
830
+ checkpoint = torch.load(checkpoint_path, map_location=device)
831
+ else:
832
+ checkpoint = torch.load(checkpoint_path, map_location=device)
833
+
834
+ if "state_dict" in checkpoint:
835
+ checkpoint = checkpoint["state_dict"]
836
+
837
+ if original_config_file is None:
838
+ original_config = DEFAULT_CONFIG
839
+ else:
840
+ original_config = yaml.safe_load(original_config_file)
841
+
842
+ if num_in_channels is not None:
843
+ original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
844
+
845
+ if model_channels is not None:
846
+ original_config["model"]["params"]["unet_config"]["params"]["model_channels"] = model_channels
847
+
848
+ if num_head_channels is not None:
849
+ original_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = num_head_channels
850
+
851
+ if (
852
+ "parameterization" in original_config["model"]["params"]
853
+ and original_config["model"]["params"]["parameterization"] == "v"
854
+ ):
855
+ if prediction_type is None:
856
+ prediction_type = "v_prediction"
857
+ else:
858
+ if prediction_type is None:
859
+ prediction_type = "epsilon"
860
+
861
+ if image_size is None:
862
+ image_size = 512
863
+
864
+ num_train_timesteps = original_config["model"]["params"]["timesteps"]
865
+ beta_start = original_config["model"]["params"]["linear_start"]
866
+ beta_end = original_config["model"]["params"]["linear_end"]
867
+
868
+ scheduler = DDIMScheduler(
869
+ beta_end=beta_end,
870
+ beta_schedule="scaled_linear",
871
+ beta_start=beta_start,
872
+ num_train_timesteps=num_train_timesteps,
873
+ steps_offset=1,
874
+ clip_sample=False,
875
+ set_alpha_to_one=False,
876
+ prediction_type=prediction_type,
877
+ )
878
+ # make sure scheduler works correctly with DDIM
879
+ scheduler.register_to_config(clip_sample=False)
880
+
881
+ if scheduler_type == "pndm":
882
+ config = dict(scheduler.config)
883
+ config["skip_prk_steps"] = True
884
+ scheduler = PNDMScheduler.from_config(config)
885
+ elif scheduler_type == "lms":
886
+ scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
887
+ elif scheduler_type == "heun":
888
+ scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
889
+ elif scheduler_type == "euler":
890
+ scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
891
+ elif scheduler_type == "euler-ancestral":
892
+ scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
893
+ elif scheduler_type == "dpm":
894
+ scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
895
+ elif scheduler_type == "ddim":
896
+ scheduler = scheduler
897
+ else:
898
+ raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
899
+
900
+ # Convert the UNet2DModel
901
+ unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
902
+ unet = UNet2DConditionModel(**unet_config)
903
+
904
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
905
+ checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
906
+ )
907
+
908
+ unet.load_state_dict(converted_unet_checkpoint)
909
+
910
+ # Convert the VAE model
911
+ vae_config = create_vae_diffusers_config(original_config, checkpoint=checkpoint, image_size=image_size)
912
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
913
+
914
+ vae = AutoencoderKL(**vae_config)
915
+ vae.load_state_dict(converted_vae_checkpoint)
916
+
917
+ # Convert the text model
918
+ # AudioLDM uses the same configuration and tokenizer as the original CLAP model
919
+ config = ClapTextConfig.from_pretrained("laion/clap-htsat-unfused")
920
+ tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
921
+
922
+ converted_text_model = convert_open_clap_checkpoint(checkpoint)
923
+ text_model = ClapTextModelWithProjection(config)
924
+
925
+ missing_keys, unexpected_keys = text_model.load_state_dict(converted_text_model, strict=False)
926
+ # we expect not to have token_type_ids in our original state dict so let's ignore them
927
+ missing_keys = list(set(missing_keys) - set(CLAP_EXPECTED_MISSING_KEYS))
928
+
929
+ if len(unexpected_keys) > 0:
930
+ raise ValueError(f"Unexpected keys when loading CLAP model: {unexpected_keys}")
931
+
932
+ if len(missing_keys) > 0:
933
+ raise ValueError(f"Missing keys when loading CLAP model: {missing_keys}")
934
+
935
+ # Convert the vocoder model
936
+ vocoder_config = create_transformers_vocoder_config(original_config)
937
+ vocoder_config = SpeechT5HifiGanConfig(**vocoder_config)
938
+ converted_vocoder_checkpoint = convert_hifigan_checkpoint(checkpoint, vocoder_config)
939
+
940
+ vocoder = SpeechT5HifiGan(vocoder_config)
941
+ vocoder.load_state_dict(converted_vocoder_checkpoint)
942
+
943
+ # Instantiate the diffusers pipeline
944
+ pipe = AudioLDMPipeline(
945
+ vae=vae,
946
+ text_encoder=text_model,
947
+ tokenizer=tokenizer,
948
+ unet=unet,
949
+ scheduler=scheduler,
950
+ vocoder=vocoder,
951
+ )
952
+
953
+ return pipe
954
+
955
+
956
+ if __name__ == "__main__":
957
+ parser = argparse.ArgumentParser()
958
+
959
+ parser.add_argument(
960
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
961
+ )
962
+ parser.add_argument(
963
+ "--original_config_file",
964
+ default=None,
965
+ type=str,
966
+ help="The YAML config file corresponding to the original architecture.",
967
+ )
968
+ parser.add_argument(
969
+ "--num_in_channels",
970
+ default=None,
971
+ type=int,
972
+ help="The number of input channels. If `None` number of input channels will be automatically inferred.",
973
+ )
974
+ parser.add_argument(
975
+ "--model_channels",
976
+ default=None,
977
+ type=int,
978
+ help="The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override"
979
+ " to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.",
980
+ )
981
+ parser.add_argument(
982
+ "--num_head_channels",
983
+ default=None,
984
+ type=int,
985
+ help="The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override"
986
+ " to 32 for the small and medium checkpoints, and 64 for the large.",
987
+ )
988
+ parser.add_argument(
989
+ "--scheduler_type",
990
+ default="ddim",
991
+ type=str,
992
+ help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
993
+ )
994
+ parser.add_argument(
995
+ "--image_size",
996
+ default=None,
997
+ type=int,
998
+ help=("The image size that the model was trained on."),
999
+ )
1000
+ parser.add_argument(
1001
+ "--prediction_type",
1002
+ default=None,
1003
+ type=str,
1004
+ help=("The prediction type that the model was trained on."),
1005
+ )
1006
+ parser.add_argument(
1007
+ "--extract_ema",
1008
+ action="store_true",
1009
+ help=(
1010
+ "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
1011
+ " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
1012
+ " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
1013
+ ),
1014
+ )
1015
+ parser.add_argument(
1016
+ "--from_safetensors",
1017
+ action="store_true",
1018
+ help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
1019
+ )
1020
+ parser.add_argument(
1021
+ "--to_safetensors",
1022
+ action="store_true",
1023
+ help="Whether to store pipeline in safetensors format or not.",
1024
+ )
1025
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
1026
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
1027
+ args = parser.parse_args()
1028
+
1029
+ pipe = load_pipeline_from_original_audioldm_ckpt(
1030
+ checkpoint_path=args.checkpoint_path,
1031
+ original_config_file=args.original_config_file,
1032
+ image_size=args.image_size,
1033
+ prediction_type=args.prediction_type,
1034
+ extract_ema=args.extract_ema,
1035
+ scheduler_type=args.scheduler_type,
1036
+ num_in_channels=args.num_in_channels,
1037
+ model_channels=args.model_channels,
1038
+ num_head_channels=args.num_head_channels,
1039
+ from_safetensors=args.from_safetensors,
1040
+ device=args.device,
1041
+ )
1042
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)