Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- diffusers/.github/ISSUE_TEMPLATE/config.yml +4 -0
- diffusers/.github/ISSUE_TEMPLATE/feedback.md +12 -0
- diffusers/.github/ISSUE_TEMPLATE/translate.md +29 -0
- diffusers/benchmarks/README.md +69 -0
- diffusers/benchmarks/__init__.py +0 -0
- diffusers/benchmarks/benchmarking_flux.py +98 -0
- diffusers/benchmarks/benchmarking_ltx.py +80 -0
- diffusers/benchmarks/benchmarking_sdxl.py +82 -0
- diffusers/benchmarks/benchmarking_utils.py +244 -0
- diffusers/benchmarks/benchmarking_wan.py +74 -0
- diffusers/benchmarks/populate_into_db.py +166 -0
- diffusers/benchmarks/push_results.py +76 -0
- diffusers/benchmarks/requirements.txt +6 -0
- diffusers/benchmarks/run_all.py +84 -0
- diffusers/examples/README.md +70 -0
- diffusers/examples/conftest.py +45 -0
- diffusers/examples/test_examples_utils.py +61 -0
- diffusers/scripts/__init__.py +0 -0
- diffusers/scripts/change_naming_configs_and_checkpoints.py +113 -0
- diffusers/scripts/convert_amused.py +523 -0
- diffusers/scripts/convert_animatediff_motion_module_to_diffusers.py +62 -0
- diffusers/scripts/convert_animatediff_sparsectrl_to_diffusers.py +83 -0
- diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py +184 -0
- diffusers/scripts/convert_aura_flow_to_diffusers.py +131 -0
- diffusers/scripts/convert_blipdiffusion_to_diffusers.py +344 -0
- diffusers/scripts/convert_cogview3_to_diffusers.py +242 -0
- diffusers/scripts/convert_cogview4_to_diffusers.py +254 -0
- diffusers/scripts/convert_cogview4_to_diffusers_megatron.py +384 -0
- diffusers/scripts/convert_consistency_to_diffusers.py +315 -0
- diffusers/scripts/convert_cosmos_to_diffusers.py +506 -0
- diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py +431 -0
- diffusers/scripts/convert_diffusers_to_original_sdxl.py +350 -0
- diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py +353 -0
- diffusers/scripts/convert_dit_to_diffusers.py +162 -0
- diffusers/scripts/convert_flux_to_diffusers.py +308 -0
- diffusers/scripts/convert_gligen_to_diffusers.py +581 -0
- diffusers/scripts/convert_hunyuan_video_to_diffusers.py +353 -0
- diffusers/scripts/convert_hunyuandit_to_diffusers.py +266 -0
- diffusers/scripts/convert_k_upscaler_to_diffusers.py +297 -0
- diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py +1159 -0
- diffusers/scripts/convert_kandinsky3_unet.py +98 -0
- diffusers/scripts/convert_kandinsky_to_diffusers.py +1411 -0
- diffusers/scripts/convert_ldm_original_checkpoint_to_diffusers.py +359 -0
- diffusers/scripts/convert_ltx_to_diffusers.py +516 -0
- diffusers/scripts/convert_lumina_to_diffusers.py +142 -0
- diffusers/scripts/convert_mochi_to_diffusers.py +463 -0
- diffusers/scripts/convert_models_diffuser_to_diffusers.py +100 -0
- diffusers/scripts/convert_ms_text_to_video_to_diffusers.py +428 -0
- diffusers/scripts/convert_music_spectrogram_to_diffusers.py +203 -0
- diffusers/scripts/convert_original_audioldm_to_diffusers.py +1042 -0
diffusers/.github/ISSUE_TEMPLATE/config.yml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
contact_links:
|
| 2 |
+
- name: Questions / Discussions
|
| 3 |
+
url: https://github.com/huggingface/diffusers/discussions
|
| 4 |
+
about: General usage questions and community discussions
|
diffusers/.github/ISSUE_TEMPLATE/feedback.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: "💬 Feedback about API Design"
|
| 3 |
+
about: Give feedback about the current API design
|
| 4 |
+
title: ''
|
| 5 |
+
labels: ''
|
| 6 |
+
assignees: ''
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
**What API design would you like to have changed or added to the library? Why?**
|
| 11 |
+
|
| 12 |
+
**What use case would this enable or better enable? Can you give us a code example?**
|
diffusers/.github/ISSUE_TEMPLATE/translate.md
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: 🌐 Translating a New Language?
|
| 3 |
+
about: Start a new translation effort in your language
|
| 4 |
+
title: '[<languageCode>] Translating docs to <languageName>'
|
| 5 |
+
labels: WIP
|
| 6 |
+
assignees: ''
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
<!--
|
| 11 |
+
Note: Please search to see if an issue already exists for the language you are trying to translate.
|
| 12 |
+
-->
|
| 13 |
+
|
| 14 |
+
Hi!
|
| 15 |
+
|
| 16 |
+
Let's bring the documentation to all the <languageName>-speaking community 🌐.
|
| 17 |
+
|
| 18 |
+
Who would want to translate? Please follow the 🤗 [TRANSLATING guide](https://github.com/huggingface/diffusers/blob/main/docs/TRANSLATING.md). Here is a list of the files ready for translation. Let us know in this issue if you'd like to translate any, and we'll add your name to the list.
|
| 19 |
+
|
| 20 |
+
Some notes:
|
| 21 |
+
|
| 22 |
+
* Please translate using an informal tone (imagine you are talking with a friend about Diffusers 🤗).
|
| 23 |
+
* Please translate in a gender-neutral way.
|
| 24 |
+
* Add your translations to the folder called `<languageCode>` inside the [source folder](https://github.com/huggingface/diffusers/tree/main/docs/source).
|
| 25 |
+
* Register your translation in `<languageCode>/_toctree.yml`; please follow the order of the [English version](https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml).
|
| 26 |
+
* Once you're finished, open a pull request and tag this issue by including #issue-number in the description, where issue-number is the number of this issue. Please ping @stevhliu for review.
|
| 27 |
+
* 🙋 If you'd like others to help you with the translation, you can also post in the 🤗 [forums](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63).
|
| 28 |
+
|
| 29 |
+
Thank you so much for your help! 🤗
|
diffusers/benchmarks/README.md
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Diffusers Benchmarks
|
| 2 |
+
|
| 3 |
+
Welcome to Diffusers Benchmarks. These benchmarks are use to obtain latency and memory information of the most popular models across different scenarios such as:
|
| 4 |
+
|
| 5 |
+
* Base case i.e., when using `torch.bfloat16` and `torch.nn.functional.scaled_dot_product_attention`.
|
| 6 |
+
* Base + `torch.compile()`
|
| 7 |
+
* NF4 quantization
|
| 8 |
+
* Layerwise upcasting
|
| 9 |
+
|
| 10 |
+
Instead of full diffusion pipelines, only the forward pass of the respective model classes (such as `FluxTransformer2DModel`) is tested with the real checkpoints (such as `"black-forest-labs/FLUX.1-dev"`).
|
| 11 |
+
|
| 12 |
+
The entrypoint to running all the currently available benchmarks is in `run_all.py`. However, one can run the individual benchmarks, too, e.g., `python benchmarking_flux.py`. It should produce a CSV file containing various information about the benchmarks run.
|
| 13 |
+
|
| 14 |
+
The benchmarks are run on a weekly basis and the CI is defined in [benchmark.yml](../.github/workflows/benchmark.yml).
|
| 15 |
+
|
| 16 |
+
## Running the benchmarks manually
|
| 17 |
+
|
| 18 |
+
First set up `torch` and install `diffusers` from the root of the directory:
|
| 19 |
+
|
| 20 |
+
```py
|
| 21 |
+
pip install -e ".[quality,test]"
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
Then make sure the other dependencies are installed:
|
| 25 |
+
|
| 26 |
+
```sh
|
| 27 |
+
cd benchmarks/
|
| 28 |
+
pip install -r requirements.txt
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
We need to be authenticated to access some of the checkpoints used during benchmarking:
|
| 32 |
+
|
| 33 |
+
```sh
|
| 34 |
+
huggingface-cli login
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
We use an L40 GPU with 128GB RAM to run the benchmark CI. As such, the benchmarks are configured to run on NVIDIA GPUs. So, make sure you have access to a similar machine (or modify the benchmarking scripts accordingly).
|
| 38 |
+
|
| 39 |
+
Then you can either launch the entire benchmarking suite by running:
|
| 40 |
+
|
| 41 |
+
```sh
|
| 42 |
+
python run_all.py
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
Or, you can run the individual benchmarks.
|
| 46 |
+
|
| 47 |
+
## Customizing the benchmarks
|
| 48 |
+
|
| 49 |
+
We define "scenarios" to cover the most common ways in which these models are used. You can
|
| 50 |
+
define a new scenario, modifying an existing benchmark file:
|
| 51 |
+
|
| 52 |
+
```py
|
| 53 |
+
BenchmarkScenario(
|
| 54 |
+
name=f"{CKPT_ID}-bnb-8bit",
|
| 55 |
+
model_cls=FluxTransformer2DModel,
|
| 56 |
+
model_init_kwargs={
|
| 57 |
+
"pretrained_model_name_or_path": CKPT_ID,
|
| 58 |
+
"torch_dtype": torch.bfloat16,
|
| 59 |
+
"subfolder": "transformer",
|
| 60 |
+
"quantization_config": BitsAndBytesConfig(load_in_8bit=True),
|
| 61 |
+
},
|
| 62 |
+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
| 63 |
+
model_init_fn=model_init_fn,
|
| 64 |
+
)
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
You can also configure a new model-level benchmark and add it to the existing suite. To do so, just defining a valid benchmarking file like `benchmarking_flux.py` should be enough.
|
| 68 |
+
|
| 69 |
+
Happy benchmarking 🧨
|
diffusers/benchmarks/__init__.py
ADDED
|
File without changes
|
diffusers/benchmarks/benchmarking_flux.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
|
| 5 |
+
|
| 6 |
+
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel
|
| 7 |
+
from diffusers.utils.testing_utils import torch_device
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
CKPT_ID = "black-forest-labs/FLUX.1-dev"
|
| 11 |
+
RESULT_FILENAME = "flux.csv"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_input_dict(**device_dtype_kwargs):
|
| 15 |
+
# resolution: 1024x1024
|
| 16 |
+
# maximum sequence length 512
|
| 17 |
+
hidden_states = torch.randn(1, 4096, 64, **device_dtype_kwargs)
|
| 18 |
+
encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)
|
| 19 |
+
pooled_prompt_embeds = torch.randn(1, 768, **device_dtype_kwargs)
|
| 20 |
+
image_ids = torch.ones(512, 3, **device_dtype_kwargs)
|
| 21 |
+
text_ids = torch.ones(4096, 3, **device_dtype_kwargs)
|
| 22 |
+
timestep = torch.tensor([1.0], **device_dtype_kwargs)
|
| 23 |
+
guidance = torch.tensor([1.0], **device_dtype_kwargs)
|
| 24 |
+
|
| 25 |
+
return {
|
| 26 |
+
"hidden_states": hidden_states,
|
| 27 |
+
"encoder_hidden_states": encoder_hidden_states,
|
| 28 |
+
"img_ids": image_ids,
|
| 29 |
+
"txt_ids": text_ids,
|
| 30 |
+
"pooled_projections": pooled_prompt_embeds,
|
| 31 |
+
"timestep": timestep,
|
| 32 |
+
"guidance": guidance,
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
scenarios = [
|
| 38 |
+
BenchmarkScenario(
|
| 39 |
+
name=f"{CKPT_ID}-bf16",
|
| 40 |
+
model_cls=FluxTransformer2DModel,
|
| 41 |
+
model_init_kwargs={
|
| 42 |
+
"pretrained_model_name_or_path": CKPT_ID,
|
| 43 |
+
"torch_dtype": torch.bfloat16,
|
| 44 |
+
"subfolder": "transformer",
|
| 45 |
+
},
|
| 46 |
+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
| 47 |
+
model_init_fn=model_init_fn,
|
| 48 |
+
compile_kwargs={"fullgraph": True},
|
| 49 |
+
),
|
| 50 |
+
BenchmarkScenario(
|
| 51 |
+
name=f"{CKPT_ID}-bnb-nf4",
|
| 52 |
+
model_cls=FluxTransformer2DModel,
|
| 53 |
+
model_init_kwargs={
|
| 54 |
+
"pretrained_model_name_or_path": CKPT_ID,
|
| 55 |
+
"torch_dtype": torch.bfloat16,
|
| 56 |
+
"subfolder": "transformer",
|
| 57 |
+
"quantization_config": BitsAndBytesConfig(
|
| 58 |
+
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4"
|
| 59 |
+
),
|
| 60 |
+
},
|
| 61 |
+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
| 62 |
+
model_init_fn=model_init_fn,
|
| 63 |
+
),
|
| 64 |
+
BenchmarkScenario(
|
| 65 |
+
name=f"{CKPT_ID}-layerwise-upcasting",
|
| 66 |
+
model_cls=FluxTransformer2DModel,
|
| 67 |
+
model_init_kwargs={
|
| 68 |
+
"pretrained_model_name_or_path": CKPT_ID,
|
| 69 |
+
"torch_dtype": torch.bfloat16,
|
| 70 |
+
"subfolder": "transformer",
|
| 71 |
+
},
|
| 72 |
+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
| 73 |
+
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
|
| 74 |
+
),
|
| 75 |
+
BenchmarkScenario(
|
| 76 |
+
name=f"{CKPT_ID}-group-offload-leaf",
|
| 77 |
+
model_cls=FluxTransformer2DModel,
|
| 78 |
+
model_init_kwargs={
|
| 79 |
+
"pretrained_model_name_or_path": CKPT_ID,
|
| 80 |
+
"torch_dtype": torch.bfloat16,
|
| 81 |
+
"subfolder": "transformer",
|
| 82 |
+
},
|
| 83 |
+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
| 84 |
+
model_init_fn=partial(
|
| 85 |
+
model_init_fn,
|
| 86 |
+
group_offload_kwargs={
|
| 87 |
+
"onload_device": torch_device,
|
| 88 |
+
"offload_device": torch.device("cpu"),
|
| 89 |
+
"offload_type": "leaf_level",
|
| 90 |
+
"use_stream": True,
|
| 91 |
+
"non_blocking": True,
|
| 92 |
+
},
|
| 93 |
+
),
|
| 94 |
+
),
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
runner = BenchmarkMixin()
|
| 98 |
+
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
|
diffusers/benchmarks/benchmarking_ltx.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
|
| 5 |
+
|
| 6 |
+
from diffusers import LTXVideoTransformer3DModel
|
| 7 |
+
from diffusers.utils.testing_utils import torch_device
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
CKPT_ID = "Lightricks/LTX-Video-0.9.7-dev"
|
| 11 |
+
RESULT_FILENAME = "ltx.csv"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_input_dict(**device_dtype_kwargs):
|
| 15 |
+
# 512x704 (161 frames)
|
| 16 |
+
# `max_sequence_length`: 256
|
| 17 |
+
hidden_states = torch.randn(1, 7392, 128, **device_dtype_kwargs)
|
| 18 |
+
encoder_hidden_states = torch.randn(1, 256, 4096, **device_dtype_kwargs)
|
| 19 |
+
encoder_attention_mask = torch.ones(1, 256, **device_dtype_kwargs)
|
| 20 |
+
timestep = torch.tensor([1.0], **device_dtype_kwargs)
|
| 21 |
+
video_coords = torch.randn(1, 3, 7392, **device_dtype_kwargs)
|
| 22 |
+
|
| 23 |
+
return {
|
| 24 |
+
"hidden_states": hidden_states,
|
| 25 |
+
"encoder_hidden_states": encoder_hidden_states,
|
| 26 |
+
"encoder_attention_mask": encoder_attention_mask,
|
| 27 |
+
"timestep": timestep,
|
| 28 |
+
"video_coords": video_coords,
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
scenarios = [
|
| 34 |
+
BenchmarkScenario(
|
| 35 |
+
name=f"{CKPT_ID}-bf16",
|
| 36 |
+
model_cls=LTXVideoTransformer3DModel,
|
| 37 |
+
model_init_kwargs={
|
| 38 |
+
"pretrained_model_name_or_path": CKPT_ID,
|
| 39 |
+
"torch_dtype": torch.bfloat16,
|
| 40 |
+
"subfolder": "transformer",
|
| 41 |
+
},
|
| 42 |
+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
| 43 |
+
model_init_fn=model_init_fn,
|
| 44 |
+
compile_kwargs={"fullgraph": True},
|
| 45 |
+
),
|
| 46 |
+
BenchmarkScenario(
|
| 47 |
+
name=f"{CKPT_ID}-layerwise-upcasting",
|
| 48 |
+
model_cls=LTXVideoTransformer3DModel,
|
| 49 |
+
model_init_kwargs={
|
| 50 |
+
"pretrained_model_name_or_path": CKPT_ID,
|
| 51 |
+
"torch_dtype": torch.bfloat16,
|
| 52 |
+
"subfolder": "transformer",
|
| 53 |
+
},
|
| 54 |
+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
| 55 |
+
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
|
| 56 |
+
),
|
| 57 |
+
BenchmarkScenario(
|
| 58 |
+
name=f"{CKPT_ID}-group-offload-leaf",
|
| 59 |
+
model_cls=LTXVideoTransformer3DModel,
|
| 60 |
+
model_init_kwargs={
|
| 61 |
+
"pretrained_model_name_or_path": CKPT_ID,
|
| 62 |
+
"torch_dtype": torch.bfloat16,
|
| 63 |
+
"subfolder": "transformer",
|
| 64 |
+
},
|
| 65 |
+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
| 66 |
+
model_init_fn=partial(
|
| 67 |
+
model_init_fn,
|
| 68 |
+
group_offload_kwargs={
|
| 69 |
+
"onload_device": torch_device,
|
| 70 |
+
"offload_device": torch.device("cpu"),
|
| 71 |
+
"offload_type": "leaf_level",
|
| 72 |
+
"use_stream": True,
|
| 73 |
+
"non_blocking": True,
|
| 74 |
+
},
|
| 75 |
+
),
|
| 76 |
+
),
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
runner = BenchmarkMixin()
|
| 80 |
+
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
|
diffusers/benchmarks/benchmarking_sdxl.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
|
| 5 |
+
|
| 6 |
+
from diffusers import UNet2DConditionModel
|
| 7 |
+
from diffusers.utils.testing_utils import torch_device
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
CKPT_ID = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 11 |
+
RESULT_FILENAME = "sdxl.csv"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_input_dict(**device_dtype_kwargs):
|
| 15 |
+
# height: 1024
|
| 16 |
+
# width: 1024
|
| 17 |
+
# max_sequence_length: 77
|
| 18 |
+
hidden_states = torch.randn(1, 4, 128, 128, **device_dtype_kwargs)
|
| 19 |
+
encoder_hidden_states = torch.randn(1, 77, 2048, **device_dtype_kwargs)
|
| 20 |
+
timestep = torch.tensor([1.0], **device_dtype_kwargs)
|
| 21 |
+
added_cond_kwargs = {
|
| 22 |
+
"text_embeds": torch.randn(1, 1280, **device_dtype_kwargs),
|
| 23 |
+
"time_ids": torch.ones(1, 6, **device_dtype_kwargs),
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
return {
|
| 27 |
+
"sample": hidden_states,
|
| 28 |
+
"encoder_hidden_states": encoder_hidden_states,
|
| 29 |
+
"timestep": timestep,
|
| 30 |
+
"added_cond_kwargs": added_cond_kwargs,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
scenarios = [
|
| 36 |
+
BenchmarkScenario(
|
| 37 |
+
name=f"{CKPT_ID}-bf16",
|
| 38 |
+
model_cls=UNet2DConditionModel,
|
| 39 |
+
model_init_kwargs={
|
| 40 |
+
"pretrained_model_name_or_path": CKPT_ID,
|
| 41 |
+
"torch_dtype": torch.bfloat16,
|
| 42 |
+
"subfolder": "unet",
|
| 43 |
+
},
|
| 44 |
+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
| 45 |
+
model_init_fn=model_init_fn,
|
| 46 |
+
compile_kwargs={"fullgraph": True},
|
| 47 |
+
),
|
| 48 |
+
BenchmarkScenario(
|
| 49 |
+
name=f"{CKPT_ID}-layerwise-upcasting",
|
| 50 |
+
model_cls=UNet2DConditionModel,
|
| 51 |
+
model_init_kwargs={
|
| 52 |
+
"pretrained_model_name_or_path": CKPT_ID,
|
| 53 |
+
"torch_dtype": torch.bfloat16,
|
| 54 |
+
"subfolder": "unet",
|
| 55 |
+
},
|
| 56 |
+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
| 57 |
+
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
|
| 58 |
+
),
|
| 59 |
+
BenchmarkScenario(
|
| 60 |
+
name=f"{CKPT_ID}-group-offload-leaf",
|
| 61 |
+
model_cls=UNet2DConditionModel,
|
| 62 |
+
model_init_kwargs={
|
| 63 |
+
"pretrained_model_name_or_path": CKPT_ID,
|
| 64 |
+
"torch_dtype": torch.bfloat16,
|
| 65 |
+
"subfolder": "unet",
|
| 66 |
+
},
|
| 67 |
+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
| 68 |
+
model_init_fn=partial(
|
| 69 |
+
model_init_fn,
|
| 70 |
+
group_offload_kwargs={
|
| 71 |
+
"onload_device": torch_device,
|
| 72 |
+
"offload_device": torch.device("cpu"),
|
| 73 |
+
"offload_type": "leaf_level",
|
| 74 |
+
"use_stream": True,
|
| 75 |
+
"non_blocking": True,
|
| 76 |
+
},
|
| 77 |
+
),
|
| 78 |
+
),
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
runner = BenchmarkMixin()
|
| 82 |
+
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
|
diffusers/benchmarks/benchmarking_utils.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import inspect
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import queue
|
| 6 |
+
import threading
|
| 7 |
+
from contextlib import nullcontext
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Any, Callable, Dict, Optional, Union
|
| 10 |
+
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import torch
|
| 13 |
+
import torch.utils.benchmark as benchmark
|
| 14 |
+
|
| 15 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 16 |
+
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
NUM_WARMUP_ROUNDS = 5
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def benchmark_fn(f, *args, **kwargs):
|
| 26 |
+
t0 = benchmark.Timer(
|
| 27 |
+
stmt="f(*args, **kwargs)",
|
| 28 |
+
globals={"args": args, "kwargs": kwargs, "f": f},
|
| 29 |
+
num_threads=1,
|
| 30 |
+
)
|
| 31 |
+
return float(f"{(t0.blocked_autorange().mean):.3f}")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def flush():
|
| 35 |
+
gc.collect()
|
| 36 |
+
torch.cuda.empty_cache()
|
| 37 |
+
torch.cuda.reset_max_memory_allocated()
|
| 38 |
+
torch.cuda.reset_peak_memory_stats()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Adapted from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py
|
| 42 |
+
def calculate_flops(model, input_dict):
|
| 43 |
+
try:
|
| 44 |
+
from torchprofile import profile_macs
|
| 45 |
+
except ModuleNotFoundError:
|
| 46 |
+
raise
|
| 47 |
+
|
| 48 |
+
# This is a hacky way to convert the kwargs to args as `profile_macs` cries about kwargs.
|
| 49 |
+
sig = inspect.signature(model.forward)
|
| 50 |
+
param_names = [
|
| 51 |
+
p.name
|
| 52 |
+
for p in sig.parameters.values()
|
| 53 |
+
if p.kind
|
| 54 |
+
in (
|
| 55 |
+
inspect.Parameter.POSITIONAL_ONLY,
|
| 56 |
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
| 57 |
+
)
|
| 58 |
+
and p.name != "self"
|
| 59 |
+
]
|
| 60 |
+
bound = sig.bind_partial(**input_dict)
|
| 61 |
+
bound.apply_defaults()
|
| 62 |
+
args = tuple(bound.arguments[name] for name in param_names)
|
| 63 |
+
|
| 64 |
+
model.eval()
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
macs = profile_macs(model, args)
|
| 67 |
+
flops = 2 * macs # 1 MAC operation = 2 FLOPs (1 multiplication + 1 addition)
|
| 68 |
+
return flops
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def calculate_params(model):
|
| 72 |
+
return sum(p.numel() for p in model.parameters())
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Users can define their own in case this doesn't suffice. For most cases,
|
| 76 |
+
# it should be sufficient.
|
| 77 |
+
def model_init_fn(model_cls, group_offload_kwargs=None, layerwise_upcasting=False, **init_kwargs):
|
| 78 |
+
model = model_cls.from_pretrained(**init_kwargs).eval()
|
| 79 |
+
if group_offload_kwargs and isinstance(group_offload_kwargs, dict):
|
| 80 |
+
model.enable_group_offload(**group_offload_kwargs)
|
| 81 |
+
else:
|
| 82 |
+
model.to(torch_device)
|
| 83 |
+
if layerwise_upcasting:
|
| 84 |
+
model.enable_layerwise_casting(
|
| 85 |
+
storage_dtype=torch.float8_e4m3fn, compute_dtype=init_kwargs.get("torch_dtype", torch.bfloat16)
|
| 86 |
+
)
|
| 87 |
+
return model
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@dataclass
|
| 91 |
+
class BenchmarkScenario:
|
| 92 |
+
name: str
|
| 93 |
+
model_cls: ModelMixin
|
| 94 |
+
model_init_kwargs: Dict[str, Any]
|
| 95 |
+
model_init_fn: Callable
|
| 96 |
+
get_model_input_dict: Callable
|
| 97 |
+
compile_kwargs: Optional[Dict[str, Any]] = None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@require_torch_gpu
|
| 101 |
+
class BenchmarkMixin:
|
| 102 |
+
def pre_benchmark(self):
|
| 103 |
+
flush()
|
| 104 |
+
torch.compiler.reset()
|
| 105 |
+
|
| 106 |
+
def post_benchmark(self, model):
|
| 107 |
+
model.cpu()
|
| 108 |
+
flush()
|
| 109 |
+
torch.compiler.reset()
|
| 110 |
+
|
| 111 |
+
@torch.no_grad()
|
| 112 |
+
def run_benchmark(self, scenario: BenchmarkScenario):
|
| 113 |
+
# 0) Basic stats
|
| 114 |
+
logger.info(f"Running scenario: {scenario.name}.")
|
| 115 |
+
try:
|
| 116 |
+
model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs)
|
| 117 |
+
num_params = round(calculate_params(model) / 1e9, 2)
|
| 118 |
+
try:
|
| 119 |
+
flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e9, 2)
|
| 120 |
+
except Exception as e:
|
| 121 |
+
logger.info(f"Problem in calculating FLOPs:\n{e}")
|
| 122 |
+
flops = None
|
| 123 |
+
model.cpu()
|
| 124 |
+
del model
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.info(f"Error while initializing the model and calculating FLOPs:\n{e}")
|
| 127 |
+
return {}
|
| 128 |
+
self.pre_benchmark()
|
| 129 |
+
|
| 130 |
+
# 1) plain stats
|
| 131 |
+
results = {}
|
| 132 |
+
plain = None
|
| 133 |
+
try:
|
| 134 |
+
plain = self._run_phase(
|
| 135 |
+
model_cls=scenario.model_cls,
|
| 136 |
+
init_fn=scenario.model_init_fn,
|
| 137 |
+
init_kwargs=scenario.model_init_kwargs,
|
| 138 |
+
get_input_fn=scenario.get_model_input_dict,
|
| 139 |
+
compile_kwargs=None,
|
| 140 |
+
)
|
| 141 |
+
except Exception as e:
|
| 142 |
+
logger.info(f"Benchmark could not be run with the following error:\n{e}")
|
| 143 |
+
return results
|
| 144 |
+
|
| 145 |
+
# 2) compiled stats (if any)
|
| 146 |
+
compiled = {"time": None, "memory": None}
|
| 147 |
+
if scenario.compile_kwargs:
|
| 148 |
+
try:
|
| 149 |
+
compiled = self._run_phase(
|
| 150 |
+
model_cls=scenario.model_cls,
|
| 151 |
+
init_fn=scenario.model_init_fn,
|
| 152 |
+
init_kwargs=scenario.model_init_kwargs,
|
| 153 |
+
get_input_fn=scenario.get_model_input_dict,
|
| 154 |
+
compile_kwargs=scenario.compile_kwargs,
|
| 155 |
+
)
|
| 156 |
+
except Exception as e:
|
| 157 |
+
logger.info(f"Compilation benchmark could not be run with the following error\n: {e}")
|
| 158 |
+
if plain is None:
|
| 159 |
+
return results
|
| 160 |
+
|
| 161 |
+
# 3) merge
|
| 162 |
+
result = {
|
| 163 |
+
"scenario": scenario.name,
|
| 164 |
+
"model_cls": scenario.model_cls.__name__,
|
| 165 |
+
"num_params_B": num_params,
|
| 166 |
+
"flops_G": flops,
|
| 167 |
+
"time_plain_s": plain["time"],
|
| 168 |
+
"mem_plain_GB": plain["memory"],
|
| 169 |
+
"time_compile_s": compiled["time"],
|
| 170 |
+
"mem_compile_GB": compiled["memory"],
|
| 171 |
+
}
|
| 172 |
+
if scenario.compile_kwargs:
|
| 173 |
+
result["fullgraph"] = scenario.compile_kwargs.get("fullgraph", False)
|
| 174 |
+
result["mode"] = scenario.compile_kwargs.get("mode", "default")
|
| 175 |
+
else:
|
| 176 |
+
result["fullgraph"], result["mode"] = None, None
|
| 177 |
+
return result
|
| 178 |
+
|
| 179 |
+
def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[BenchmarkScenario]], filename: str):
|
| 180 |
+
if not isinstance(scenarios, list):
|
| 181 |
+
scenarios = [scenarios]
|
| 182 |
+
record_queue = queue.Queue()
|
| 183 |
+
stop_signal = object()
|
| 184 |
+
|
| 185 |
+
def _writer_thread():
|
| 186 |
+
while True:
|
| 187 |
+
item = record_queue.get()
|
| 188 |
+
if item is stop_signal:
|
| 189 |
+
break
|
| 190 |
+
df_row = pd.DataFrame([item])
|
| 191 |
+
write_header = not os.path.exists(filename)
|
| 192 |
+
df_row.to_csv(filename, mode="a", header=write_header, index=False)
|
| 193 |
+
record_queue.task_done()
|
| 194 |
+
|
| 195 |
+
record_queue.task_done()
|
| 196 |
+
|
| 197 |
+
writer = threading.Thread(target=_writer_thread, daemon=True)
|
| 198 |
+
writer.start()
|
| 199 |
+
|
| 200 |
+
for s in scenarios:
|
| 201 |
+
try:
|
| 202 |
+
record = self.run_benchmark(s)
|
| 203 |
+
if record:
|
| 204 |
+
record_queue.put(record)
|
| 205 |
+
else:
|
| 206 |
+
logger.info(f"Record empty from scenario: {s.name}.")
|
| 207 |
+
except Exception as e:
|
| 208 |
+
logger.info(f"Running scenario ({s.name}) led to error:\n{e}")
|
| 209 |
+
record_queue.put(stop_signal)
|
| 210 |
+
logger.info(f"Results serialized to {filename=}.")
|
| 211 |
+
|
| 212 |
+
def _run_phase(
|
| 213 |
+
self,
|
| 214 |
+
*,
|
| 215 |
+
model_cls: ModelMixin,
|
| 216 |
+
init_fn: Callable,
|
| 217 |
+
init_kwargs: Dict[str, Any],
|
| 218 |
+
get_input_fn: Callable,
|
| 219 |
+
compile_kwargs: Optional[Dict[str, Any]],
|
| 220 |
+
) -> Dict[str, float]:
|
| 221 |
+
# setup
|
| 222 |
+
self.pre_benchmark()
|
| 223 |
+
|
| 224 |
+
# init & (optional) compile
|
| 225 |
+
model = init_fn(model_cls, **init_kwargs)
|
| 226 |
+
if compile_kwargs:
|
| 227 |
+
model.compile(**compile_kwargs)
|
| 228 |
+
|
| 229 |
+
# build inputs
|
| 230 |
+
inp = get_input_fn()
|
| 231 |
+
|
| 232 |
+
# measure
|
| 233 |
+
run_ctx = torch._inductor.utils.fresh_inductor_cache() if compile_kwargs else nullcontext()
|
| 234 |
+
with run_ctx:
|
| 235 |
+
for _ in range(NUM_WARMUP_ROUNDS):
|
| 236 |
+
_ = model(**inp)
|
| 237 |
+
time_s = benchmark_fn(lambda m, d: m(**d), model, inp)
|
| 238 |
+
mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
| 239 |
+
mem_gb = round(mem_gb, 2)
|
| 240 |
+
|
| 241 |
+
# teardown
|
| 242 |
+
self.post_benchmark(model)
|
| 243 |
+
del model
|
| 244 |
+
return {"time": time_s, "memory": mem_gb}
|
diffusers/benchmarks/benchmarking_wan.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
|
| 5 |
+
|
| 6 |
+
from diffusers import WanTransformer3DModel
|
| 7 |
+
from diffusers.utils.testing_utils import torch_device
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
CKPT_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
| 11 |
+
RESULT_FILENAME = "wan.csv"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_input_dict(**device_dtype_kwargs):
|
| 15 |
+
# height: 480
|
| 16 |
+
# width: 832
|
| 17 |
+
# num_frames: 81
|
| 18 |
+
# max_sequence_length: 512
|
| 19 |
+
hidden_states = torch.randn(1, 16, 21, 60, 104, **device_dtype_kwargs)
|
| 20 |
+
encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)
|
| 21 |
+
timestep = torch.tensor([1.0], **device_dtype_kwargs)
|
| 22 |
+
|
| 23 |
+
return {"hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if __name__ == "__main__":
|
| 27 |
+
scenarios = [
|
| 28 |
+
BenchmarkScenario(
|
| 29 |
+
name=f"{CKPT_ID}-bf16",
|
| 30 |
+
model_cls=WanTransformer3DModel,
|
| 31 |
+
model_init_kwargs={
|
| 32 |
+
"pretrained_model_name_or_path": CKPT_ID,
|
| 33 |
+
"torch_dtype": torch.bfloat16,
|
| 34 |
+
"subfolder": "transformer",
|
| 35 |
+
},
|
| 36 |
+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
| 37 |
+
model_init_fn=model_init_fn,
|
| 38 |
+
compile_kwargs={"fullgraph": True},
|
| 39 |
+
),
|
| 40 |
+
BenchmarkScenario(
|
| 41 |
+
name=f"{CKPT_ID}-layerwise-upcasting",
|
| 42 |
+
model_cls=WanTransformer3DModel,
|
| 43 |
+
model_init_kwargs={
|
| 44 |
+
"pretrained_model_name_or_path": CKPT_ID,
|
| 45 |
+
"torch_dtype": torch.bfloat16,
|
| 46 |
+
"subfolder": "transformer",
|
| 47 |
+
},
|
| 48 |
+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
| 49 |
+
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
|
| 50 |
+
),
|
| 51 |
+
BenchmarkScenario(
|
| 52 |
+
name=f"{CKPT_ID}-group-offload-leaf",
|
| 53 |
+
model_cls=WanTransformer3DModel,
|
| 54 |
+
model_init_kwargs={
|
| 55 |
+
"pretrained_model_name_or_path": CKPT_ID,
|
| 56 |
+
"torch_dtype": torch.bfloat16,
|
| 57 |
+
"subfolder": "transformer",
|
| 58 |
+
},
|
| 59 |
+
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
| 60 |
+
model_init_fn=partial(
|
| 61 |
+
model_init_fn,
|
| 62 |
+
group_offload_kwargs={
|
| 63 |
+
"onload_device": torch_device,
|
| 64 |
+
"offload_device": torch.device("cpu"),
|
| 65 |
+
"offload_type": "leaf_level",
|
| 66 |
+
"use_stream": True,
|
| 67 |
+
"non_blocking": True,
|
| 68 |
+
},
|
| 69 |
+
),
|
| 70 |
+
),
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
runner = BenchmarkMixin()
|
| 74 |
+
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
|
diffusers/benchmarks/populate_into_db.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import gpustat
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import psycopg2
|
| 8 |
+
import psycopg2.extras
|
| 9 |
+
from psycopg2.extensions import register_adapter
|
| 10 |
+
from psycopg2.extras import Json
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
register_adapter(dict, Json)
|
| 14 |
+
|
| 15 |
+
FINAL_CSV_FILENAME = "collated_results.csv"
|
| 16 |
+
# https://github.com/huggingface/transformers/blob/593e29c5e2a9b17baec010e8dc7c1431fed6e841/benchmark/init_db.sql#L27
|
| 17 |
+
BENCHMARKS_TABLE_NAME = "benchmarks"
|
| 18 |
+
MEASUREMENTS_TABLE_NAME = "model_measurements"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _init_benchmark(conn, branch, commit_id, commit_msg):
|
| 22 |
+
gpu_stats = gpustat.GPUStatCollection.new_query()
|
| 23 |
+
metadata = {"gpu_name": gpu_stats[0]["name"]}
|
| 24 |
+
repository = "huggingface/diffusers"
|
| 25 |
+
with conn.cursor() as cur:
|
| 26 |
+
cur.execute(
|
| 27 |
+
f"INSERT INTO {BENCHMARKS_TABLE_NAME} (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id",
|
| 28 |
+
(repository, branch, commit_id, commit_msg, metadata),
|
| 29 |
+
)
|
| 30 |
+
benchmark_id = cur.fetchone()[0]
|
| 31 |
+
print(f"Initialised benchmark #{benchmark_id}")
|
| 32 |
+
return benchmark_id
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def parse_args():
|
| 36 |
+
parser = argparse.ArgumentParser()
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"branch",
|
| 39 |
+
type=str,
|
| 40 |
+
help="The branch name on which the benchmarking is performed.",
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"commit_id",
|
| 45 |
+
type=str,
|
| 46 |
+
help="The commit hash on which the benchmarking is performed.",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"commit_msg",
|
| 51 |
+
type=str,
|
| 52 |
+
help="The commit message associated with the commit, truncated to 70 characters.",
|
| 53 |
+
)
|
| 54 |
+
args = parser.parse_args()
|
| 55 |
+
return args
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
args = parse_args()
|
| 60 |
+
try:
|
| 61 |
+
conn = psycopg2.connect(
|
| 62 |
+
host=os.getenv("PGHOST"),
|
| 63 |
+
database=os.getenv("PGDATABASE"),
|
| 64 |
+
user=os.getenv("PGUSER"),
|
| 65 |
+
password=os.getenv("PGPASSWORD"),
|
| 66 |
+
)
|
| 67 |
+
print("DB connection established successfully.")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"Problem during DB init: {e}")
|
| 70 |
+
sys.exit(1)
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
benchmark_id = _init_benchmark(
|
| 74 |
+
conn=conn,
|
| 75 |
+
branch=args.branch,
|
| 76 |
+
commit_id=args.commit_id,
|
| 77 |
+
commit_msg=args.commit_msg,
|
| 78 |
+
)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"Problem during initializing benchmark: {e}")
|
| 81 |
+
sys.exit(1)
|
| 82 |
+
|
| 83 |
+
cur = conn.cursor()
|
| 84 |
+
|
| 85 |
+
df = pd.read_csv(FINAL_CSV_FILENAME)
|
| 86 |
+
|
| 87 |
+
# Helper to cast values (or None) given a dtype
|
| 88 |
+
def _cast_value(val, dtype: str):
|
| 89 |
+
if pd.isna(val):
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
if dtype == "text":
|
| 93 |
+
return str(val).strip()
|
| 94 |
+
|
| 95 |
+
if dtype == "float":
|
| 96 |
+
try:
|
| 97 |
+
return float(val)
|
| 98 |
+
except ValueError:
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
if dtype == "bool":
|
| 102 |
+
s = str(val).strip().lower()
|
| 103 |
+
if s in ("true", "t", "yes", "1"):
|
| 104 |
+
return True
|
| 105 |
+
if s in ("false", "f", "no", "0"):
|
| 106 |
+
return False
|
| 107 |
+
if val in (1, 1.0):
|
| 108 |
+
return True
|
| 109 |
+
if val in (0, 0.0):
|
| 110 |
+
return False
|
| 111 |
+
return None
|
| 112 |
+
|
| 113 |
+
return val
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
rows_to_insert = []
|
| 117 |
+
for _, row in df.iterrows():
|
| 118 |
+
scenario = _cast_value(row.get("scenario"), "text")
|
| 119 |
+
model_cls = _cast_value(row.get("model_cls"), "text")
|
| 120 |
+
num_params_B = _cast_value(row.get("num_params_B"), "float")
|
| 121 |
+
flops_G = _cast_value(row.get("flops_G"), "float")
|
| 122 |
+
time_plain_s = _cast_value(row.get("time_plain_s"), "float")
|
| 123 |
+
mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float")
|
| 124 |
+
time_compile_s = _cast_value(row.get("time_compile_s"), "float")
|
| 125 |
+
mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float")
|
| 126 |
+
fullgraph = _cast_value(row.get("fullgraph"), "bool")
|
| 127 |
+
mode = _cast_value(row.get("mode"), "text")
|
| 128 |
+
|
| 129 |
+
# If "github_sha" column exists in the CSV, cast it; else default to None
|
| 130 |
+
if "github_sha" in df.columns:
|
| 131 |
+
github_sha = _cast_value(row.get("github_sha"), "text")
|
| 132 |
+
else:
|
| 133 |
+
github_sha = None
|
| 134 |
+
|
| 135 |
+
measurements = {
|
| 136 |
+
"scenario": scenario,
|
| 137 |
+
"model_cls": model_cls,
|
| 138 |
+
"num_params_B": num_params_B,
|
| 139 |
+
"flops_G": flops_G,
|
| 140 |
+
"time_plain_s": time_plain_s,
|
| 141 |
+
"mem_plain_GB": mem_plain_GB,
|
| 142 |
+
"time_compile_s": time_compile_s,
|
| 143 |
+
"mem_compile_GB": mem_compile_GB,
|
| 144 |
+
"fullgraph": fullgraph,
|
| 145 |
+
"mode": mode,
|
| 146 |
+
"github_sha": github_sha,
|
| 147 |
+
}
|
| 148 |
+
rows_to_insert.append((benchmark_id, measurements))
|
| 149 |
+
|
| 150 |
+
# Batch-insert all rows
|
| 151 |
+
insert_sql = f"""
|
| 152 |
+
INSERT INTO {MEASUREMENTS_TABLE_NAME} (
|
| 153 |
+
benchmark_id,
|
| 154 |
+
measurements
|
| 155 |
+
)
|
| 156 |
+
VALUES (%s, %s);
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert)
|
| 160 |
+
conn.commit()
|
| 161 |
+
|
| 162 |
+
cur.close()
|
| 163 |
+
conn.close()
|
| 164 |
+
except Exception as e:
|
| 165 |
+
print(f"Exception: {e}")
|
| 166 |
+
sys.exit(1)
|
diffusers/benchmarks/push_results.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from huggingface_hub import hf_hub_download, upload_file
|
| 5 |
+
from huggingface_hub.utils import EntryNotFoundError
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
REPO_ID = "diffusers/benchmarks"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def has_previous_benchmark() -> str:
|
| 12 |
+
from run_all import FINAL_CSV_FILENAME
|
| 13 |
+
|
| 14 |
+
csv_path = None
|
| 15 |
+
try:
|
| 16 |
+
csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILENAME)
|
| 17 |
+
except EntryNotFoundError:
|
| 18 |
+
csv_path = None
|
| 19 |
+
return csv_path
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def filter_float(value):
|
| 23 |
+
if isinstance(value, str):
|
| 24 |
+
return float(value.split()[0])
|
| 25 |
+
return value
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def push_to_hf_dataset():
|
| 29 |
+
from run_all import FINAL_CSV_FILENAME, GITHUB_SHA
|
| 30 |
+
|
| 31 |
+
csv_path = has_previous_benchmark()
|
| 32 |
+
if csv_path is not None:
|
| 33 |
+
current_results = pd.read_csv(FINAL_CSV_FILENAME)
|
| 34 |
+
previous_results = pd.read_csv(csv_path)
|
| 35 |
+
|
| 36 |
+
numeric_columns = current_results.select_dtypes(include=["float64", "int64"]).columns
|
| 37 |
+
|
| 38 |
+
for column in numeric_columns:
|
| 39 |
+
# get previous values as floats, aligned to current index
|
| 40 |
+
prev_vals = previous_results[column].map(filter_float).reindex(current_results.index)
|
| 41 |
+
|
| 42 |
+
# get current values as floats
|
| 43 |
+
curr_vals = current_results[column].astype(float)
|
| 44 |
+
|
| 45 |
+
# stringify the current values
|
| 46 |
+
curr_str = curr_vals.map(str)
|
| 47 |
+
|
| 48 |
+
# build an appendage only when prev exists and differs
|
| 49 |
+
append_str = prev_vals.where(prev_vals.notnull() & (prev_vals != curr_vals), other=pd.NA).map(
|
| 50 |
+
lambda x: f" ({x})" if pd.notnull(x) else ""
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# combine
|
| 54 |
+
current_results[column] = curr_str + append_str
|
| 55 |
+
os.remove(FINAL_CSV_FILENAME)
|
| 56 |
+
current_results.to_csv(FINAL_CSV_FILENAME, index=False)
|
| 57 |
+
|
| 58 |
+
commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results"
|
| 59 |
+
upload_file(
|
| 60 |
+
repo_id=REPO_ID,
|
| 61 |
+
path_in_repo=FINAL_CSV_FILENAME,
|
| 62 |
+
path_or_fileobj=FINAL_CSV_FILENAME,
|
| 63 |
+
repo_type="dataset",
|
| 64 |
+
commit_message=commit_message,
|
| 65 |
+
)
|
| 66 |
+
upload_file(
|
| 67 |
+
repo_id="diffusers/benchmark-analyzer",
|
| 68 |
+
path_in_repo=FINAL_CSV_FILENAME,
|
| 69 |
+
path_or_fileobj=FINAL_CSV_FILENAME,
|
| 70 |
+
repo_type="space",
|
| 71 |
+
commit_message=commit_message,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
push_to_hf_dataset()
|
diffusers/benchmarks/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pandas
|
| 2 |
+
psutil
|
| 3 |
+
gpustat
|
| 4 |
+
torchprofile
|
| 5 |
+
bitsandbytes
|
| 6 |
+
psycopg2==2.9.9
|
diffusers/benchmarks/run_all.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import subprocess
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
PATTERN = "benchmarking_*.py"
|
| 13 |
+
FINAL_CSV_FILENAME = "collated_results.csv"
|
| 14 |
+
GITHUB_SHA = os.getenv("GITHUB_SHA", None)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SubprocessCallException(Exception):
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def run_command(command: list[str], return_stdout=False):
|
| 22 |
+
try:
|
| 23 |
+
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
|
| 24 |
+
if return_stdout and hasattr(output, "decode"):
|
| 25 |
+
return output.decode("utf-8")
|
| 26 |
+
except subprocess.CalledProcessError as e:
|
| 27 |
+
raise SubprocessCallException(f"Command `{' '.join(command)}` failed with:\n{e.output.decode()}") from e
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def merge_csvs(final_csv: str = "collated_results.csv"):
|
| 31 |
+
all_csvs = glob.glob("*.csv")
|
| 32 |
+
all_csvs = [f for f in all_csvs if f != final_csv]
|
| 33 |
+
if not all_csvs:
|
| 34 |
+
logger.info("No result CSVs found to merge.")
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
df_list = []
|
| 38 |
+
for f in all_csvs:
|
| 39 |
+
try:
|
| 40 |
+
d = pd.read_csv(f)
|
| 41 |
+
except pd.errors.EmptyDataError:
|
| 42 |
+
# If a file existed but was zero‐bytes or corrupted, skip it
|
| 43 |
+
continue
|
| 44 |
+
df_list.append(d)
|
| 45 |
+
|
| 46 |
+
if not df_list:
|
| 47 |
+
logger.info("All result CSVs were empty or invalid; nothing to merge.")
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
+
final_df = pd.concat(df_list, ignore_index=True)
|
| 51 |
+
if GITHUB_SHA is not None:
|
| 52 |
+
final_df["github_sha"] = GITHUB_SHA
|
| 53 |
+
final_df.to_csv(final_csv, index=False)
|
| 54 |
+
logger.info(f"Merged {len(all_csvs)} partial CSVs → {final_csv}.")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def run_scripts():
|
| 58 |
+
python_files = sorted(glob.glob(PATTERN))
|
| 59 |
+
python_files = [f for f in python_files if f != "benchmarking_utils.py"]
|
| 60 |
+
|
| 61 |
+
for file in python_files:
|
| 62 |
+
script_name = file.split(".py")[0].split("_")[-1] # example: benchmarking_foo.py -> foo
|
| 63 |
+
logger.info(f"\n****** Running file: {file} ******")
|
| 64 |
+
|
| 65 |
+
partial_csv = f"{script_name}.csv"
|
| 66 |
+
if os.path.exists(partial_csv):
|
| 67 |
+
logger.info(f"Found {partial_csv}. Removing for safer numbers and duplication.")
|
| 68 |
+
os.remove(partial_csv)
|
| 69 |
+
|
| 70 |
+
command = ["python", file]
|
| 71 |
+
try:
|
| 72 |
+
run_command(command)
|
| 73 |
+
logger.info(f"→ {file} finished normally.")
|
| 74 |
+
except SubprocessCallException as e:
|
| 75 |
+
logger.info(f"Error running {file}:\n{e}")
|
| 76 |
+
finally:
|
| 77 |
+
logger.info(f"→ Merging partial CSVs after {file} …")
|
| 78 |
+
merge_csvs(final_csv=FINAL_CSV_FILENAME)
|
| 79 |
+
|
| 80 |
+
logger.info(f"\nAll scripts attempted. Final collated CSV: {FINAL_CSV_FILENAME}")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
run_scripts()
|
diffusers/examples/README.md
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!---
|
| 2 |
+
Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
you may not use this file except in compliance with the License.
|
| 5 |
+
You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
See the License for the specific language governing permissions and
|
| 13 |
+
limitations under the License.
|
| 14 |
+
-->
|
| 15 |
+
|
| 16 |
+
# 🧨 Diffusers Examples
|
| 17 |
+
|
| 18 |
+
Diffusers examples are a collection of scripts to demonstrate how to effectively use the `diffusers` library
|
| 19 |
+
for a variety of use cases involving training or fine-tuning.
|
| 20 |
+
|
| 21 |
+
**Note**: If you are looking for **official** examples on how to use `diffusers` for inference, please have a look at [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
|
| 22 |
+
|
| 23 |
+
Our examples aspire to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.
|
| 24 |
+
More specifically, this means:
|
| 25 |
+
|
| 26 |
+
- **Self-contained**: An example script shall only depend on "pip-install-able" Python packages that can be found in a `requirements.txt` file. Example scripts shall **not** depend on any local files. This means that one can simply download an example script, *e.g.* [train_unconditional.py](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py), install the required dependencies, *e.g.* [requirements.txt](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/requirements.txt) and execute the example script.
|
| 27 |
+
- **Easy-to-tweak**: While we strive to present as many use cases as possible, the example scripts are just that - examples. It is expected that they won't work out-of-the box on your specific problem and that you will be required to change a few lines of code to adapt them to your needs. To help you with that, most of the examples fully expose the preprocessing of the data and the training loop to allow you to tweak and edit them as required.
|
| 28 |
+
- **Beginner-friendly**: We do not aim for providing state-of-the-art training scripts for the newest models, but rather examples that can be used as a way to better understand diffusion models and how to use them with the `diffusers` library. We often purposefully leave out certain state-of-the-art methods if we consider them too complex for beginners.
|
| 29 |
+
- **One-purpose-only**: Examples should show one task and one task only. Even if a task is from a modeling point of view very similar, *e.g.* image super-resolution and image modification tend to use the same model and training method, we want examples to showcase only one task to keep them as readable and easy-to-understand as possible.
|
| 30 |
+
|
| 31 |
+
We provide **official** examples that cover the most popular tasks of diffusion models.
|
| 32 |
+
*Official* examples are **actively** maintained by the `diffusers` maintainers and we try to rigorously follow our example philosophy as defined above.
|
| 33 |
+
If you feel like another important example should exist, we are more than happy to welcome a [Feature Request](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=) or directly a [Pull Request](https://github.com/huggingface/diffusers/compare) from you!
|
| 34 |
+
|
| 35 |
+
Training examples show how to pretrain or fine-tune diffusion models for a variety of tasks. Currently we support:
|
| 36 |
+
|
| 37 |
+
| Task | 🤗 Accelerate | 🤗 Datasets | Colab
|
| 38 |
+
|---|---|:---:|:---:|
|
| 39 |
+
| [**Unconditional Image Generation**](./unconditional_image_generation) | ✅ | ✅ | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
| 40 |
+
| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |
|
| 41 |
+
| [**Textual Inversion**](./textual_inversion) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
|
| 42 |
+
| [**Dreambooth**](./dreambooth) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
|
| 43 |
+
| [**ControlNet**](./controlnet) | ✅ | ✅ | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)
|
| 44 |
+
| [**InstructPix2Pix**](./instruct_pix2pix) | ✅ | ✅ | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/InstructPix2Pix_using_diffusers.ipynb)
|
| 45 |
+
| [**Reinforcement Learning for Control**](./reinforcement_learning) | - | - | [Notebook1](https://github.com/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_for_control.ipynb), [Notebook2](https://github.com/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb)
|
| 46 |
+
|
| 47 |
+
## Community
|
| 48 |
+
|
| 49 |
+
In addition, we provide **community** examples, which are examples added and maintained by our community.
|
| 50 |
+
Community examples can consist of both *training* examples or *inference* pipelines.
|
| 51 |
+
For such examples, we are more lenient regarding the philosophy defined above and also cannot guarantee to provide maintenance for every issue.
|
| 52 |
+
Examples that are useful for the community, but are either not yet deemed popular or not yet following our above philosophy should go into the [community examples](https://github.com/huggingface/diffusers/tree/main/examples/community) folder. The community folder therefore includes training examples and inference pipelines.
|
| 53 |
+
**Note**: Community examples can be a [great first contribution](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) to show to the community how you like to use `diffusers` 🪄.
|
| 54 |
+
|
| 55 |
+
## Research Projects
|
| 56 |
+
|
| 57 |
+
We also provide **research_projects** examples that are maintained by the community as defined in the respective research project folders. These examples are useful and offer the extended capabilities which are complementary to the official examples. You may refer to [research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) for details.
|
| 58 |
+
|
| 59 |
+
## Important note
|
| 60 |
+
|
| 61 |
+
To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
| 62 |
+
```bash
|
| 63 |
+
git clone https://github.com/huggingface/diffusers
|
| 64 |
+
cd diffusers
|
| 65 |
+
pip install .
|
| 66 |
+
```
|
| 67 |
+
Then cd in the example folder of your choice and run
|
| 68 |
+
```bash
|
| 69 |
+
pip install -r requirements.txt
|
| 70 |
+
```
|
diffusers/examples/conftest.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# tests directory-specific settings - this file is run automatically
|
| 16 |
+
# by pytest before any tests are run
|
| 17 |
+
|
| 18 |
+
import sys
|
| 19 |
+
import warnings
|
| 20 |
+
from os.path import abspath, dirname, join
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# allow having multiple repository checkouts and not needing to remember to rerun
|
| 24 |
+
# 'pip install -e .[dev]' when switching between checkouts and running tests.
|
| 25 |
+
git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src"))
|
| 26 |
+
sys.path.insert(1, git_repo_path)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# silence FutureWarning warnings in tests since often we can't act on them until
|
| 30 |
+
# they become normal warnings - i.e. the tests still need to test the current functionality
|
| 31 |
+
warnings.simplefilter(action="ignore", category=FutureWarning)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def pytest_addoption(parser):
|
| 35 |
+
from diffusers.utils.testing_utils import pytest_addoption_shared
|
| 36 |
+
|
| 37 |
+
pytest_addoption_shared(parser)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def pytest_terminal_summary(terminalreporter):
|
| 41 |
+
from diffusers.utils.testing_utils import pytest_terminal_summary_main
|
| 42 |
+
|
| 43 |
+
make_reports = terminalreporter.config.getoption("--make-reports")
|
| 44 |
+
if make_reports:
|
| 45 |
+
pytest_terminal_summary_main(terminalreporter, id=make_reports)
|
diffusers/examples/test_examples_utils.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 HuggingFace Inc.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import shutil
|
| 18 |
+
import subprocess
|
| 19 |
+
import tempfile
|
| 20 |
+
import unittest
|
| 21 |
+
from typing import List
|
| 22 |
+
|
| 23 |
+
from accelerate.utils import write_basic_config
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# These utils relate to ensuring the right error message is received when running scripts
|
| 27 |
+
class SubprocessCallException(Exception):
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def run_command(command: List[str], return_stdout=False):
|
| 32 |
+
"""
|
| 33 |
+
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
|
| 34 |
+
if an error occurred while running `command`
|
| 35 |
+
"""
|
| 36 |
+
try:
|
| 37 |
+
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
|
| 38 |
+
if return_stdout:
|
| 39 |
+
if hasattr(output, "decode"):
|
| 40 |
+
output = output.decode("utf-8")
|
| 41 |
+
return output
|
| 42 |
+
except subprocess.CalledProcessError as e:
|
| 43 |
+
raise SubprocessCallException(
|
| 44 |
+
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
|
| 45 |
+
) from e
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ExamplesTestsAccelerate(unittest.TestCase):
|
| 49 |
+
@classmethod
|
| 50 |
+
def setUpClass(cls):
|
| 51 |
+
super().setUpClass()
|
| 52 |
+
cls._tmpdir = tempfile.mkdtemp()
|
| 53 |
+
cls.configPath = os.path.join(cls._tmpdir, "default_config.yml")
|
| 54 |
+
|
| 55 |
+
write_basic_config(save_location=cls.configPath)
|
| 56 |
+
cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def tearDownClass(cls):
|
| 60 |
+
super().tearDownClass()
|
| 61 |
+
shutil.rmtree(cls._tmpdir)
|
diffusers/scripts/__init__.py
ADDED
|
File without changes
|
diffusers/scripts/change_naming_configs_and_checkpoints.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Conversion script for the LDM checkpoints."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from transformers.file_utils import has_file
|
| 23 |
+
|
| 24 |
+
from diffusers import UNet2DConditionModel, UNet2DModel
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
do_only_config = False
|
| 28 |
+
do_only_weights = True
|
| 29 |
+
do_only_renaming = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
parser = argparse.ArgumentParser()
|
| 34 |
+
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--repo_path",
|
| 37 |
+
default=None,
|
| 38 |
+
type=str,
|
| 39 |
+
required=True,
|
| 40 |
+
help="The config json file corresponding to the architecture.",
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 44 |
+
|
| 45 |
+
args = parser.parse_args()
|
| 46 |
+
|
| 47 |
+
config_parameters_to_change = {
|
| 48 |
+
"image_size": "sample_size",
|
| 49 |
+
"num_res_blocks": "layers_per_block",
|
| 50 |
+
"block_channels": "block_out_channels",
|
| 51 |
+
"down_blocks": "down_block_types",
|
| 52 |
+
"up_blocks": "up_block_types",
|
| 53 |
+
"downscale_freq_shift": "freq_shift",
|
| 54 |
+
"resnet_num_groups": "norm_num_groups",
|
| 55 |
+
"resnet_act_fn": "act_fn",
|
| 56 |
+
"resnet_eps": "norm_eps",
|
| 57 |
+
"num_head_channels": "attention_head_dim",
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
key_parameters_to_change = {
|
| 61 |
+
"time_steps": "time_proj",
|
| 62 |
+
"mid": "mid_block",
|
| 63 |
+
"downsample_blocks": "down_blocks",
|
| 64 |
+
"upsample_blocks": "up_blocks",
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
subfolder = "" if has_file(args.repo_path, "config.json") else "unet"
|
| 68 |
+
|
| 69 |
+
with open(os.path.join(args.repo_path, subfolder, "config.json"), "r", encoding="utf-8") as reader:
|
| 70 |
+
text = reader.read()
|
| 71 |
+
config = json.loads(text)
|
| 72 |
+
|
| 73 |
+
if do_only_config:
|
| 74 |
+
for key in config_parameters_to_change.keys():
|
| 75 |
+
config.pop(key, None)
|
| 76 |
+
|
| 77 |
+
if has_file(args.repo_path, "config.json"):
|
| 78 |
+
model = UNet2DModel(**config)
|
| 79 |
+
else:
|
| 80 |
+
class_name = UNet2DConditionModel if "ldm-text2im-large-256" in args.repo_path else UNet2DModel
|
| 81 |
+
model = class_name(**config)
|
| 82 |
+
|
| 83 |
+
if do_only_config:
|
| 84 |
+
model.save_config(os.path.join(args.repo_path, subfolder))
|
| 85 |
+
|
| 86 |
+
config = dict(model.config)
|
| 87 |
+
|
| 88 |
+
if do_only_renaming:
|
| 89 |
+
for key, value in config_parameters_to_change.items():
|
| 90 |
+
if key in config:
|
| 91 |
+
config[value] = config[key]
|
| 92 |
+
del config[key]
|
| 93 |
+
|
| 94 |
+
config["down_block_types"] = [k.replace("UNetRes", "") for k in config["down_block_types"]]
|
| 95 |
+
config["up_block_types"] = [k.replace("UNetRes", "") for k in config["up_block_types"]]
|
| 96 |
+
|
| 97 |
+
if do_only_weights:
|
| 98 |
+
state_dict = torch.load(os.path.join(args.repo_path, subfolder, "diffusion_pytorch_model.bin"))
|
| 99 |
+
|
| 100 |
+
new_state_dict = {}
|
| 101 |
+
for param_key, param_value in state_dict.items():
|
| 102 |
+
if param_key.endswith(".op.bias") or param_key.endswith(".op.weight"):
|
| 103 |
+
continue
|
| 104 |
+
has_changed = False
|
| 105 |
+
for key, new_key in key_parameters_to_change.items():
|
| 106 |
+
if not has_changed and param_key.split(".")[0] == key:
|
| 107 |
+
new_state_dict[".".join([new_key] + param_key.split(".")[1:])] = param_value
|
| 108 |
+
has_changed = True
|
| 109 |
+
if not has_changed:
|
| 110 |
+
new_state_dict[param_key] = param_value
|
| 111 |
+
|
| 112 |
+
model.load_state_dict(new_state_dict)
|
| 113 |
+
model.save_pretrained(os.path.join(args.repo_path, subfolder))
|
diffusers/scripts/convert_amused.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import os
|
| 3 |
+
from argparse import ArgumentParser
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from muse import MaskGiTUViT, VQGANModel
|
| 8 |
+
from muse import PipelineMuse as OldPipelineMuse
|
| 9 |
+
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
| 10 |
+
|
| 11 |
+
from diffusers import VQModel
|
| 12 |
+
from diffusers.models.attention_processor import AttnProcessor
|
| 13 |
+
from diffusers.models.unets.uvit_2d import UVit2DModel
|
| 14 |
+
from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline
|
| 15 |
+
from diffusers.schedulers import AmusedScheduler
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
torch.backends.cuda.enable_flash_sdp(False)
|
| 19 |
+
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
| 20 |
+
torch.backends.cuda.enable_math_sdp(True)
|
| 21 |
+
|
| 22 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
| 23 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
| 24 |
+
torch.use_deterministic_algorithms(True)
|
| 25 |
+
|
| 26 |
+
# Enable CUDNN deterministic mode
|
| 27 |
+
torch.backends.cudnn.deterministic = True
|
| 28 |
+
torch.backends.cudnn.benchmark = False
|
| 29 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
| 30 |
+
|
| 31 |
+
device = "cuda"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def main():
|
| 35 |
+
args = ArgumentParser()
|
| 36 |
+
args.add_argument("--model_256", action="store_true")
|
| 37 |
+
args.add_argument("--write_to", type=str, required=False, default=None)
|
| 38 |
+
args.add_argument("--transformer_path", type=str, required=False, default=None)
|
| 39 |
+
args = args.parse_args()
|
| 40 |
+
|
| 41 |
+
transformer_path = args.transformer_path
|
| 42 |
+
subfolder = "transformer"
|
| 43 |
+
|
| 44 |
+
if transformer_path is None:
|
| 45 |
+
if args.model_256:
|
| 46 |
+
transformer_path = "openMUSE/muse-256"
|
| 47 |
+
else:
|
| 48 |
+
transformer_path = (
|
| 49 |
+
"../research-run-512-checkpoints/research-run-512-with-downsample-checkpoint-554000/unwrapped_model/"
|
| 50 |
+
)
|
| 51 |
+
subfolder = None
|
| 52 |
+
|
| 53 |
+
old_transformer = MaskGiTUViT.from_pretrained(transformer_path, subfolder=subfolder)
|
| 54 |
+
|
| 55 |
+
old_transformer.to(device)
|
| 56 |
+
|
| 57 |
+
old_vae = VQGANModel.from_pretrained("openMUSE/muse-512", subfolder="vae")
|
| 58 |
+
old_vae.to(device)
|
| 59 |
+
|
| 60 |
+
vqvae = make_vqvae(old_vae)
|
| 61 |
+
|
| 62 |
+
tokenizer = CLIPTokenizer.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
|
| 63 |
+
|
| 64 |
+
text_encoder = CLIPTextModelWithProjection.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
|
| 65 |
+
text_encoder.to(device)
|
| 66 |
+
|
| 67 |
+
transformer = make_transformer(old_transformer, args.model_256)
|
| 68 |
+
|
| 69 |
+
scheduler = AmusedScheduler(mask_token_id=old_transformer.config.mask_token_id)
|
| 70 |
+
|
| 71 |
+
new_pipe = AmusedPipeline(
|
| 72 |
+
vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
old_pipe = OldPipelineMuse(
|
| 76 |
+
vae=old_vae, transformer=old_transformer, text_encoder=text_encoder, tokenizer=tokenizer
|
| 77 |
+
)
|
| 78 |
+
old_pipe.to(device)
|
| 79 |
+
|
| 80 |
+
if args.model_256:
|
| 81 |
+
transformer_seq_len = 256
|
| 82 |
+
orig_size = (256, 256)
|
| 83 |
+
else:
|
| 84 |
+
transformer_seq_len = 1024
|
| 85 |
+
orig_size = (512, 512)
|
| 86 |
+
|
| 87 |
+
old_out = old_pipe(
|
| 88 |
+
"dog",
|
| 89 |
+
generator=torch.Generator(device).manual_seed(0),
|
| 90 |
+
transformer_seq_len=transformer_seq_len,
|
| 91 |
+
orig_size=orig_size,
|
| 92 |
+
timesteps=12,
|
| 93 |
+
)[0]
|
| 94 |
+
|
| 95 |
+
new_out = new_pipe("dog", generator=torch.Generator(device).manual_seed(0)).images[0]
|
| 96 |
+
|
| 97 |
+
old_out = np.array(old_out)
|
| 98 |
+
new_out = np.array(new_out)
|
| 99 |
+
|
| 100 |
+
diff = np.abs(old_out.astype(np.float64) - new_out.astype(np.float64))
|
| 101 |
+
|
| 102 |
+
# assert diff diff.sum() == 0
|
| 103 |
+
print("skipping pipeline full equivalence check")
|
| 104 |
+
|
| 105 |
+
print(f"max diff: {diff.max()}, diff.sum() / diff.size {diff.sum() / diff.size}")
|
| 106 |
+
|
| 107 |
+
if args.model_256:
|
| 108 |
+
assert diff.max() <= 3
|
| 109 |
+
assert diff.sum() / diff.size < 0.7
|
| 110 |
+
else:
|
| 111 |
+
assert diff.max() <= 1
|
| 112 |
+
assert diff.sum() / diff.size < 0.4
|
| 113 |
+
|
| 114 |
+
if args.write_to is not None:
|
| 115 |
+
new_pipe.save_pretrained(args.write_to)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def make_transformer(old_transformer, model_256):
|
| 119 |
+
args = dict(old_transformer.config)
|
| 120 |
+
force_down_up_sample = args["force_down_up_sample"]
|
| 121 |
+
|
| 122 |
+
signature = inspect.signature(UVit2DModel.__init__)
|
| 123 |
+
|
| 124 |
+
args_ = {
|
| 125 |
+
"downsample": force_down_up_sample,
|
| 126 |
+
"upsample": force_down_up_sample,
|
| 127 |
+
"block_out_channels": args["block_out_channels"][0],
|
| 128 |
+
"sample_size": 16 if model_256 else 32,
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
for s in list(signature.parameters.keys()):
|
| 132 |
+
if s in ["self", "downsample", "upsample", "sample_size", "block_out_channels"]:
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
args_[s] = args[s]
|
| 136 |
+
|
| 137 |
+
new_transformer = UVit2DModel(**args_)
|
| 138 |
+
new_transformer.to(device)
|
| 139 |
+
|
| 140 |
+
new_transformer.set_attn_processor(AttnProcessor())
|
| 141 |
+
|
| 142 |
+
state_dict = old_transformer.state_dict()
|
| 143 |
+
|
| 144 |
+
state_dict["cond_embed.linear_1.weight"] = state_dict.pop("cond_embed.0.weight")
|
| 145 |
+
state_dict["cond_embed.linear_2.weight"] = state_dict.pop("cond_embed.2.weight")
|
| 146 |
+
|
| 147 |
+
for i in range(22):
|
| 148 |
+
state_dict[f"transformer_layers.{i}.norm1.norm.weight"] = state_dict.pop(
|
| 149 |
+
f"transformer_layers.{i}.attn_layer_norm.weight"
|
| 150 |
+
)
|
| 151 |
+
state_dict[f"transformer_layers.{i}.norm1.linear.weight"] = state_dict.pop(
|
| 152 |
+
f"transformer_layers.{i}.self_attn_adaLN_modulation.mapper.weight"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
state_dict[f"transformer_layers.{i}.attn1.to_q.weight"] = state_dict.pop(
|
| 156 |
+
f"transformer_layers.{i}.attention.query.weight"
|
| 157 |
+
)
|
| 158 |
+
state_dict[f"transformer_layers.{i}.attn1.to_k.weight"] = state_dict.pop(
|
| 159 |
+
f"transformer_layers.{i}.attention.key.weight"
|
| 160 |
+
)
|
| 161 |
+
state_dict[f"transformer_layers.{i}.attn1.to_v.weight"] = state_dict.pop(
|
| 162 |
+
f"transformer_layers.{i}.attention.value.weight"
|
| 163 |
+
)
|
| 164 |
+
state_dict[f"transformer_layers.{i}.attn1.to_out.0.weight"] = state_dict.pop(
|
| 165 |
+
f"transformer_layers.{i}.attention.out.weight"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
state_dict[f"transformer_layers.{i}.norm2.norm.weight"] = state_dict.pop(
|
| 169 |
+
f"transformer_layers.{i}.crossattn_layer_norm.weight"
|
| 170 |
+
)
|
| 171 |
+
state_dict[f"transformer_layers.{i}.norm2.linear.weight"] = state_dict.pop(
|
| 172 |
+
f"transformer_layers.{i}.cross_attn_adaLN_modulation.mapper.weight"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
state_dict[f"transformer_layers.{i}.attn2.to_q.weight"] = state_dict.pop(
|
| 176 |
+
f"transformer_layers.{i}.crossattention.query.weight"
|
| 177 |
+
)
|
| 178 |
+
state_dict[f"transformer_layers.{i}.attn2.to_k.weight"] = state_dict.pop(
|
| 179 |
+
f"transformer_layers.{i}.crossattention.key.weight"
|
| 180 |
+
)
|
| 181 |
+
state_dict[f"transformer_layers.{i}.attn2.to_v.weight"] = state_dict.pop(
|
| 182 |
+
f"transformer_layers.{i}.crossattention.value.weight"
|
| 183 |
+
)
|
| 184 |
+
state_dict[f"transformer_layers.{i}.attn2.to_out.0.weight"] = state_dict.pop(
|
| 185 |
+
f"transformer_layers.{i}.crossattention.out.weight"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
state_dict[f"transformer_layers.{i}.norm3.norm.weight"] = state_dict.pop(
|
| 189 |
+
f"transformer_layers.{i}.ffn.pre_mlp_layer_norm.weight"
|
| 190 |
+
)
|
| 191 |
+
state_dict[f"transformer_layers.{i}.norm3.linear.weight"] = state_dict.pop(
|
| 192 |
+
f"transformer_layers.{i}.ffn.adaLN_modulation.mapper.weight"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
wi_0_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_0.weight")
|
| 196 |
+
wi_1_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_1.weight")
|
| 197 |
+
proj_weight = torch.concat([wi_1_weight, wi_0_weight], dim=0)
|
| 198 |
+
state_dict[f"transformer_layers.{i}.ff.net.0.proj.weight"] = proj_weight
|
| 199 |
+
|
| 200 |
+
state_dict[f"transformer_layers.{i}.ff.net.2.weight"] = state_dict.pop(f"transformer_layers.{i}.ffn.wo.weight")
|
| 201 |
+
|
| 202 |
+
if force_down_up_sample:
|
| 203 |
+
state_dict["down_block.downsample.norm.weight"] = state_dict.pop("down_blocks.0.downsample.0.norm.weight")
|
| 204 |
+
state_dict["down_block.downsample.conv.weight"] = state_dict.pop("down_blocks.0.downsample.1.weight")
|
| 205 |
+
|
| 206 |
+
state_dict["up_block.upsample.norm.weight"] = state_dict.pop("up_blocks.0.upsample.0.norm.weight")
|
| 207 |
+
state_dict["up_block.upsample.conv.weight"] = state_dict.pop("up_blocks.0.upsample.1.weight")
|
| 208 |
+
|
| 209 |
+
state_dict["mlm_layer.layer_norm.weight"] = state_dict.pop("mlm_layer.layer_norm.norm.weight")
|
| 210 |
+
|
| 211 |
+
for i in range(3):
|
| 212 |
+
state_dict[f"down_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
|
| 213 |
+
f"down_blocks.0.res_blocks.{i}.norm.norm.weight"
|
| 214 |
+
)
|
| 215 |
+
state_dict[f"down_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
|
| 216 |
+
f"down_blocks.0.res_blocks.{i}.channelwise.0.weight"
|
| 217 |
+
)
|
| 218 |
+
state_dict[f"down_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
|
| 219 |
+
f"down_blocks.0.res_blocks.{i}.channelwise.2.gamma"
|
| 220 |
+
)
|
| 221 |
+
state_dict[f"down_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
|
| 222 |
+
f"down_blocks.0.res_blocks.{i}.channelwise.2.beta"
|
| 223 |
+
)
|
| 224 |
+
state_dict[f"down_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
|
| 225 |
+
f"down_blocks.0.res_blocks.{i}.channelwise.4.weight"
|
| 226 |
+
)
|
| 227 |
+
state_dict[f"down_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
|
| 228 |
+
f"down_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
state_dict[f"down_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
|
| 232 |
+
f"down_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
|
| 233 |
+
)
|
| 234 |
+
state_dict[f"down_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
|
| 235 |
+
f"down_blocks.0.attention_blocks.{i}.attention.query.weight"
|
| 236 |
+
)
|
| 237 |
+
state_dict[f"down_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
|
| 238 |
+
f"down_blocks.0.attention_blocks.{i}.attention.key.weight"
|
| 239 |
+
)
|
| 240 |
+
state_dict[f"down_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
|
| 241 |
+
f"down_blocks.0.attention_blocks.{i}.attention.value.weight"
|
| 242 |
+
)
|
| 243 |
+
state_dict[f"down_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
|
| 244 |
+
f"down_blocks.0.attention_blocks.{i}.attention.out.weight"
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
state_dict[f"down_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
|
| 248 |
+
f"down_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
|
| 249 |
+
)
|
| 250 |
+
state_dict[f"down_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
|
| 251 |
+
f"down_blocks.0.attention_blocks.{i}.crossattention.query.weight"
|
| 252 |
+
)
|
| 253 |
+
state_dict[f"down_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
|
| 254 |
+
f"down_blocks.0.attention_blocks.{i}.crossattention.key.weight"
|
| 255 |
+
)
|
| 256 |
+
state_dict[f"down_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
|
| 257 |
+
f"down_blocks.0.attention_blocks.{i}.crossattention.value.weight"
|
| 258 |
+
)
|
| 259 |
+
state_dict[f"down_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
|
| 260 |
+
f"down_blocks.0.attention_blocks.{i}.crossattention.out.weight"
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
state_dict[f"up_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
|
| 264 |
+
f"up_blocks.0.res_blocks.{i}.norm.norm.weight"
|
| 265 |
+
)
|
| 266 |
+
state_dict[f"up_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
|
| 267 |
+
f"up_blocks.0.res_blocks.{i}.channelwise.0.weight"
|
| 268 |
+
)
|
| 269 |
+
state_dict[f"up_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
|
| 270 |
+
f"up_blocks.0.res_blocks.{i}.channelwise.2.gamma"
|
| 271 |
+
)
|
| 272 |
+
state_dict[f"up_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
|
| 273 |
+
f"up_blocks.0.res_blocks.{i}.channelwise.2.beta"
|
| 274 |
+
)
|
| 275 |
+
state_dict[f"up_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
|
| 276 |
+
f"up_blocks.0.res_blocks.{i}.channelwise.4.weight"
|
| 277 |
+
)
|
| 278 |
+
state_dict[f"up_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
|
| 279 |
+
f"up_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
state_dict[f"up_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
|
| 283 |
+
f"up_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
|
| 284 |
+
)
|
| 285 |
+
state_dict[f"up_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
|
| 286 |
+
f"up_blocks.0.attention_blocks.{i}.attention.query.weight"
|
| 287 |
+
)
|
| 288 |
+
state_dict[f"up_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
|
| 289 |
+
f"up_blocks.0.attention_blocks.{i}.attention.key.weight"
|
| 290 |
+
)
|
| 291 |
+
state_dict[f"up_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
|
| 292 |
+
f"up_blocks.0.attention_blocks.{i}.attention.value.weight"
|
| 293 |
+
)
|
| 294 |
+
state_dict[f"up_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
|
| 295 |
+
f"up_blocks.0.attention_blocks.{i}.attention.out.weight"
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
state_dict[f"up_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
|
| 299 |
+
f"up_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
|
| 300 |
+
)
|
| 301 |
+
state_dict[f"up_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
|
| 302 |
+
f"up_blocks.0.attention_blocks.{i}.crossattention.query.weight"
|
| 303 |
+
)
|
| 304 |
+
state_dict[f"up_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
|
| 305 |
+
f"up_blocks.0.attention_blocks.{i}.crossattention.key.weight"
|
| 306 |
+
)
|
| 307 |
+
state_dict[f"up_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
|
| 308 |
+
f"up_blocks.0.attention_blocks.{i}.crossattention.value.weight"
|
| 309 |
+
)
|
| 310 |
+
state_dict[f"up_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
|
| 311 |
+
f"up_blocks.0.attention_blocks.{i}.crossattention.out.weight"
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
for key in list(state_dict.keys()):
|
| 315 |
+
if key.startswith("up_blocks.0"):
|
| 316 |
+
key_ = "up_block." + ".".join(key.split(".")[2:])
|
| 317 |
+
state_dict[key_] = state_dict.pop(key)
|
| 318 |
+
|
| 319 |
+
if key.startswith("down_blocks.0"):
|
| 320 |
+
key_ = "down_block." + ".".join(key.split(".")[2:])
|
| 321 |
+
state_dict[key_] = state_dict.pop(key)
|
| 322 |
+
|
| 323 |
+
new_transformer.load_state_dict(state_dict)
|
| 324 |
+
|
| 325 |
+
input_ids = torch.randint(0, 10, (1, 32, 32), device=old_transformer.device)
|
| 326 |
+
encoder_hidden_states = torch.randn((1, 77, 768), device=old_transformer.device)
|
| 327 |
+
cond_embeds = torch.randn((1, 768), device=old_transformer.device)
|
| 328 |
+
micro_conds = torch.tensor([[512, 512, 0, 0, 6]], dtype=torch.float32, device=old_transformer.device)
|
| 329 |
+
|
| 330 |
+
old_out = old_transformer(input_ids.reshape(1, -1), encoder_hidden_states, cond_embeds, micro_conds)
|
| 331 |
+
old_out = old_out.reshape(1, 32, 32, 8192).permute(0, 3, 1, 2)
|
| 332 |
+
|
| 333 |
+
new_out = new_transformer(input_ids, encoder_hidden_states, cond_embeds, micro_conds)
|
| 334 |
+
|
| 335 |
+
# NOTE: these differences are solely due to using the geglu block that has a single linear layer of
|
| 336 |
+
# double output dimension instead of two different linear layers
|
| 337 |
+
max_diff = (old_out - new_out).abs().max()
|
| 338 |
+
total_diff = (old_out - new_out).abs().sum()
|
| 339 |
+
print(f"Transformer max_diff: {max_diff} total_diff: {total_diff}")
|
| 340 |
+
assert max_diff < 0.01
|
| 341 |
+
assert total_diff < 1500
|
| 342 |
+
|
| 343 |
+
return new_transformer
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def make_vqvae(old_vae):
|
| 347 |
+
new_vae = VQModel(
|
| 348 |
+
act_fn="silu",
|
| 349 |
+
block_out_channels=[128, 256, 256, 512, 768],
|
| 350 |
+
down_block_types=[
|
| 351 |
+
"DownEncoderBlock2D",
|
| 352 |
+
"DownEncoderBlock2D",
|
| 353 |
+
"DownEncoderBlock2D",
|
| 354 |
+
"DownEncoderBlock2D",
|
| 355 |
+
"DownEncoderBlock2D",
|
| 356 |
+
],
|
| 357 |
+
in_channels=3,
|
| 358 |
+
latent_channels=64,
|
| 359 |
+
layers_per_block=2,
|
| 360 |
+
norm_num_groups=32,
|
| 361 |
+
num_vq_embeddings=8192,
|
| 362 |
+
out_channels=3,
|
| 363 |
+
sample_size=32,
|
| 364 |
+
up_block_types=[
|
| 365 |
+
"UpDecoderBlock2D",
|
| 366 |
+
"UpDecoderBlock2D",
|
| 367 |
+
"UpDecoderBlock2D",
|
| 368 |
+
"UpDecoderBlock2D",
|
| 369 |
+
"UpDecoderBlock2D",
|
| 370 |
+
],
|
| 371 |
+
mid_block_add_attention=False,
|
| 372 |
+
lookup_from_codebook=True,
|
| 373 |
+
)
|
| 374 |
+
new_vae.to(device)
|
| 375 |
+
|
| 376 |
+
# fmt: off
|
| 377 |
+
|
| 378 |
+
new_state_dict = {}
|
| 379 |
+
|
| 380 |
+
old_state_dict = old_vae.state_dict()
|
| 381 |
+
|
| 382 |
+
new_state_dict["encoder.conv_in.weight"] = old_state_dict.pop("encoder.conv_in.weight")
|
| 383 |
+
new_state_dict["encoder.conv_in.bias"] = old_state_dict.pop("encoder.conv_in.bias")
|
| 384 |
+
|
| 385 |
+
convert_vae_block_state_dict(old_state_dict, "encoder.down.0", new_state_dict, "encoder.down_blocks.0")
|
| 386 |
+
convert_vae_block_state_dict(old_state_dict, "encoder.down.1", new_state_dict, "encoder.down_blocks.1")
|
| 387 |
+
convert_vae_block_state_dict(old_state_dict, "encoder.down.2", new_state_dict, "encoder.down_blocks.2")
|
| 388 |
+
convert_vae_block_state_dict(old_state_dict, "encoder.down.3", new_state_dict, "encoder.down_blocks.3")
|
| 389 |
+
convert_vae_block_state_dict(old_state_dict, "encoder.down.4", new_state_dict, "encoder.down_blocks.4")
|
| 390 |
+
|
| 391 |
+
new_state_dict["encoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("encoder.mid.block_1.norm1.weight")
|
| 392 |
+
new_state_dict["encoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("encoder.mid.block_1.norm1.bias")
|
| 393 |
+
new_state_dict["encoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("encoder.mid.block_1.conv1.weight")
|
| 394 |
+
new_state_dict["encoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("encoder.mid.block_1.conv1.bias")
|
| 395 |
+
new_state_dict["encoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("encoder.mid.block_1.norm2.weight")
|
| 396 |
+
new_state_dict["encoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("encoder.mid.block_1.norm2.bias")
|
| 397 |
+
new_state_dict["encoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("encoder.mid.block_1.conv2.weight")
|
| 398 |
+
new_state_dict["encoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("encoder.mid.block_1.conv2.bias")
|
| 399 |
+
new_state_dict["encoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("encoder.mid.block_2.norm1.weight")
|
| 400 |
+
new_state_dict["encoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("encoder.mid.block_2.norm1.bias")
|
| 401 |
+
new_state_dict["encoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("encoder.mid.block_2.conv1.weight")
|
| 402 |
+
new_state_dict["encoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("encoder.mid.block_2.conv1.bias")
|
| 403 |
+
new_state_dict["encoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("encoder.mid.block_2.norm2.weight")
|
| 404 |
+
new_state_dict["encoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("encoder.mid.block_2.norm2.bias")
|
| 405 |
+
new_state_dict["encoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("encoder.mid.block_2.conv2.weight")
|
| 406 |
+
new_state_dict["encoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("encoder.mid.block_2.conv2.bias")
|
| 407 |
+
new_state_dict["encoder.conv_norm_out.weight"] = old_state_dict.pop("encoder.norm_out.weight")
|
| 408 |
+
new_state_dict["encoder.conv_norm_out.bias"] = old_state_dict.pop("encoder.norm_out.bias")
|
| 409 |
+
new_state_dict["encoder.conv_out.weight"] = old_state_dict.pop("encoder.conv_out.weight")
|
| 410 |
+
new_state_dict["encoder.conv_out.bias"] = old_state_dict.pop("encoder.conv_out.bias")
|
| 411 |
+
new_state_dict["quant_conv.weight"] = old_state_dict.pop("quant_conv.weight")
|
| 412 |
+
new_state_dict["quant_conv.bias"] = old_state_dict.pop("quant_conv.bias")
|
| 413 |
+
new_state_dict["quantize.embedding.weight"] = old_state_dict.pop("quantize.embedding.weight")
|
| 414 |
+
new_state_dict["post_quant_conv.weight"] = old_state_dict.pop("post_quant_conv.weight")
|
| 415 |
+
new_state_dict["post_quant_conv.bias"] = old_state_dict.pop("post_quant_conv.bias")
|
| 416 |
+
new_state_dict["decoder.conv_in.weight"] = old_state_dict.pop("decoder.conv_in.weight")
|
| 417 |
+
new_state_dict["decoder.conv_in.bias"] = old_state_dict.pop("decoder.conv_in.bias")
|
| 418 |
+
new_state_dict["decoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("decoder.mid.block_1.norm1.weight")
|
| 419 |
+
new_state_dict["decoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("decoder.mid.block_1.norm1.bias")
|
| 420 |
+
new_state_dict["decoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("decoder.mid.block_1.conv1.weight")
|
| 421 |
+
new_state_dict["decoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("decoder.mid.block_1.conv1.bias")
|
| 422 |
+
new_state_dict["decoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("decoder.mid.block_1.norm2.weight")
|
| 423 |
+
new_state_dict["decoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("decoder.mid.block_1.norm2.bias")
|
| 424 |
+
new_state_dict["decoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("decoder.mid.block_1.conv2.weight")
|
| 425 |
+
new_state_dict["decoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("decoder.mid.block_1.conv2.bias")
|
| 426 |
+
new_state_dict["decoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("decoder.mid.block_2.norm1.weight")
|
| 427 |
+
new_state_dict["decoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("decoder.mid.block_2.norm1.bias")
|
| 428 |
+
new_state_dict["decoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("decoder.mid.block_2.conv1.weight")
|
| 429 |
+
new_state_dict["decoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("decoder.mid.block_2.conv1.bias")
|
| 430 |
+
new_state_dict["decoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("decoder.mid.block_2.norm2.weight")
|
| 431 |
+
new_state_dict["decoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("decoder.mid.block_2.norm2.bias")
|
| 432 |
+
new_state_dict["decoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("decoder.mid.block_2.conv2.weight")
|
| 433 |
+
new_state_dict["decoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("decoder.mid.block_2.conv2.bias")
|
| 434 |
+
|
| 435 |
+
convert_vae_block_state_dict(old_state_dict, "decoder.up.0", new_state_dict, "decoder.up_blocks.4")
|
| 436 |
+
convert_vae_block_state_dict(old_state_dict, "decoder.up.1", new_state_dict, "decoder.up_blocks.3")
|
| 437 |
+
convert_vae_block_state_dict(old_state_dict, "decoder.up.2", new_state_dict, "decoder.up_blocks.2")
|
| 438 |
+
convert_vae_block_state_dict(old_state_dict, "decoder.up.3", new_state_dict, "decoder.up_blocks.1")
|
| 439 |
+
convert_vae_block_state_dict(old_state_dict, "decoder.up.4", new_state_dict, "decoder.up_blocks.0")
|
| 440 |
+
|
| 441 |
+
new_state_dict["decoder.conv_norm_out.weight"] = old_state_dict.pop("decoder.norm_out.weight")
|
| 442 |
+
new_state_dict["decoder.conv_norm_out.bias"] = old_state_dict.pop("decoder.norm_out.bias")
|
| 443 |
+
new_state_dict["decoder.conv_out.weight"] = old_state_dict.pop("decoder.conv_out.weight")
|
| 444 |
+
new_state_dict["decoder.conv_out.bias"] = old_state_dict.pop("decoder.conv_out.bias")
|
| 445 |
+
|
| 446 |
+
# fmt: on
|
| 447 |
+
|
| 448 |
+
assert len(old_state_dict.keys()) == 0
|
| 449 |
+
|
| 450 |
+
new_vae.load_state_dict(new_state_dict)
|
| 451 |
+
|
| 452 |
+
input = torch.randn((1, 3, 512, 512), device=device)
|
| 453 |
+
input = input.clamp(-1, 1)
|
| 454 |
+
|
| 455 |
+
old_encoder_output = old_vae.quant_conv(old_vae.encoder(input))
|
| 456 |
+
new_encoder_output = new_vae.quant_conv(new_vae.encoder(input))
|
| 457 |
+
assert (old_encoder_output == new_encoder_output).all()
|
| 458 |
+
|
| 459 |
+
old_decoder_output = old_vae.decoder(old_vae.post_quant_conv(old_encoder_output))
|
| 460 |
+
new_decoder_output = new_vae.decoder(new_vae.post_quant_conv(new_encoder_output))
|
| 461 |
+
|
| 462 |
+
# assert (old_decoder_output == new_decoder_output).all()
|
| 463 |
+
print("kipping vae decoder equivalence check")
|
| 464 |
+
print(f"vae decoder diff {(old_decoder_output - new_decoder_output).float().abs().sum()}")
|
| 465 |
+
|
| 466 |
+
old_output = old_vae(input)[0]
|
| 467 |
+
new_output = new_vae(input)[0]
|
| 468 |
+
|
| 469 |
+
# assert (old_output == new_output).all()
|
| 470 |
+
print("skipping full vae equivalence check")
|
| 471 |
+
print(f"vae full diff {(old_output - new_output).float().abs().sum()}")
|
| 472 |
+
|
| 473 |
+
return new_vae
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def convert_vae_block_state_dict(old_state_dict, prefix_from, new_state_dict, prefix_to):
|
| 477 |
+
# fmt: off
|
| 478 |
+
|
| 479 |
+
new_state_dict[f"{prefix_to}.resnets.0.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.weight")
|
| 480 |
+
new_state_dict[f"{prefix_to}.resnets.0.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.bias")
|
| 481 |
+
new_state_dict[f"{prefix_to}.resnets.0.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.weight")
|
| 482 |
+
new_state_dict[f"{prefix_to}.resnets.0.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.bias")
|
| 483 |
+
new_state_dict[f"{prefix_to}.resnets.0.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.weight")
|
| 484 |
+
new_state_dict[f"{prefix_to}.resnets.0.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.bias")
|
| 485 |
+
new_state_dict[f"{prefix_to}.resnets.0.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.weight")
|
| 486 |
+
new_state_dict[f"{prefix_to}.resnets.0.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.bias")
|
| 487 |
+
|
| 488 |
+
if f"{prefix_from}.block.0.nin_shortcut.weight" in old_state_dict:
|
| 489 |
+
new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.weight")
|
| 490 |
+
new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.bias")
|
| 491 |
+
|
| 492 |
+
new_state_dict[f"{prefix_to}.resnets.1.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.weight")
|
| 493 |
+
new_state_dict[f"{prefix_to}.resnets.1.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.bias")
|
| 494 |
+
new_state_dict[f"{prefix_to}.resnets.1.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.weight")
|
| 495 |
+
new_state_dict[f"{prefix_to}.resnets.1.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.bias")
|
| 496 |
+
new_state_dict[f"{prefix_to}.resnets.1.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.weight")
|
| 497 |
+
new_state_dict[f"{prefix_to}.resnets.1.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.bias")
|
| 498 |
+
new_state_dict[f"{prefix_to}.resnets.1.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.weight")
|
| 499 |
+
new_state_dict[f"{prefix_to}.resnets.1.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.bias")
|
| 500 |
+
|
| 501 |
+
if f"{prefix_from}.downsample.conv.weight" in old_state_dict:
|
| 502 |
+
new_state_dict[f"{prefix_to}.downsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.weight")
|
| 503 |
+
new_state_dict[f"{prefix_to}.downsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.bias")
|
| 504 |
+
|
| 505 |
+
if f"{prefix_from}.upsample.conv.weight" in old_state_dict:
|
| 506 |
+
new_state_dict[f"{prefix_to}.upsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.weight")
|
| 507 |
+
new_state_dict[f"{prefix_to}.upsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.bias")
|
| 508 |
+
|
| 509 |
+
if f"{prefix_from}.block.2.norm1.weight" in old_state_dict:
|
| 510 |
+
new_state_dict[f"{prefix_to}.resnets.2.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.weight")
|
| 511 |
+
new_state_dict[f"{prefix_to}.resnets.2.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.bias")
|
| 512 |
+
new_state_dict[f"{prefix_to}.resnets.2.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.weight")
|
| 513 |
+
new_state_dict[f"{prefix_to}.resnets.2.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.bias")
|
| 514 |
+
new_state_dict[f"{prefix_to}.resnets.2.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.weight")
|
| 515 |
+
new_state_dict[f"{prefix_to}.resnets.2.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.bias")
|
| 516 |
+
new_state_dict[f"{prefix_to}.resnets.2.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.weight")
|
| 517 |
+
new_state_dict[f"{prefix_to}.resnets.2.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.bias")
|
| 518 |
+
|
| 519 |
+
# fmt: on
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
if __name__ == "__main__":
|
| 523 |
+
main()
|
diffusers/scripts/convert_animatediff_motion_module_to_diffusers.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from safetensors.torch import load_file
|
| 5 |
+
|
| 6 |
+
from diffusers import MotionAdapter
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def convert_motion_module(original_state_dict):
|
| 10 |
+
converted_state_dict = {}
|
| 11 |
+
for k, v in original_state_dict.items():
|
| 12 |
+
if "pos_encoder" in k:
|
| 13 |
+
continue
|
| 14 |
+
|
| 15 |
+
else:
|
| 16 |
+
converted_state_dict[
|
| 17 |
+
k.replace(".norms.0", ".norm1")
|
| 18 |
+
.replace(".norms.1", ".norm2")
|
| 19 |
+
.replace(".ff_norm", ".norm3")
|
| 20 |
+
.replace(".attention_blocks.0", ".attn1")
|
| 21 |
+
.replace(".attention_blocks.1", ".attn2")
|
| 22 |
+
.replace(".temporal_transformer", "")
|
| 23 |
+
] = v
|
| 24 |
+
|
| 25 |
+
return converted_state_dict
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_args():
|
| 29 |
+
parser = argparse.ArgumentParser()
|
| 30 |
+
parser.add_argument("--ckpt_path", type=str, required=True)
|
| 31 |
+
parser.add_argument("--output_path", type=str, required=True)
|
| 32 |
+
parser.add_argument("--use_motion_mid_block", action="store_true")
|
| 33 |
+
parser.add_argument("--motion_max_seq_length", type=int, default=32)
|
| 34 |
+
parser.add_argument("--block_out_channels", nargs="+", default=[320, 640, 1280, 1280], type=int)
|
| 35 |
+
parser.add_argument("--save_fp16", action="store_true")
|
| 36 |
+
|
| 37 |
+
return parser.parse_args()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
args = get_args()
|
| 42 |
+
|
| 43 |
+
if args.ckpt_path.endswith(".safetensors"):
|
| 44 |
+
state_dict = load_file(args.ckpt_path)
|
| 45 |
+
else:
|
| 46 |
+
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
| 47 |
+
|
| 48 |
+
if "state_dict" in state_dict.keys():
|
| 49 |
+
state_dict = state_dict["state_dict"]
|
| 50 |
+
|
| 51 |
+
conv_state_dict = convert_motion_module(state_dict)
|
| 52 |
+
adapter = MotionAdapter(
|
| 53 |
+
block_out_channels=args.block_out_channels,
|
| 54 |
+
use_motion_mid_block=args.use_motion_mid_block,
|
| 55 |
+
motion_max_seq_length=args.motion_max_seq_length,
|
| 56 |
+
)
|
| 57 |
+
# skip loading position embeddings
|
| 58 |
+
adapter.load_state_dict(conv_state_dict, strict=False)
|
| 59 |
+
adapter.save_pretrained(args.output_path)
|
| 60 |
+
|
| 61 |
+
if args.save_fp16:
|
| 62 |
+
adapter.to(dtype=torch.float16).save_pretrained(args.output_path, variant="fp16")
|
diffusers/scripts/convert_animatediff_sparsectrl_to_diffusers.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from typing import Dict
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from diffusers import SparseControlNetModel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
KEYS_RENAME_MAPPING = {
|
| 11 |
+
".attention_blocks.0": ".attn1",
|
| 12 |
+
".attention_blocks.1": ".attn2",
|
| 13 |
+
".attn1.pos_encoder": ".pos_embed",
|
| 14 |
+
".ff_norm": ".norm3",
|
| 15 |
+
".norms.0": ".norm1",
|
| 16 |
+
".norms.1": ".norm2",
|
| 17 |
+
".temporal_transformer": "",
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def convert(original_state_dict: Dict[str, nn.Module]) -> Dict[str, nn.Module]:
|
| 22 |
+
converted_state_dict = {}
|
| 23 |
+
|
| 24 |
+
for key in list(original_state_dict.keys()):
|
| 25 |
+
renamed_key = key
|
| 26 |
+
for new_name, old_name in KEYS_RENAME_MAPPING.items():
|
| 27 |
+
renamed_key = renamed_key.replace(new_name, old_name)
|
| 28 |
+
converted_state_dict[renamed_key] = original_state_dict.pop(key)
|
| 29 |
+
|
| 30 |
+
return converted_state_dict
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_args():
|
| 34 |
+
parser = argparse.ArgumentParser()
|
| 35 |
+
parser.add_argument("--ckpt_path", type=str, required=True, help="Path to checkpoint")
|
| 36 |
+
parser.add_argument("--output_path", type=str, required=True, help="Path to output directory")
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--max_motion_seq_length",
|
| 39 |
+
type=int,
|
| 40 |
+
default=32,
|
| 41 |
+
help="Max motion sequence length supported by the motion adapter",
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--conditioning_channels", type=int, default=4, help="Number of channels in conditioning input to controlnet"
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--use_simplified_condition_embedding",
|
| 48 |
+
action="store_true",
|
| 49 |
+
default=False,
|
| 50 |
+
help="Whether or not to use simplified condition embedding. When `conditioning_channels==4` i.e. latent inputs, set this to `True`. When `conditioning_channels==3` i.e. image inputs, set this to `False`",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--save_fp16",
|
| 54 |
+
action="store_true",
|
| 55 |
+
default=False,
|
| 56 |
+
help="Whether or not to save model in fp16 precision along with fp32",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--push_to_hub", action="store_true", default=False, help="Whether or not to push saved model to the HF hub"
|
| 60 |
+
)
|
| 61 |
+
return parser.parse_args()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
args = get_args()
|
| 66 |
+
|
| 67 |
+
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
| 68 |
+
if "state_dict" in state_dict.keys():
|
| 69 |
+
state_dict: dict = state_dict["state_dict"]
|
| 70 |
+
|
| 71 |
+
controlnet = SparseControlNetModel(
|
| 72 |
+
conditioning_channels=args.conditioning_channels,
|
| 73 |
+
motion_max_seq_length=args.max_motion_seq_length,
|
| 74 |
+
use_simplified_condition_embedding=args.use_simplified_condition_embedding,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
state_dict = convert(state_dict)
|
| 78 |
+
controlnet.load_state_dict(state_dict, strict=True)
|
| 79 |
+
|
| 80 |
+
controlnet.save_pretrained(args.output_path, push_to_hub=args.push_to_hub)
|
| 81 |
+
if args.save_fp16:
|
| 82 |
+
controlnet = controlnet.to(dtype=torch.float16)
|
| 83 |
+
controlnet.save_pretrained(args.output_path, variant="fp16", push_to_hub=args.push_to_hub)
|
diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import time
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Dict, Literal
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from diffusers import AsymmetricAutoencoderKL
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG = {
|
| 12 |
+
"in_channels": 3,
|
| 13 |
+
"out_channels": 3,
|
| 14 |
+
"down_block_types": [
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D",
|
| 17 |
+
"DownEncoderBlock2D",
|
| 18 |
+
"DownEncoderBlock2D",
|
| 19 |
+
],
|
| 20 |
+
"down_block_out_channels": [128, 256, 512, 512],
|
| 21 |
+
"layers_per_down_block": 2,
|
| 22 |
+
"up_block_types": [
|
| 23 |
+
"UpDecoderBlock2D",
|
| 24 |
+
"UpDecoderBlock2D",
|
| 25 |
+
"UpDecoderBlock2D",
|
| 26 |
+
"UpDecoderBlock2D",
|
| 27 |
+
],
|
| 28 |
+
"up_block_out_channels": [192, 384, 768, 768],
|
| 29 |
+
"layers_per_up_block": 3,
|
| 30 |
+
"act_fn": "silu",
|
| 31 |
+
"latent_channels": 4,
|
| 32 |
+
"norm_num_groups": 32,
|
| 33 |
+
"sample_size": 256,
|
| 34 |
+
"scaling_factor": 0.18215,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG = {
|
| 38 |
+
"in_channels": 3,
|
| 39 |
+
"out_channels": 3,
|
| 40 |
+
"down_block_types": [
|
| 41 |
+
"DownEncoderBlock2D",
|
| 42 |
+
"DownEncoderBlock2D",
|
| 43 |
+
"DownEncoderBlock2D",
|
| 44 |
+
"DownEncoderBlock2D",
|
| 45 |
+
],
|
| 46 |
+
"down_block_out_channels": [128, 256, 512, 512],
|
| 47 |
+
"layers_per_down_block": 2,
|
| 48 |
+
"up_block_types": [
|
| 49 |
+
"UpDecoderBlock2D",
|
| 50 |
+
"UpDecoderBlock2D",
|
| 51 |
+
"UpDecoderBlock2D",
|
| 52 |
+
"UpDecoderBlock2D",
|
| 53 |
+
],
|
| 54 |
+
"up_block_out_channels": [256, 512, 1024, 1024],
|
| 55 |
+
"layers_per_up_block": 5,
|
| 56 |
+
"act_fn": "silu",
|
| 57 |
+
"latent_channels": 4,
|
| 58 |
+
"norm_num_groups": 32,
|
| 59 |
+
"sample_size": 256,
|
| 60 |
+
"scaling_factor": 0.18215,
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def convert_asymmetric_autoencoder_kl_state_dict(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
| 65 |
+
converted_state_dict = {}
|
| 66 |
+
for k, v in original_state_dict.items():
|
| 67 |
+
if k.startswith("encoder."):
|
| 68 |
+
converted_state_dict[
|
| 69 |
+
k.replace("encoder.down.", "encoder.down_blocks.")
|
| 70 |
+
.replace("encoder.mid.", "encoder.mid_block.")
|
| 71 |
+
.replace("encoder.norm_out.", "encoder.conv_norm_out.")
|
| 72 |
+
.replace(".downsample.", ".downsamplers.0.")
|
| 73 |
+
.replace(".nin_shortcut.", ".conv_shortcut.")
|
| 74 |
+
.replace(".block.", ".resnets.")
|
| 75 |
+
.replace(".block_1.", ".resnets.0.")
|
| 76 |
+
.replace(".block_2.", ".resnets.1.")
|
| 77 |
+
.replace(".attn_1.k.", ".attentions.0.to_k.")
|
| 78 |
+
.replace(".attn_1.q.", ".attentions.0.to_q.")
|
| 79 |
+
.replace(".attn_1.v.", ".attentions.0.to_v.")
|
| 80 |
+
.replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
|
| 81 |
+
.replace(".attn_1.norm.", ".attentions.0.group_norm.")
|
| 82 |
+
] = v
|
| 83 |
+
elif k.startswith("decoder.") and "up_layers" not in k:
|
| 84 |
+
converted_state_dict[
|
| 85 |
+
k.replace("decoder.encoder.", "decoder.condition_encoder.")
|
| 86 |
+
.replace(".norm_out.", ".conv_norm_out.")
|
| 87 |
+
.replace(".up.0.", ".up_blocks.3.")
|
| 88 |
+
.replace(".up.1.", ".up_blocks.2.")
|
| 89 |
+
.replace(".up.2.", ".up_blocks.1.")
|
| 90 |
+
.replace(".up.3.", ".up_blocks.0.")
|
| 91 |
+
.replace(".block.", ".resnets.")
|
| 92 |
+
.replace("mid", "mid_block")
|
| 93 |
+
.replace(".0.upsample.", ".0.upsamplers.0.")
|
| 94 |
+
.replace(".1.upsample.", ".1.upsamplers.0.")
|
| 95 |
+
.replace(".2.upsample.", ".2.upsamplers.0.")
|
| 96 |
+
.replace(".nin_shortcut.", ".conv_shortcut.")
|
| 97 |
+
.replace(".block_1.", ".resnets.0.")
|
| 98 |
+
.replace(".block_2.", ".resnets.1.")
|
| 99 |
+
.replace(".attn_1.k.", ".attentions.0.to_k.")
|
| 100 |
+
.replace(".attn_1.q.", ".attentions.0.to_q.")
|
| 101 |
+
.replace(".attn_1.v.", ".attentions.0.to_v.")
|
| 102 |
+
.replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
|
| 103 |
+
.replace(".attn_1.norm.", ".attentions.0.group_norm.")
|
| 104 |
+
] = v
|
| 105 |
+
elif k.startswith("quant_conv."):
|
| 106 |
+
converted_state_dict[k] = v
|
| 107 |
+
elif k.startswith("post_quant_conv."):
|
| 108 |
+
converted_state_dict[k] = v
|
| 109 |
+
else:
|
| 110 |
+
print(f" skipping key `{k}`")
|
| 111 |
+
# fix weights shape
|
| 112 |
+
for k, v in converted_state_dict.items():
|
| 113 |
+
if (
|
| 114 |
+
(k.startswith("encoder.mid_block.attentions.0") or k.startswith("decoder.mid_block.attentions.0"))
|
| 115 |
+
and k.endswith("weight")
|
| 116 |
+
and ("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k)
|
| 117 |
+
):
|
| 118 |
+
converted_state_dict[k] = converted_state_dict[k][:, :, 0, 0]
|
| 119 |
+
|
| 120 |
+
return converted_state_dict
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def get_asymmetric_autoencoder_kl_from_original_checkpoint(
|
| 124 |
+
scale: Literal["1.5", "2"], original_checkpoint_path: str, map_location: torch.device
|
| 125 |
+
) -> AsymmetricAutoencoderKL:
|
| 126 |
+
print("Loading original state_dict")
|
| 127 |
+
original_state_dict = torch.load(original_checkpoint_path, map_location=map_location)
|
| 128 |
+
original_state_dict = original_state_dict["state_dict"]
|
| 129 |
+
print("Converting state_dict")
|
| 130 |
+
converted_state_dict = convert_asymmetric_autoencoder_kl_state_dict(original_state_dict)
|
| 131 |
+
kwargs = ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG if scale == "1.5" else ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG
|
| 132 |
+
print("Initializing AsymmetricAutoencoderKL model")
|
| 133 |
+
asymmetric_autoencoder_kl = AsymmetricAutoencoderKL(**kwargs)
|
| 134 |
+
print("Loading weight from converted state_dict")
|
| 135 |
+
asymmetric_autoencoder_kl.load_state_dict(converted_state_dict)
|
| 136 |
+
asymmetric_autoencoder_kl.eval()
|
| 137 |
+
print("AsymmetricAutoencoderKL successfully initialized")
|
| 138 |
+
return asymmetric_autoencoder_kl
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if __name__ == "__main__":
|
| 142 |
+
start = time.time()
|
| 143 |
+
parser = argparse.ArgumentParser()
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--scale",
|
| 146 |
+
default=None,
|
| 147 |
+
type=str,
|
| 148 |
+
required=True,
|
| 149 |
+
help="Asymmetric VQGAN scale: `1.5` or `2`",
|
| 150 |
+
)
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--original_checkpoint_path",
|
| 153 |
+
default=None,
|
| 154 |
+
type=str,
|
| 155 |
+
required=True,
|
| 156 |
+
help="Path to the original Asymmetric VQGAN checkpoint",
|
| 157 |
+
)
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--output_path",
|
| 160 |
+
default=None,
|
| 161 |
+
type=str,
|
| 162 |
+
required=True,
|
| 163 |
+
help="Path to save pretrained AsymmetricAutoencoderKL model",
|
| 164 |
+
)
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--map_location",
|
| 167 |
+
default="cpu",
|
| 168 |
+
type=str,
|
| 169 |
+
required=False,
|
| 170 |
+
help="The device passed to `map_location` when loading the checkpoint",
|
| 171 |
+
)
|
| 172 |
+
args = parser.parse_args()
|
| 173 |
+
|
| 174 |
+
assert args.scale in ["1.5", "2"], f"{args.scale} should be `1.5` of `2`"
|
| 175 |
+
assert Path(args.original_checkpoint_path).is_file()
|
| 176 |
+
|
| 177 |
+
asymmetric_autoencoder_kl = get_asymmetric_autoencoder_kl_from_original_checkpoint(
|
| 178 |
+
scale=args.scale,
|
| 179 |
+
original_checkpoint_path=args.original_checkpoint_path,
|
| 180 |
+
map_location=torch.device(args.map_location),
|
| 181 |
+
)
|
| 182 |
+
print("Saving pretrained AsymmetricAutoencoderKL")
|
| 183 |
+
asymmetric_autoencoder_kl.save_pretrained(args.output_path)
|
| 184 |
+
print(f"Done in {time.time() - start:.2f} seconds")
|
diffusers/scripts/convert_aura_flow_to_diffusers.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
|
| 6 |
+
from diffusers.models.transformers.auraflow_transformer_2d import AuraFlowTransformer2DModel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_original_state_dict(args):
|
| 10 |
+
model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="aura_diffusion_pytorch_model.bin")
|
| 11 |
+
state_dict = torch.load(model_pt, map_location="cpu")
|
| 12 |
+
return state_dict
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def calculate_layers(state_dict_keys, key_prefix):
|
| 16 |
+
dit_layers = set()
|
| 17 |
+
for k in state_dict_keys:
|
| 18 |
+
if key_prefix in k:
|
| 19 |
+
dit_layers.add(int(k.split(".")[2]))
|
| 20 |
+
print(f"{key_prefix}: {len(dit_layers)}")
|
| 21 |
+
return len(dit_layers)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# similar to SD3 but only for the last norm layer
|
| 25 |
+
def swap_scale_shift(weight, dim):
|
| 26 |
+
shift, scale = weight.chunk(2, dim=0)
|
| 27 |
+
new_weight = torch.cat([scale, shift], dim=0)
|
| 28 |
+
return new_weight
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def convert_transformer(state_dict):
|
| 32 |
+
converted_state_dict = {}
|
| 33 |
+
state_dict_keys = list(state_dict.keys())
|
| 34 |
+
|
| 35 |
+
converted_state_dict["register_tokens"] = state_dict.pop("model.register_tokens")
|
| 36 |
+
converted_state_dict["pos_embed.pos_embed"] = state_dict.pop("model.positional_encoding")
|
| 37 |
+
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("model.init_x_linear.weight")
|
| 38 |
+
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("model.init_x_linear.bias")
|
| 39 |
+
|
| 40 |
+
converted_state_dict["time_step_proj.linear_1.weight"] = state_dict.pop("model.t_embedder.mlp.0.weight")
|
| 41 |
+
converted_state_dict["time_step_proj.linear_1.bias"] = state_dict.pop("model.t_embedder.mlp.0.bias")
|
| 42 |
+
converted_state_dict["time_step_proj.linear_2.weight"] = state_dict.pop("model.t_embedder.mlp.2.weight")
|
| 43 |
+
converted_state_dict["time_step_proj.linear_2.bias"] = state_dict.pop("model.t_embedder.mlp.2.bias")
|
| 44 |
+
|
| 45 |
+
converted_state_dict["context_embedder.weight"] = state_dict.pop("model.cond_seq_linear.weight")
|
| 46 |
+
|
| 47 |
+
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
|
| 48 |
+
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
|
| 49 |
+
|
| 50 |
+
# MMDiT blocks 🎸.
|
| 51 |
+
for i in range(mmdit_layers):
|
| 52 |
+
# feed-forward
|
| 53 |
+
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
|
| 54 |
+
weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
|
| 55 |
+
for orig_k, diffuser_k in path_mapping.items():
|
| 56 |
+
for k, v in weight_mapping.items():
|
| 57 |
+
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = state_dict.pop(
|
| 58 |
+
f"model.double_layers.{i}.{orig_k}.{k}.weight"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# norms
|
| 62 |
+
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
|
| 63 |
+
for orig_k, diffuser_k in path_mapping.items():
|
| 64 |
+
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = state_dict.pop(
|
| 65 |
+
f"model.double_layers.{i}.{orig_k}.1.weight"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# attns
|
| 69 |
+
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
|
| 70 |
+
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
|
| 71 |
+
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
|
| 72 |
+
for k, v in attn_mapping.items():
|
| 73 |
+
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
|
| 74 |
+
f"model.double_layers.{i}.attn.{k}.weight"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Single-DiT blocks.
|
| 78 |
+
for i in range(single_dit_layers):
|
| 79 |
+
# feed-forward
|
| 80 |
+
mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
|
| 81 |
+
for k, v in mapping.items():
|
| 82 |
+
converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = state_dict.pop(
|
| 83 |
+
f"model.single_layers.{i}.mlp.{k}.weight"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# norms
|
| 87 |
+
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = state_dict.pop(
|
| 88 |
+
f"model.single_layers.{i}.modCX.1.weight"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# attns
|
| 92 |
+
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
|
| 93 |
+
for k, v in x_attn_mapping.items():
|
| 94 |
+
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
|
| 95 |
+
f"model.single_layers.{i}.attn.{k}.weight"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Final blocks.
|
| 99 |
+
converted_state_dict["proj_out.weight"] = state_dict.pop("model.final_linear.weight")
|
| 100 |
+
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict.pop("model.modF.1.weight"), dim=None)
|
| 101 |
+
|
| 102 |
+
return converted_state_dict
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@torch.no_grad()
|
| 106 |
+
def populate_state_dict(args):
|
| 107 |
+
original_state_dict = load_original_state_dict(args)
|
| 108 |
+
state_dict_keys = list(original_state_dict.keys())
|
| 109 |
+
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
|
| 110 |
+
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
|
| 111 |
+
|
| 112 |
+
converted_state_dict = convert_transformer(original_state_dict)
|
| 113 |
+
model_diffusers = AuraFlowTransformer2DModel(
|
| 114 |
+
num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers
|
| 115 |
+
)
|
| 116 |
+
model_diffusers.load_state_dict(converted_state_dict, strict=True)
|
| 117 |
+
|
| 118 |
+
return model_diffusers
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
parser = argparse.ArgumentParser()
|
| 123 |
+
parser.add_argument("--original_state_dict_repo_id", default="AuraDiffusion/auradiffusion-v0.1a0", type=str)
|
| 124 |
+
parser.add_argument("--dump_path", default="aura-flow", type=str)
|
| 125 |
+
parser.add_argument("--hub_id", default=None, type=str)
|
| 126 |
+
args = parser.parse_args()
|
| 127 |
+
|
| 128 |
+
model_diffusers = populate_state_dict(args)
|
| 129 |
+
model_diffusers.save_pretrained(args.dump_path)
|
| 130 |
+
if args.hub_id is not None:
|
| 131 |
+
model_diffusers.push_to_hub(args.hub_id)
|
diffusers/scripts/convert_blipdiffusion_to_diffusers.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from lavis.models import load_model_and_preprocess
|
| 11 |
+
from transformers import CLIPTokenizer
|
| 12 |
+
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
|
| 13 |
+
|
| 14 |
+
from diffusers import (
|
| 15 |
+
AutoencoderKL,
|
| 16 |
+
PNDMScheduler,
|
| 17 |
+
UNet2DConditionModel,
|
| 18 |
+
)
|
| 19 |
+
from diffusers.pipelines import BlipDiffusionPipeline
|
| 20 |
+
from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
|
| 21 |
+
from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
|
| 22 |
+
from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
BLIP2_CONFIG = {
|
| 26 |
+
"vision_config": {
|
| 27 |
+
"hidden_size": 1024,
|
| 28 |
+
"num_hidden_layers": 23,
|
| 29 |
+
"num_attention_heads": 16,
|
| 30 |
+
"image_size": 224,
|
| 31 |
+
"patch_size": 14,
|
| 32 |
+
"intermediate_size": 4096,
|
| 33 |
+
"hidden_act": "quick_gelu",
|
| 34 |
+
},
|
| 35 |
+
"qformer_config": {
|
| 36 |
+
"cross_attention_frequency": 1,
|
| 37 |
+
"encoder_hidden_size": 1024,
|
| 38 |
+
"vocab_size": 30523,
|
| 39 |
+
},
|
| 40 |
+
"num_query_tokens": 16,
|
| 41 |
+
}
|
| 42 |
+
blip2config = Blip2Config(**BLIP2_CONFIG)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def qformer_model_from_original_config():
|
| 46 |
+
qformer = Blip2QFormerModel(blip2config)
|
| 47 |
+
return qformer
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix):
|
| 51 |
+
embeddings = {}
|
| 52 |
+
embeddings.update(
|
| 53 |
+
{
|
| 54 |
+
f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[
|
| 55 |
+
f"{original_embeddings_prefix}.word_embeddings.weight"
|
| 56 |
+
]
|
| 57 |
+
}
|
| 58 |
+
)
|
| 59 |
+
embeddings.update(
|
| 60 |
+
{
|
| 61 |
+
f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[
|
| 62 |
+
f"{original_embeddings_prefix}.position_embeddings.weight"
|
| 63 |
+
]
|
| 64 |
+
}
|
| 65 |
+
)
|
| 66 |
+
embeddings.update(
|
| 67 |
+
{f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]}
|
| 68 |
+
)
|
| 69 |
+
embeddings.update(
|
| 70 |
+
{f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]}
|
| 71 |
+
)
|
| 72 |
+
return embeddings
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix):
|
| 76 |
+
proj_layer = {}
|
| 77 |
+
proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]})
|
| 78 |
+
proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]})
|
| 79 |
+
proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]})
|
| 80 |
+
proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]})
|
| 81 |
+
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]})
|
| 82 |
+
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]})
|
| 83 |
+
return proj_layer
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix):
|
| 87 |
+
attention = {}
|
| 88 |
+
attention.update(
|
| 89 |
+
{
|
| 90 |
+
f"{diffuser_attention_prefix}.attention.query.weight": model[
|
| 91 |
+
f"{original_attention_prefix}.self.query.weight"
|
| 92 |
+
]
|
| 93 |
+
}
|
| 94 |
+
)
|
| 95 |
+
attention.update(
|
| 96 |
+
{f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]}
|
| 97 |
+
)
|
| 98 |
+
attention.update(
|
| 99 |
+
{f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]}
|
| 100 |
+
)
|
| 101 |
+
attention.update(
|
| 102 |
+
{f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]}
|
| 103 |
+
)
|
| 104 |
+
attention.update(
|
| 105 |
+
{
|
| 106 |
+
f"{diffuser_attention_prefix}.attention.value.weight": model[
|
| 107 |
+
f"{original_attention_prefix}.self.value.weight"
|
| 108 |
+
]
|
| 109 |
+
}
|
| 110 |
+
)
|
| 111 |
+
attention.update(
|
| 112 |
+
{f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]}
|
| 113 |
+
)
|
| 114 |
+
attention.update(
|
| 115 |
+
{f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]}
|
| 116 |
+
)
|
| 117 |
+
attention.update(
|
| 118 |
+
{f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]}
|
| 119 |
+
)
|
| 120 |
+
attention.update(
|
| 121 |
+
{
|
| 122 |
+
f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[
|
| 123 |
+
f"{original_attention_prefix}.output.LayerNorm.weight"
|
| 124 |
+
]
|
| 125 |
+
}
|
| 126 |
+
)
|
| 127 |
+
attention.update(
|
| 128 |
+
{
|
| 129 |
+
f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[
|
| 130 |
+
f"{original_attention_prefix}.output.LayerNorm.bias"
|
| 131 |
+
]
|
| 132 |
+
}
|
| 133 |
+
)
|
| 134 |
+
return attention
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix):
|
| 138 |
+
output_layers = {}
|
| 139 |
+
output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]})
|
| 140 |
+
output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]})
|
| 141 |
+
output_layers.update(
|
| 142 |
+
{f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]}
|
| 143 |
+
)
|
| 144 |
+
output_layers.update(
|
| 145 |
+
{f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]}
|
| 146 |
+
)
|
| 147 |
+
return output_layers
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix):
|
| 151 |
+
encoder = {}
|
| 152 |
+
for i in range(blip2config.qformer_config.num_hidden_layers):
|
| 153 |
+
encoder.update(
|
| 154 |
+
attention_from_original_checkpoint(
|
| 155 |
+
model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention"
|
| 156 |
+
)
|
| 157 |
+
)
|
| 158 |
+
encoder.update(
|
| 159 |
+
attention_from_original_checkpoint(
|
| 160 |
+
model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention"
|
| 161 |
+
)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
encoder.update(
|
| 165 |
+
{
|
| 166 |
+
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[
|
| 167 |
+
f"{original_encoder_prefix}.{i}.intermediate.dense.weight"
|
| 168 |
+
]
|
| 169 |
+
}
|
| 170 |
+
)
|
| 171 |
+
encoder.update(
|
| 172 |
+
{
|
| 173 |
+
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[
|
| 174 |
+
f"{original_encoder_prefix}.{i}.intermediate.dense.bias"
|
| 175 |
+
]
|
| 176 |
+
}
|
| 177 |
+
)
|
| 178 |
+
encoder.update(
|
| 179 |
+
{
|
| 180 |
+
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[
|
| 181 |
+
f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight"
|
| 182 |
+
]
|
| 183 |
+
}
|
| 184 |
+
)
|
| 185 |
+
encoder.update(
|
| 186 |
+
{
|
| 187 |
+
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[
|
| 188 |
+
f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias"
|
| 189 |
+
]
|
| 190 |
+
}
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
encoder.update(
|
| 194 |
+
output_layers_from_original_checkpoint(
|
| 195 |
+
model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output"
|
| 196 |
+
)
|
| 197 |
+
)
|
| 198 |
+
encoder.update(
|
| 199 |
+
output_layers_from_original_checkpoint(
|
| 200 |
+
model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query"
|
| 201 |
+
)
|
| 202 |
+
)
|
| 203 |
+
return encoder
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix):
|
| 207 |
+
visual_encoder_layer = {}
|
| 208 |
+
|
| 209 |
+
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]})
|
| 210 |
+
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]})
|
| 211 |
+
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]})
|
| 212 |
+
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]})
|
| 213 |
+
visual_encoder_layer.update(
|
| 214 |
+
{f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]}
|
| 215 |
+
)
|
| 216 |
+
visual_encoder_layer.update(
|
| 217 |
+
{f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]}
|
| 218 |
+
)
|
| 219 |
+
visual_encoder_layer.update(
|
| 220 |
+
{f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]}
|
| 221 |
+
)
|
| 222 |
+
visual_encoder_layer.update(
|
| 223 |
+
{f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]}
|
| 224 |
+
)
|
| 225 |
+
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]})
|
| 226 |
+
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]})
|
| 227 |
+
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]})
|
| 228 |
+
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]})
|
| 229 |
+
|
| 230 |
+
return visual_encoder_layer
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix):
|
| 234 |
+
visual_encoder = {}
|
| 235 |
+
|
| 236 |
+
visual_encoder.update(
|
| 237 |
+
{
|
| 238 |
+
f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"]
|
| 239 |
+
.unsqueeze(0)
|
| 240 |
+
.unsqueeze(0)
|
| 241 |
+
}
|
| 242 |
+
)
|
| 243 |
+
visual_encoder.update(
|
| 244 |
+
{
|
| 245 |
+
f"{diffuser_prefix}.embeddings.position_embedding": model[
|
| 246 |
+
f"{original_prefix}.positional_embedding"
|
| 247 |
+
].unsqueeze(0)
|
| 248 |
+
}
|
| 249 |
+
)
|
| 250 |
+
visual_encoder.update(
|
| 251 |
+
{f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]}
|
| 252 |
+
)
|
| 253 |
+
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]})
|
| 254 |
+
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]})
|
| 255 |
+
|
| 256 |
+
for i in range(blip2config.vision_config.num_hidden_layers):
|
| 257 |
+
visual_encoder.update(
|
| 258 |
+
visual_encoder_layer_from_original_checkpoint(
|
| 259 |
+
model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}"
|
| 260 |
+
)
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]})
|
| 264 |
+
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]})
|
| 265 |
+
|
| 266 |
+
return visual_encoder
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def qformer_original_checkpoint_to_diffusers_checkpoint(model):
|
| 270 |
+
qformer_checkpoint = {}
|
| 271 |
+
qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings"))
|
| 272 |
+
qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]})
|
| 273 |
+
qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer"))
|
| 274 |
+
qformer_checkpoint.update(
|
| 275 |
+
encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer")
|
| 276 |
+
)
|
| 277 |
+
qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder"))
|
| 278 |
+
return qformer_checkpoint
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def get_qformer(model):
|
| 282 |
+
print("loading qformer")
|
| 283 |
+
|
| 284 |
+
qformer = qformer_model_from_original_config()
|
| 285 |
+
qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model)
|
| 286 |
+
|
| 287 |
+
load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer)
|
| 288 |
+
|
| 289 |
+
print("done loading qformer")
|
| 290 |
+
return qformer
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def load_checkpoint_to_model(checkpoint, model):
|
| 294 |
+
with tempfile.NamedTemporaryFile(delete=False) as file:
|
| 295 |
+
torch.save(checkpoint, file.name)
|
| 296 |
+
del checkpoint
|
| 297 |
+
model.load_state_dict(torch.load(file.name), strict=False)
|
| 298 |
+
|
| 299 |
+
os.remove(file.name)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def save_blip_diffusion_model(model, args):
|
| 303 |
+
qformer = get_qformer(model)
|
| 304 |
+
qformer.eval()
|
| 305 |
+
|
| 306 |
+
text_encoder = ContextCLIPTextModel.from_pretrained(
|
| 307 |
+
"stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="text_encoder"
|
| 308 |
+
)
|
| 309 |
+
vae = AutoencoderKL.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="vae")
|
| 310 |
+
unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
|
| 311 |
+
vae.eval()
|
| 312 |
+
text_encoder.eval()
|
| 313 |
+
scheduler = PNDMScheduler(
|
| 314 |
+
beta_start=0.00085,
|
| 315 |
+
beta_end=0.012,
|
| 316 |
+
beta_schedule="scaled_linear",
|
| 317 |
+
set_alpha_to_one=False,
|
| 318 |
+
skip_prk_steps=True,
|
| 319 |
+
)
|
| 320 |
+
tokenizer = CLIPTokenizer.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="tokenizer")
|
| 321 |
+
image_processor = BlipImageProcessor()
|
| 322 |
+
blip_diffusion = BlipDiffusionPipeline(
|
| 323 |
+
tokenizer=tokenizer,
|
| 324 |
+
text_encoder=text_encoder,
|
| 325 |
+
vae=vae,
|
| 326 |
+
unet=unet,
|
| 327 |
+
scheduler=scheduler,
|
| 328 |
+
qformer=qformer,
|
| 329 |
+
image_processor=image_processor,
|
| 330 |
+
)
|
| 331 |
+
blip_diffusion.save_pretrained(args.checkpoint_path)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def main(args):
|
| 335 |
+
model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True)
|
| 336 |
+
save_blip_diffusion_model(model.state_dict(), args)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
if __name__ == "__main__":
|
| 340 |
+
parser = argparse.ArgumentParser()
|
| 341 |
+
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 342 |
+
args = parser.parse_args()
|
| 343 |
+
|
| 344 |
+
main(args)
|
diffusers/scripts/convert_cogview3_to_diffusers.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Convert a CogView3 checkpoint to the Diffusers format.
|
| 3 |
+
|
| 4 |
+
This script converts a CogView3 checkpoint to the Diffusers format, which can then be used
|
| 5 |
+
with the Diffusers library.
|
| 6 |
+
|
| 7 |
+
Example usage:
|
| 8 |
+
python scripts/convert_cogview3_to_diffusers.py \
|
| 9 |
+
--transformer_checkpoint_path 'your path/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
|
| 10 |
+
--vae_checkpoint_path 'your path/3plus_ae/imagekl_ch16.pt' \
|
| 11 |
+
--output_path "/raid/yiyi/cogview3_diffusers" \
|
| 12 |
+
--dtype "bf16"
|
| 13 |
+
|
| 14 |
+
Arguments:
|
| 15 |
+
--transformer_checkpoint_path: Path to Transformer state dict.
|
| 16 |
+
--vae_checkpoint_path: Path to VAE state dict.
|
| 17 |
+
--output_path: The path to save the converted model.
|
| 18 |
+
--push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
|
| 19 |
+
--text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
|
| 20 |
+
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
|
| 21 |
+
|
| 22 |
+
Default is "bf16" because CogView3 uses bfloat16 for Training.
|
| 23 |
+
|
| 24 |
+
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
from contextlib import nullcontext
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
from accelerate import init_empty_weights
|
| 32 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 33 |
+
|
| 34 |
+
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
|
| 35 |
+
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
|
| 36 |
+
from diffusers.utils.import_utils import is_accelerate_available
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
| 40 |
+
|
| 41 |
+
TOKENIZER_MAX_LENGTH = 224
|
| 42 |
+
|
| 43 |
+
parser = argparse.ArgumentParser()
|
| 44 |
+
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
|
| 45 |
+
parser.add_argument("--vae_checkpoint_path", default=None, type=str)
|
| 46 |
+
parser.add_argument("--output_path", required=True, type=str)
|
| 47 |
+
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
|
| 48 |
+
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
|
| 49 |
+
parser.add_argument("--dtype", type=str, default="bf16")
|
| 50 |
+
|
| 51 |
+
args = parser.parse_args()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# this is specific to `AdaLayerNormContinuous`:
|
| 55 |
+
# diffusers implementation split the linear projection into the scale, shift while CogView3 split it tino shift, scale
|
| 56 |
+
def swap_scale_shift(weight, dim):
|
| 57 |
+
shift, scale = weight.chunk(2, dim=0)
|
| 58 |
+
new_weight = torch.cat([scale, shift], dim=0)
|
| 59 |
+
return new_weight
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def convert_cogview3_transformer_checkpoint_to_diffusers(ckpt_path):
|
| 63 |
+
original_state_dict = torch.load(ckpt_path, map_location="cpu")
|
| 64 |
+
original_state_dict = original_state_dict["module"]
|
| 65 |
+
original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()}
|
| 66 |
+
|
| 67 |
+
new_state_dict = {}
|
| 68 |
+
|
| 69 |
+
# Convert patch_embed
|
| 70 |
+
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
|
| 71 |
+
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
|
| 72 |
+
new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
|
| 73 |
+
new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
|
| 74 |
+
|
| 75 |
+
# Convert time_condition_embed
|
| 76 |
+
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
|
| 77 |
+
"time_embed.0.weight"
|
| 78 |
+
)
|
| 79 |
+
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
|
| 80 |
+
"time_embed.0.bias"
|
| 81 |
+
)
|
| 82 |
+
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
|
| 83 |
+
"time_embed.2.weight"
|
| 84 |
+
)
|
| 85 |
+
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
|
| 86 |
+
"time_embed.2.bias"
|
| 87 |
+
)
|
| 88 |
+
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop(
|
| 89 |
+
"label_emb.0.0.weight"
|
| 90 |
+
)
|
| 91 |
+
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop(
|
| 92 |
+
"label_emb.0.0.bias"
|
| 93 |
+
)
|
| 94 |
+
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop(
|
| 95 |
+
"label_emb.0.2.weight"
|
| 96 |
+
)
|
| 97 |
+
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop(
|
| 98 |
+
"label_emb.0.2.bias"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Convert transformer blocks
|
| 102 |
+
for i in range(30):
|
| 103 |
+
block_prefix = f"transformer_blocks.{i}."
|
| 104 |
+
old_prefix = f"transformer.layers.{i}."
|
| 105 |
+
adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
|
| 106 |
+
|
| 107 |
+
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
|
| 108 |
+
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
|
| 109 |
+
|
| 110 |
+
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
|
| 111 |
+
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
|
| 112 |
+
q, k, v = qkv_weight.chunk(3, dim=0)
|
| 113 |
+
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
|
| 114 |
+
|
| 115 |
+
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
| 116 |
+
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
|
| 117 |
+
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
| 118 |
+
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
|
| 119 |
+
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
| 120 |
+
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
|
| 121 |
+
|
| 122 |
+
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
|
| 123 |
+
old_prefix + "attention.dense.weight"
|
| 124 |
+
)
|
| 125 |
+
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(
|
| 126 |
+
old_prefix + "attention.dense.bias"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
|
| 130 |
+
old_prefix + "mlp.dense_h_to_4h.weight"
|
| 131 |
+
)
|
| 132 |
+
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
|
| 133 |
+
old_prefix + "mlp.dense_h_to_4h.bias"
|
| 134 |
+
)
|
| 135 |
+
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
|
| 136 |
+
old_prefix + "mlp.dense_4h_to_h.weight"
|
| 137 |
+
)
|
| 138 |
+
new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")
|
| 139 |
+
|
| 140 |
+
# Convert final norm and projection
|
| 141 |
+
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
| 142 |
+
original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
|
| 143 |
+
)
|
| 144 |
+
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
|
| 145 |
+
original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
|
| 146 |
+
)
|
| 147 |
+
new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
|
| 148 |
+
new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")
|
| 149 |
+
|
| 150 |
+
return new_state_dict
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def convert_cogview3_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
|
| 154 |
+
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
| 155 |
+
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def main(args):
|
| 159 |
+
if args.dtype == "fp16":
|
| 160 |
+
dtype = torch.float16
|
| 161 |
+
elif args.dtype == "bf16":
|
| 162 |
+
dtype = torch.bfloat16
|
| 163 |
+
elif args.dtype == "fp32":
|
| 164 |
+
dtype = torch.float32
|
| 165 |
+
else:
|
| 166 |
+
raise ValueError(f"Unsupported dtype: {args.dtype}")
|
| 167 |
+
|
| 168 |
+
transformer = None
|
| 169 |
+
vae = None
|
| 170 |
+
|
| 171 |
+
if args.transformer_checkpoint_path is not None:
|
| 172 |
+
converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers(
|
| 173 |
+
args.transformer_checkpoint_path
|
| 174 |
+
)
|
| 175 |
+
transformer = CogView3PlusTransformer2DModel()
|
| 176 |
+
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
|
| 177 |
+
if dtype is not None:
|
| 178 |
+
# Original checkpoint data type will be preserved
|
| 179 |
+
transformer = transformer.to(dtype=dtype)
|
| 180 |
+
|
| 181 |
+
if args.vae_checkpoint_path is not None:
|
| 182 |
+
vae_config = {
|
| 183 |
+
"in_channels": 3,
|
| 184 |
+
"out_channels": 3,
|
| 185 |
+
"down_block_types": ("DownEncoderBlock2D",) * 4,
|
| 186 |
+
"up_block_types": ("UpDecoderBlock2D",) * 4,
|
| 187 |
+
"block_out_channels": (128, 512, 1024, 1024),
|
| 188 |
+
"layers_per_block": 3,
|
| 189 |
+
"act_fn": "silu",
|
| 190 |
+
"latent_channels": 16,
|
| 191 |
+
"norm_num_groups": 32,
|
| 192 |
+
"sample_size": 1024,
|
| 193 |
+
"scaling_factor": 1.0,
|
| 194 |
+
"force_upcast": True,
|
| 195 |
+
"use_quant_conv": False,
|
| 196 |
+
"use_post_quant_conv": False,
|
| 197 |
+
"mid_block_add_attention": False,
|
| 198 |
+
}
|
| 199 |
+
converted_vae_state_dict = convert_cogview3_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
|
| 200 |
+
vae = AutoencoderKL(**vae_config)
|
| 201 |
+
vae.load_state_dict(converted_vae_state_dict, strict=True)
|
| 202 |
+
if dtype is not None:
|
| 203 |
+
vae = vae.to(dtype=dtype)
|
| 204 |
+
|
| 205 |
+
text_encoder_id = "google/t5-v1_1-xxl"
|
| 206 |
+
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
| 207 |
+
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
| 208 |
+
|
| 209 |
+
# Apparently, the conversion does not work anymore without this :shrug:
|
| 210 |
+
for param in text_encoder.parameters():
|
| 211 |
+
param.data = param.data.contiguous()
|
| 212 |
+
|
| 213 |
+
scheduler = CogVideoXDDIMScheduler.from_config(
|
| 214 |
+
{
|
| 215 |
+
"snr_shift_scale": 4.0,
|
| 216 |
+
"beta_end": 0.012,
|
| 217 |
+
"beta_schedule": "scaled_linear",
|
| 218 |
+
"beta_start": 0.00085,
|
| 219 |
+
"clip_sample": False,
|
| 220 |
+
"num_train_timesteps": 1000,
|
| 221 |
+
"prediction_type": "v_prediction",
|
| 222 |
+
"rescale_betas_zero_snr": True,
|
| 223 |
+
"set_alpha_to_one": True,
|
| 224 |
+
"timestep_spacing": "trailing",
|
| 225 |
+
}
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
pipe = CogView3PlusPipeline(
|
| 229 |
+
tokenizer=tokenizer,
|
| 230 |
+
text_encoder=text_encoder,
|
| 231 |
+
vae=vae,
|
| 232 |
+
transformer=transformer,
|
| 233 |
+
scheduler=scheduler,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can
|
| 237 |
+
# save some memory used for model loading.
|
| 238 |
+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
if __name__ == "__main__":
|
| 242 |
+
main(args)
|
diffusers/scripts/convert_cogview4_to_diffusers.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Convert a CogView4 checkpoint from SAT(https://github.com/THUDM/SwissArmyTransformer) to the Diffusers format.
|
| 3 |
+
(deprecated Since 2025-02-07 and will remove it in later CogView4 version)
|
| 4 |
+
|
| 5 |
+
This script converts a CogView4 checkpoint to the Diffusers format, which can then be used
|
| 6 |
+
with the Diffusers library.
|
| 7 |
+
|
| 8 |
+
Example usage:
|
| 9 |
+
python scripts/convert_cogview4_to_diffusers.py \
|
| 10 |
+
--transformer_checkpoint_path 'your path/cogview4_6b/1/mp_rank_00_model_states.pt' \
|
| 11 |
+
--vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \
|
| 12 |
+
--output_path "THUDM/CogView4-6B" \
|
| 13 |
+
--dtype "bf16"
|
| 14 |
+
|
| 15 |
+
Arguments:
|
| 16 |
+
--transformer_checkpoint_path: Path to Transformer state dict.
|
| 17 |
+
--vae_checkpoint_path: Path to VAE state dict.
|
| 18 |
+
--output_path: The path to save the converted model.
|
| 19 |
+
--push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
|
| 20 |
+
--text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
|
| 21 |
+
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
|
| 22 |
+
|
| 23 |
+
Default is "bf16" because CogView4 uses bfloat16 for Training.
|
| 24 |
+
|
| 25 |
+
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
from contextlib import nullcontext
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
from accelerate import init_empty_weights
|
| 33 |
+
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
|
| 34 |
+
|
| 35 |
+
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
| 36 |
+
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
|
| 37 |
+
from diffusers.utils.import_utils import is_accelerate_available
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
| 41 |
+
|
| 42 |
+
parser = argparse.ArgumentParser()
|
| 43 |
+
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
|
| 44 |
+
parser.add_argument("--vae_checkpoint_path", default=None, type=str)
|
| 45 |
+
parser.add_argument("--output_path", required=True, type=str)
|
| 46 |
+
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
|
| 47 |
+
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
|
| 48 |
+
parser.add_argument("--dtype", type=str, default="bf16")
|
| 49 |
+
|
| 50 |
+
args = parser.parse_args()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# this is specific to `AdaLayerNormContinuous`:
|
| 54 |
+
# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale
|
| 55 |
+
def swap_scale_shift(weight, dim):
|
| 56 |
+
"""
|
| 57 |
+
Swap the scale and shift components in the weight tensor.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
weight (torch.Tensor): The original weight tensor.
|
| 61 |
+
dim (int): The dimension along which to split.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
torch.Tensor: The modified weight tensor with scale and shift swapped.
|
| 65 |
+
"""
|
| 66 |
+
shift, scale = weight.chunk(2, dim=dim)
|
| 67 |
+
new_weight = torch.cat([scale, shift], dim=dim)
|
| 68 |
+
return new_weight
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def convert_cogview4_transformer_checkpoint_to_diffusers(ckpt_path):
|
| 72 |
+
original_state_dict = torch.load(ckpt_path, map_location="cpu")
|
| 73 |
+
original_state_dict = original_state_dict["module"]
|
| 74 |
+
original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()}
|
| 75 |
+
|
| 76 |
+
new_state_dict = {}
|
| 77 |
+
|
| 78 |
+
# Convert patch_embed
|
| 79 |
+
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
|
| 80 |
+
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
|
| 81 |
+
new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
|
| 82 |
+
new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
|
| 83 |
+
|
| 84 |
+
# Convert time_condition_embed
|
| 85 |
+
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
|
| 86 |
+
"time_embed.0.weight"
|
| 87 |
+
)
|
| 88 |
+
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
|
| 89 |
+
"time_embed.0.bias"
|
| 90 |
+
)
|
| 91 |
+
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
|
| 92 |
+
"time_embed.2.weight"
|
| 93 |
+
)
|
| 94 |
+
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
|
| 95 |
+
"time_embed.2.bias"
|
| 96 |
+
)
|
| 97 |
+
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop(
|
| 98 |
+
"label_emb.0.0.weight"
|
| 99 |
+
)
|
| 100 |
+
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop(
|
| 101 |
+
"label_emb.0.0.bias"
|
| 102 |
+
)
|
| 103 |
+
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop(
|
| 104 |
+
"label_emb.0.2.weight"
|
| 105 |
+
)
|
| 106 |
+
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop(
|
| 107 |
+
"label_emb.0.2.bias"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Convert transformer blocks, for cogview4 is 28 blocks
|
| 111 |
+
for i in range(28):
|
| 112 |
+
block_prefix = f"transformer_blocks.{i}."
|
| 113 |
+
old_prefix = f"transformer.layers.{i}."
|
| 114 |
+
adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
|
| 115 |
+
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
|
| 116 |
+
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
|
| 117 |
+
|
| 118 |
+
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
|
| 119 |
+
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
|
| 120 |
+
q, k, v = qkv_weight.chunk(3, dim=0)
|
| 121 |
+
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
|
| 122 |
+
|
| 123 |
+
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
| 124 |
+
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
|
| 125 |
+
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
| 126 |
+
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
|
| 127 |
+
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
| 128 |
+
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
|
| 129 |
+
|
| 130 |
+
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
|
| 131 |
+
old_prefix + "attention.dense.weight"
|
| 132 |
+
)
|
| 133 |
+
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(
|
| 134 |
+
old_prefix + "attention.dense.bias"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
|
| 138 |
+
old_prefix + "mlp.dense_h_to_4h.weight"
|
| 139 |
+
)
|
| 140 |
+
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
|
| 141 |
+
old_prefix + "mlp.dense_h_to_4h.bias"
|
| 142 |
+
)
|
| 143 |
+
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
|
| 144 |
+
old_prefix + "mlp.dense_4h_to_h.weight"
|
| 145 |
+
)
|
| 146 |
+
new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")
|
| 147 |
+
|
| 148 |
+
# Convert final norm and projection
|
| 149 |
+
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
| 150 |
+
original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
|
| 151 |
+
)
|
| 152 |
+
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
|
| 153 |
+
original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
|
| 154 |
+
)
|
| 155 |
+
new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
|
| 156 |
+
new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")
|
| 157 |
+
|
| 158 |
+
return new_state_dict
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
|
| 162 |
+
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
| 163 |
+
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def main(args):
|
| 167 |
+
if args.dtype == "fp16":
|
| 168 |
+
dtype = torch.float16
|
| 169 |
+
elif args.dtype == "bf16":
|
| 170 |
+
dtype = torch.bfloat16
|
| 171 |
+
elif args.dtype == "fp32":
|
| 172 |
+
dtype = torch.float32
|
| 173 |
+
else:
|
| 174 |
+
raise ValueError(f"Unsupported dtype: {args.dtype}")
|
| 175 |
+
|
| 176 |
+
transformer = None
|
| 177 |
+
vae = None
|
| 178 |
+
|
| 179 |
+
if args.transformer_checkpoint_path is not None:
|
| 180 |
+
converted_transformer_state_dict = convert_cogview4_transformer_checkpoint_to_diffusers(
|
| 181 |
+
args.transformer_checkpoint_path
|
| 182 |
+
)
|
| 183 |
+
transformer = CogView4Transformer2DModel(
|
| 184 |
+
patch_size=2,
|
| 185 |
+
in_channels=16,
|
| 186 |
+
num_layers=28,
|
| 187 |
+
attention_head_dim=128,
|
| 188 |
+
num_attention_heads=32,
|
| 189 |
+
out_channels=16,
|
| 190 |
+
text_embed_dim=4096,
|
| 191 |
+
time_embed_dim=512,
|
| 192 |
+
condition_dim=256,
|
| 193 |
+
pos_embed_max_size=128,
|
| 194 |
+
)
|
| 195 |
+
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
|
| 196 |
+
if dtype is not None:
|
| 197 |
+
# Original checkpoint data type will be preserved
|
| 198 |
+
transformer = transformer.to(dtype=dtype)
|
| 199 |
+
|
| 200 |
+
if args.vae_checkpoint_path is not None:
|
| 201 |
+
vae_config = {
|
| 202 |
+
"in_channels": 3,
|
| 203 |
+
"out_channels": 3,
|
| 204 |
+
"down_block_types": ("DownEncoderBlock2D",) * 4,
|
| 205 |
+
"up_block_types": ("UpDecoderBlock2D",) * 4,
|
| 206 |
+
"block_out_channels": (128, 512, 1024, 1024),
|
| 207 |
+
"layers_per_block": 3,
|
| 208 |
+
"act_fn": "silu",
|
| 209 |
+
"latent_channels": 16,
|
| 210 |
+
"norm_num_groups": 32,
|
| 211 |
+
"sample_size": 1024,
|
| 212 |
+
"scaling_factor": 1.0,
|
| 213 |
+
"shift_factor": 0.0,
|
| 214 |
+
"force_upcast": True,
|
| 215 |
+
"use_quant_conv": False,
|
| 216 |
+
"use_post_quant_conv": False,
|
| 217 |
+
"mid_block_add_attention": False,
|
| 218 |
+
}
|
| 219 |
+
converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
|
| 220 |
+
vae = AutoencoderKL(**vae_config)
|
| 221 |
+
vae.load_state_dict(converted_vae_state_dict, strict=True)
|
| 222 |
+
if dtype is not None:
|
| 223 |
+
vae = vae.to(dtype=dtype)
|
| 224 |
+
|
| 225 |
+
text_encoder_id = "THUDM/glm-4-9b-hf"
|
| 226 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
|
| 227 |
+
text_encoder = GlmForCausalLM.from_pretrained(
|
| 228 |
+
text_encoder_id,
|
| 229 |
+
cache_dir=args.text_encoder_cache_dir,
|
| 230 |
+
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
for param in text_encoder.parameters():
|
| 234 |
+
param.data = param.data.contiguous()
|
| 235 |
+
|
| 236 |
+
scheduler = FlowMatchEulerDiscreteScheduler(
|
| 237 |
+
base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
pipe = CogView4Pipeline(
|
| 241 |
+
tokenizer=tokenizer,
|
| 242 |
+
text_encoder=text_encoder,
|
| 243 |
+
vae=vae,
|
| 244 |
+
transformer=transformer,
|
| 245 |
+
scheduler=scheduler,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can
|
| 249 |
+
# save some memory used for model loading.
|
| 250 |
+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
main(args)
|
diffusers/scripts/convert_cogview4_to_diffusers_megatron.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Convert a CogView4 checkpoint from Megatron to the Diffusers format.
|
| 3 |
+
|
| 4 |
+
Example usage:
|
| 5 |
+
python scripts/convert_cogview4_to_diffusers.py \
|
| 6 |
+
--transformer_checkpoint_path 'your path/cogview4_6b/mp_rank_00/model_optim_rng.pt' \
|
| 7 |
+
--vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \
|
| 8 |
+
--output_path "THUDM/CogView4-6B" \
|
| 9 |
+
--dtype "bf16"
|
| 10 |
+
|
| 11 |
+
Arguments:
|
| 12 |
+
--transformer_checkpoint_path: Path to Transformer state dict.
|
| 13 |
+
--vae_checkpoint_path: Path to VAE state dict.
|
| 14 |
+
--output_path: The path to save the converted model.
|
| 15 |
+
--push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
|
| 16 |
+
--text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used.
|
| 17 |
+
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
|
| 18 |
+
|
| 19 |
+
Default is "bf16" because CogView4 uses bfloat16 for training.
|
| 20 |
+
|
| 21 |
+
Note: You must provide either --transformer_checkpoint_path or --vae_checkpoint_path.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
from transformers import GlmModel, PreTrainedTokenizerFast
|
| 29 |
+
|
| 30 |
+
from diffusers import (
|
| 31 |
+
AutoencoderKL,
|
| 32 |
+
CogView4ControlPipeline,
|
| 33 |
+
CogView4Pipeline,
|
| 34 |
+
CogView4Transformer2DModel,
|
| 35 |
+
FlowMatchEulerDiscreteScheduler,
|
| 36 |
+
)
|
| 37 |
+
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
parser = argparse.ArgumentParser()
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--transformer_checkpoint_path",
|
| 43 |
+
default=None,
|
| 44 |
+
type=str,
|
| 45 |
+
help="Path to Megatron (not SAT) Transformer checkpoint, e.g., 'model_optim_rng.pt'.",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--vae_checkpoint_path",
|
| 49 |
+
default=None,
|
| 50 |
+
type=str,
|
| 51 |
+
help="(Optional) Path to VAE checkpoint, e.g., 'imagekl_ch16.pt'.",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--output_path",
|
| 55 |
+
required=True,
|
| 56 |
+
type=str,
|
| 57 |
+
help="Directory to save the final Diffusers format pipeline.",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--push_to_hub",
|
| 61 |
+
action="store_true",
|
| 62 |
+
default=False,
|
| 63 |
+
help="Whether to push the converted model to the HuggingFace Hub.",
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--text_encoder_cache_dir",
|
| 67 |
+
type=str,
|
| 68 |
+
default=None,
|
| 69 |
+
help="Specify the cache directory for the text encoder.",
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--dtype",
|
| 73 |
+
type=str,
|
| 74 |
+
default="bf16",
|
| 75 |
+
choices=["fp16", "bf16", "fp32"],
|
| 76 |
+
help="Data type to save the model in.",
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--num_layers",
|
| 81 |
+
type=int,
|
| 82 |
+
default=28,
|
| 83 |
+
help="Number of Transformer layers (e.g., 28, 48...).",
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--num_heads",
|
| 87 |
+
type=int,
|
| 88 |
+
default=32,
|
| 89 |
+
help="Number of attention heads.",
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"--hidden_size",
|
| 93 |
+
type=int,
|
| 94 |
+
default=4096,
|
| 95 |
+
help="Transformer hidden dimension size.",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--attention_head_dim",
|
| 99 |
+
type=int,
|
| 100 |
+
default=128,
|
| 101 |
+
help="Dimension of each attention head.",
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--time_embed_dim",
|
| 105 |
+
type=int,
|
| 106 |
+
default=512,
|
| 107 |
+
help="Dimension of time embeddings.",
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--condition_dim",
|
| 111 |
+
type=int,
|
| 112 |
+
default=256,
|
| 113 |
+
help="Dimension of condition embeddings.",
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--pos_embed_max_size",
|
| 117 |
+
type=int,
|
| 118 |
+
default=128,
|
| 119 |
+
help="Maximum size for positional embeddings.",
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--control",
|
| 123 |
+
action="store_true",
|
| 124 |
+
default=False,
|
| 125 |
+
help="Whether to use control model.",
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
args = parser.parse_args()
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def swap_scale_shift(weight, dim):
|
| 132 |
+
"""
|
| 133 |
+
Swap the scale and shift components in the weight tensor.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
weight (torch.Tensor): The original weight tensor.
|
| 137 |
+
dim (int): The dimension along which to split.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
torch.Tensor: The modified weight tensor with scale and shift swapped.
|
| 141 |
+
"""
|
| 142 |
+
shift, scale = weight.chunk(2, dim=dim)
|
| 143 |
+
new_weight = torch.cat([scale, shift], dim=dim)
|
| 144 |
+
return new_weight
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def convert_megatron_transformer_checkpoint_to_diffusers(
|
| 148 |
+
ckpt_path: str,
|
| 149 |
+
num_layers: int,
|
| 150 |
+
num_heads: int,
|
| 151 |
+
hidden_size: int,
|
| 152 |
+
):
|
| 153 |
+
"""
|
| 154 |
+
Convert a Megatron Transformer checkpoint to Diffusers format.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
ckpt_path (str): Path to the Megatron Transformer checkpoint.
|
| 158 |
+
num_layers (int): Number of Transformer layers.
|
| 159 |
+
num_heads (int): Number of attention heads.
|
| 160 |
+
hidden_size (int): Hidden size of the Transformer.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
dict: The converted state dictionary compatible with Diffusers.
|
| 164 |
+
"""
|
| 165 |
+
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 166 |
+
mega = ckpt["model"]
|
| 167 |
+
|
| 168 |
+
new_state_dict = {}
|
| 169 |
+
|
| 170 |
+
# Patch Embedding
|
| 171 |
+
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(
|
| 172 |
+
hidden_size, 128 if args.control else 64
|
| 173 |
+
)
|
| 174 |
+
new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"]
|
| 175 |
+
new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"]
|
| 176 |
+
new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"]
|
| 177 |
+
|
| 178 |
+
# Time Condition Embedding
|
| 179 |
+
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = mega[
|
| 180 |
+
"time_embedding.time_embed.0.weight"
|
| 181 |
+
]
|
| 182 |
+
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = mega["time_embedding.time_embed.0.bias"]
|
| 183 |
+
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = mega[
|
| 184 |
+
"time_embedding.time_embed.2.weight"
|
| 185 |
+
]
|
| 186 |
+
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = mega["time_embedding.time_embed.2.bias"]
|
| 187 |
+
|
| 188 |
+
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = mega[
|
| 189 |
+
"label_embedding.label_embed.0.weight"
|
| 190 |
+
]
|
| 191 |
+
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = mega[
|
| 192 |
+
"label_embedding.label_embed.0.bias"
|
| 193 |
+
]
|
| 194 |
+
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = mega[
|
| 195 |
+
"label_embedding.label_embed.2.weight"
|
| 196 |
+
]
|
| 197 |
+
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = mega[
|
| 198 |
+
"label_embedding.label_embed.2.bias"
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
# Convert each Transformer layer
|
| 202 |
+
for i in tqdm(range(num_layers), desc="Converting layers (Megatron->Diffusers)"):
|
| 203 |
+
block_prefix = f"transformer_blocks.{i}."
|
| 204 |
+
|
| 205 |
+
# AdaLayerNorm
|
| 206 |
+
new_state_dict[block_prefix + "norm1.linear.weight"] = mega[f"decoder.layers.{i}.adaln.weight"]
|
| 207 |
+
new_state_dict[block_prefix + "norm1.linear.bias"] = mega[f"decoder.layers.{i}.adaln.bias"]
|
| 208 |
+
qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"]
|
| 209 |
+
qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"]
|
| 210 |
+
|
| 211 |
+
# Reshape to match SAT logic
|
| 212 |
+
qkv_weight = qkv_weight.view(num_heads, 3, hidden_size // num_heads, hidden_size)
|
| 213 |
+
qkv_weight = qkv_weight.permute(1, 0, 2, 3).reshape(3 * hidden_size, hidden_size)
|
| 214 |
+
|
| 215 |
+
qkv_bias = qkv_bias.view(num_heads, 3, hidden_size // num_heads)
|
| 216 |
+
qkv_bias = qkv_bias.permute(1, 0, 2).reshape(3 * hidden_size)
|
| 217 |
+
|
| 218 |
+
# Assign to Diffusers keys
|
| 219 |
+
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
|
| 220 |
+
qb, kb, vb = torch.chunk(qkv_bias, 3, dim=0)
|
| 221 |
+
|
| 222 |
+
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
| 223 |
+
new_state_dict[block_prefix + "attn1.to_q.bias"] = qb
|
| 224 |
+
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
| 225 |
+
new_state_dict[block_prefix + "attn1.to_k.bias"] = kb
|
| 226 |
+
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
| 227 |
+
new_state_dict[block_prefix + "attn1.to_v.bias"] = vb
|
| 228 |
+
|
| 229 |
+
# Attention Output
|
| 230 |
+
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[
|
| 231 |
+
f"decoder.layers.{i}.self_attention.linear_proj.weight"
|
| 232 |
+
]
|
| 233 |
+
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[
|
| 234 |
+
f"decoder.layers.{i}.self_attention.linear_proj.bias"
|
| 235 |
+
]
|
| 236 |
+
|
| 237 |
+
# MLP
|
| 238 |
+
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = mega[f"decoder.layers.{i}.mlp.linear_fc1.weight"]
|
| 239 |
+
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = mega[f"decoder.layers.{i}.mlp.linear_fc1.bias"]
|
| 240 |
+
new_state_dict[block_prefix + "ff.net.2.weight"] = mega[f"decoder.layers.{i}.mlp.linear_fc2.weight"]
|
| 241 |
+
new_state_dict[block_prefix + "ff.net.2.bias"] = mega[f"decoder.layers.{i}.mlp.linear_fc2.bias"]
|
| 242 |
+
|
| 243 |
+
# Final Layers
|
| 244 |
+
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(mega["adaln_final.weight"], dim=0)
|
| 245 |
+
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(mega["adaln_final.bias"], dim=0)
|
| 246 |
+
new_state_dict["proj_out.weight"] = mega["output_projector.weight"]
|
| 247 |
+
new_state_dict["proj_out.bias"] = mega["output_projector.bias"]
|
| 248 |
+
|
| 249 |
+
return new_state_dict
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
|
| 253 |
+
"""
|
| 254 |
+
Convert a CogView4 VAE checkpoint to Diffusers format.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
ckpt_path (str): Path to the VAE checkpoint.
|
| 258 |
+
vae_config (dict): Configuration dictionary for the VAE.
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
dict: The converted VAE state dictionary compatible with Diffusers.
|
| 262 |
+
"""
|
| 263 |
+
original_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"]
|
| 264 |
+
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def main(args):
|
| 268 |
+
"""
|
| 269 |
+
Main function to convert CogView4 checkpoints to Diffusers format.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
args (argparse.Namespace): Parsed command-line arguments.
|
| 273 |
+
"""
|
| 274 |
+
# Determine the desired data type
|
| 275 |
+
if args.dtype == "fp16":
|
| 276 |
+
dtype = torch.float16
|
| 277 |
+
elif args.dtype == "bf16":
|
| 278 |
+
dtype = torch.bfloat16
|
| 279 |
+
elif args.dtype == "fp32":
|
| 280 |
+
dtype = torch.float32
|
| 281 |
+
else:
|
| 282 |
+
raise ValueError(f"Unsupported dtype: {args.dtype}")
|
| 283 |
+
|
| 284 |
+
transformer = None
|
| 285 |
+
vae = None
|
| 286 |
+
|
| 287 |
+
# Convert Transformer checkpoint if provided
|
| 288 |
+
if args.transformer_checkpoint_path is not None:
|
| 289 |
+
converted_transformer_state_dict = convert_megatron_transformer_checkpoint_to_diffusers(
|
| 290 |
+
ckpt_path=args.transformer_checkpoint_path,
|
| 291 |
+
num_layers=args.num_layers,
|
| 292 |
+
num_heads=args.num_heads,
|
| 293 |
+
hidden_size=args.hidden_size,
|
| 294 |
+
)
|
| 295 |
+
transformer = CogView4Transformer2DModel(
|
| 296 |
+
patch_size=2,
|
| 297 |
+
in_channels=32 if args.control else 16,
|
| 298 |
+
num_layers=args.num_layers,
|
| 299 |
+
attention_head_dim=args.attention_head_dim,
|
| 300 |
+
num_attention_heads=args.num_heads,
|
| 301 |
+
out_channels=16,
|
| 302 |
+
text_embed_dim=args.hidden_size,
|
| 303 |
+
time_embed_dim=args.time_embed_dim,
|
| 304 |
+
condition_dim=args.condition_dim,
|
| 305 |
+
pos_embed_max_size=args.pos_embed_max_size,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
|
| 309 |
+
|
| 310 |
+
# Convert to the specified dtype
|
| 311 |
+
if dtype is not None:
|
| 312 |
+
transformer = transformer.to(dtype=dtype)
|
| 313 |
+
|
| 314 |
+
# Convert VAE checkpoint if provided
|
| 315 |
+
if args.vae_checkpoint_path is not None:
|
| 316 |
+
vae_config = {
|
| 317 |
+
"in_channels": 3,
|
| 318 |
+
"out_channels": 3,
|
| 319 |
+
"down_block_types": ("DownEncoderBlock2D",) * 4,
|
| 320 |
+
"up_block_types": ("UpDecoderBlock2D",) * 4,
|
| 321 |
+
"block_out_channels": (128, 512, 1024, 1024),
|
| 322 |
+
"layers_per_block": 3,
|
| 323 |
+
"act_fn": "silu",
|
| 324 |
+
"latent_channels": 16,
|
| 325 |
+
"norm_num_groups": 32,
|
| 326 |
+
"sample_size": 1024,
|
| 327 |
+
"scaling_factor": 1.0,
|
| 328 |
+
"shift_factor": 0.0,
|
| 329 |
+
"force_upcast": True,
|
| 330 |
+
"use_quant_conv": False,
|
| 331 |
+
"use_post_quant_conv": False,
|
| 332 |
+
"mid_block_add_attention": False,
|
| 333 |
+
}
|
| 334 |
+
converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
|
| 335 |
+
vae = AutoencoderKL(**vae_config)
|
| 336 |
+
vae.load_state_dict(converted_vae_state_dict, strict=True)
|
| 337 |
+
if dtype is not None:
|
| 338 |
+
vae = vae.to(dtype=dtype)
|
| 339 |
+
|
| 340 |
+
# Load the text encoder and tokenizer
|
| 341 |
+
text_encoder_id = "THUDM/glm-4-9b-hf"
|
| 342 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
|
| 343 |
+
text_encoder = GlmModel.from_pretrained(
|
| 344 |
+
text_encoder_id,
|
| 345 |
+
cache_dir=args.text_encoder_cache_dir,
|
| 346 |
+
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
|
| 347 |
+
)
|
| 348 |
+
for param in text_encoder.parameters():
|
| 349 |
+
param.data = param.data.contiguous()
|
| 350 |
+
|
| 351 |
+
# Initialize the scheduler
|
| 352 |
+
scheduler = FlowMatchEulerDiscreteScheduler(
|
| 353 |
+
base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# Create the pipeline
|
| 357 |
+
if args.control:
|
| 358 |
+
pipe = CogView4ControlPipeline(
|
| 359 |
+
tokenizer=tokenizer,
|
| 360 |
+
text_encoder=text_encoder,
|
| 361 |
+
vae=vae,
|
| 362 |
+
transformer=transformer,
|
| 363 |
+
scheduler=scheduler,
|
| 364 |
+
)
|
| 365 |
+
else:
|
| 366 |
+
pipe = CogView4Pipeline(
|
| 367 |
+
tokenizer=tokenizer,
|
| 368 |
+
text_encoder=text_encoder,
|
| 369 |
+
vae=vae,
|
| 370 |
+
transformer=transformer,
|
| 371 |
+
scheduler=scheduler,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
# Save the converted pipeline
|
| 375 |
+
pipe.save_pretrained(
|
| 376 |
+
args.output_path,
|
| 377 |
+
safe_serialization=True,
|
| 378 |
+
max_shard_size="5GB",
|
| 379 |
+
push_to_hub=args.push_to_hub,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
if __name__ == "__main__":
|
| 384 |
+
main(args)
|
diffusers/scripts/convert_consistency_to_diffusers.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from diffusers import (
|
| 7 |
+
CMStochasticIterativeScheduler,
|
| 8 |
+
ConsistencyModelPipeline,
|
| 9 |
+
UNet2DModel,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
TEST_UNET_CONFIG = {
|
| 14 |
+
"sample_size": 32,
|
| 15 |
+
"in_channels": 3,
|
| 16 |
+
"out_channels": 3,
|
| 17 |
+
"layers_per_block": 2,
|
| 18 |
+
"num_class_embeds": 1000,
|
| 19 |
+
"block_out_channels": [32, 64],
|
| 20 |
+
"attention_head_dim": 8,
|
| 21 |
+
"down_block_types": [
|
| 22 |
+
"ResnetDownsampleBlock2D",
|
| 23 |
+
"AttnDownBlock2D",
|
| 24 |
+
],
|
| 25 |
+
"up_block_types": [
|
| 26 |
+
"AttnUpBlock2D",
|
| 27 |
+
"ResnetUpsampleBlock2D",
|
| 28 |
+
],
|
| 29 |
+
"resnet_time_scale_shift": "scale_shift",
|
| 30 |
+
"attn_norm_num_groups": 32,
|
| 31 |
+
"upsample_type": "resnet",
|
| 32 |
+
"downsample_type": "resnet",
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
IMAGENET_64_UNET_CONFIG = {
|
| 36 |
+
"sample_size": 64,
|
| 37 |
+
"in_channels": 3,
|
| 38 |
+
"out_channels": 3,
|
| 39 |
+
"layers_per_block": 3,
|
| 40 |
+
"num_class_embeds": 1000,
|
| 41 |
+
"block_out_channels": [192, 192 * 2, 192 * 3, 192 * 4],
|
| 42 |
+
"attention_head_dim": 64,
|
| 43 |
+
"down_block_types": [
|
| 44 |
+
"ResnetDownsampleBlock2D",
|
| 45 |
+
"AttnDownBlock2D",
|
| 46 |
+
"AttnDownBlock2D",
|
| 47 |
+
"AttnDownBlock2D",
|
| 48 |
+
],
|
| 49 |
+
"up_block_types": [
|
| 50 |
+
"AttnUpBlock2D",
|
| 51 |
+
"AttnUpBlock2D",
|
| 52 |
+
"AttnUpBlock2D",
|
| 53 |
+
"ResnetUpsampleBlock2D",
|
| 54 |
+
],
|
| 55 |
+
"resnet_time_scale_shift": "scale_shift",
|
| 56 |
+
"attn_norm_num_groups": 32,
|
| 57 |
+
"upsample_type": "resnet",
|
| 58 |
+
"downsample_type": "resnet",
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
LSUN_256_UNET_CONFIG = {
|
| 62 |
+
"sample_size": 256,
|
| 63 |
+
"in_channels": 3,
|
| 64 |
+
"out_channels": 3,
|
| 65 |
+
"layers_per_block": 2,
|
| 66 |
+
"num_class_embeds": None,
|
| 67 |
+
"block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4],
|
| 68 |
+
"attention_head_dim": 64,
|
| 69 |
+
"down_block_types": [
|
| 70 |
+
"ResnetDownsampleBlock2D",
|
| 71 |
+
"ResnetDownsampleBlock2D",
|
| 72 |
+
"ResnetDownsampleBlock2D",
|
| 73 |
+
"AttnDownBlock2D",
|
| 74 |
+
"AttnDownBlock2D",
|
| 75 |
+
"AttnDownBlock2D",
|
| 76 |
+
],
|
| 77 |
+
"up_block_types": [
|
| 78 |
+
"AttnUpBlock2D",
|
| 79 |
+
"AttnUpBlock2D",
|
| 80 |
+
"AttnUpBlock2D",
|
| 81 |
+
"ResnetUpsampleBlock2D",
|
| 82 |
+
"ResnetUpsampleBlock2D",
|
| 83 |
+
"ResnetUpsampleBlock2D",
|
| 84 |
+
],
|
| 85 |
+
"resnet_time_scale_shift": "default",
|
| 86 |
+
"upsample_type": "resnet",
|
| 87 |
+
"downsample_type": "resnet",
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
CD_SCHEDULER_CONFIG = {
|
| 91 |
+
"num_train_timesteps": 40,
|
| 92 |
+
"sigma_min": 0.002,
|
| 93 |
+
"sigma_max": 80.0,
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
CT_IMAGENET_64_SCHEDULER_CONFIG = {
|
| 97 |
+
"num_train_timesteps": 201,
|
| 98 |
+
"sigma_min": 0.002,
|
| 99 |
+
"sigma_max": 80.0,
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
CT_LSUN_256_SCHEDULER_CONFIG = {
|
| 103 |
+
"num_train_timesteps": 151,
|
| 104 |
+
"sigma_min": 0.002,
|
| 105 |
+
"sigma_max": 80.0,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def str2bool(v):
|
| 110 |
+
"""
|
| 111 |
+
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
|
| 112 |
+
"""
|
| 113 |
+
if isinstance(v, bool):
|
| 114 |
+
return v
|
| 115 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 116 |
+
return True
|
| 117 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 118 |
+
return False
|
| 119 |
+
else:
|
| 120 |
+
raise argparse.ArgumentTypeError("boolean value expected")
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False):
|
| 124 |
+
new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"]
|
| 125 |
+
new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"]
|
| 126 |
+
new_checkpoint[f"{new_prefix}.conv1.weight"] = checkpoint[f"{old_prefix}.in_layers.2.weight"]
|
| 127 |
+
new_checkpoint[f"{new_prefix}.conv1.bias"] = checkpoint[f"{old_prefix}.in_layers.2.bias"]
|
| 128 |
+
new_checkpoint[f"{new_prefix}.time_emb_proj.weight"] = checkpoint[f"{old_prefix}.emb_layers.1.weight"]
|
| 129 |
+
new_checkpoint[f"{new_prefix}.time_emb_proj.bias"] = checkpoint[f"{old_prefix}.emb_layers.1.bias"]
|
| 130 |
+
new_checkpoint[f"{new_prefix}.norm2.weight"] = checkpoint[f"{old_prefix}.out_layers.0.weight"]
|
| 131 |
+
new_checkpoint[f"{new_prefix}.norm2.bias"] = checkpoint[f"{old_prefix}.out_layers.0.bias"]
|
| 132 |
+
new_checkpoint[f"{new_prefix}.conv2.weight"] = checkpoint[f"{old_prefix}.out_layers.3.weight"]
|
| 133 |
+
new_checkpoint[f"{new_prefix}.conv2.bias"] = checkpoint[f"{old_prefix}.out_layers.3.bias"]
|
| 134 |
+
|
| 135 |
+
if has_skip:
|
| 136 |
+
new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"]
|
| 137 |
+
new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"]
|
| 138 |
+
|
| 139 |
+
return new_checkpoint
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None):
|
| 143 |
+
weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0)
|
| 144 |
+
bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0)
|
| 145 |
+
|
| 146 |
+
new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"]
|
| 147 |
+
new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"]
|
| 148 |
+
|
| 149 |
+
new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1)
|
| 150 |
+
new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1)
|
| 151 |
+
new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1)
|
| 152 |
+
new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1)
|
| 153 |
+
new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1)
|
| 154 |
+
new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1)
|
| 155 |
+
|
| 156 |
+
new_checkpoint[f"{new_prefix}.to_out.0.weight"] = (
|
| 157 |
+
checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1)
|
| 158 |
+
)
|
| 159 |
+
new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1)
|
| 160 |
+
|
| 161 |
+
return new_checkpoint
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def con_pt_to_diffuser(checkpoint_path: str, unet_config):
|
| 165 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
| 166 |
+
new_checkpoint = {}
|
| 167 |
+
|
| 168 |
+
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
|
| 169 |
+
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
|
| 170 |
+
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
|
| 171 |
+
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
|
| 172 |
+
|
| 173 |
+
if unet_config["num_class_embeds"] is not None:
|
| 174 |
+
new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"]
|
| 175 |
+
|
| 176 |
+
new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
|
| 177 |
+
new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
|
| 178 |
+
|
| 179 |
+
down_block_types = unet_config["down_block_types"]
|
| 180 |
+
layers_per_block = unet_config["layers_per_block"]
|
| 181 |
+
attention_head_dim = unet_config["attention_head_dim"]
|
| 182 |
+
channels_list = unet_config["block_out_channels"]
|
| 183 |
+
current_layer = 1
|
| 184 |
+
prev_channels = channels_list[0]
|
| 185 |
+
|
| 186 |
+
for i, layer_type in enumerate(down_block_types):
|
| 187 |
+
current_channels = channels_list[i]
|
| 188 |
+
downsample_block_has_skip = current_channels != prev_channels
|
| 189 |
+
if layer_type == "ResnetDownsampleBlock2D":
|
| 190 |
+
for j in range(layers_per_block):
|
| 191 |
+
new_prefix = f"down_blocks.{i}.resnets.{j}"
|
| 192 |
+
old_prefix = f"input_blocks.{current_layer}.0"
|
| 193 |
+
has_skip = True if j == 0 and downsample_block_has_skip else False
|
| 194 |
+
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
|
| 195 |
+
current_layer += 1
|
| 196 |
+
|
| 197 |
+
elif layer_type == "AttnDownBlock2D":
|
| 198 |
+
for j in range(layers_per_block):
|
| 199 |
+
new_prefix = f"down_blocks.{i}.resnets.{j}"
|
| 200 |
+
old_prefix = f"input_blocks.{current_layer}.0"
|
| 201 |
+
has_skip = True if j == 0 and downsample_block_has_skip else False
|
| 202 |
+
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
|
| 203 |
+
new_prefix = f"down_blocks.{i}.attentions.{j}"
|
| 204 |
+
old_prefix = f"input_blocks.{current_layer}.1"
|
| 205 |
+
new_checkpoint = convert_attention(
|
| 206 |
+
checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
|
| 207 |
+
)
|
| 208 |
+
current_layer += 1
|
| 209 |
+
|
| 210 |
+
if i != len(down_block_types) - 1:
|
| 211 |
+
new_prefix = f"down_blocks.{i}.downsamplers.0"
|
| 212 |
+
old_prefix = f"input_blocks.{current_layer}.0"
|
| 213 |
+
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
| 214 |
+
current_layer += 1
|
| 215 |
+
|
| 216 |
+
prev_channels = current_channels
|
| 217 |
+
|
| 218 |
+
# hardcoded the mid-block for now
|
| 219 |
+
new_prefix = "mid_block.resnets.0"
|
| 220 |
+
old_prefix = "middle_block.0"
|
| 221 |
+
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
| 222 |
+
new_prefix = "mid_block.attentions.0"
|
| 223 |
+
old_prefix = "middle_block.1"
|
| 224 |
+
new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim)
|
| 225 |
+
new_prefix = "mid_block.resnets.1"
|
| 226 |
+
old_prefix = "middle_block.2"
|
| 227 |
+
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
| 228 |
+
|
| 229 |
+
current_layer = 0
|
| 230 |
+
up_block_types = unet_config["up_block_types"]
|
| 231 |
+
|
| 232 |
+
for i, layer_type in enumerate(up_block_types):
|
| 233 |
+
if layer_type == "ResnetUpsampleBlock2D":
|
| 234 |
+
for j in range(layers_per_block + 1):
|
| 235 |
+
new_prefix = f"up_blocks.{i}.resnets.{j}"
|
| 236 |
+
old_prefix = f"output_blocks.{current_layer}.0"
|
| 237 |
+
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
|
| 238 |
+
current_layer += 1
|
| 239 |
+
|
| 240 |
+
if i != len(up_block_types) - 1:
|
| 241 |
+
new_prefix = f"up_blocks.{i}.upsamplers.0"
|
| 242 |
+
old_prefix = f"output_blocks.{current_layer - 1}.1"
|
| 243 |
+
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
| 244 |
+
elif layer_type == "AttnUpBlock2D":
|
| 245 |
+
for j in range(layers_per_block + 1):
|
| 246 |
+
new_prefix = f"up_blocks.{i}.resnets.{j}"
|
| 247 |
+
old_prefix = f"output_blocks.{current_layer}.0"
|
| 248 |
+
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
|
| 249 |
+
new_prefix = f"up_blocks.{i}.attentions.{j}"
|
| 250 |
+
old_prefix = f"output_blocks.{current_layer}.1"
|
| 251 |
+
new_checkpoint = convert_attention(
|
| 252 |
+
checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
|
| 253 |
+
)
|
| 254 |
+
current_layer += 1
|
| 255 |
+
|
| 256 |
+
if i != len(up_block_types) - 1:
|
| 257 |
+
new_prefix = f"up_blocks.{i}.upsamplers.0"
|
| 258 |
+
old_prefix = f"output_blocks.{current_layer - 1}.2"
|
| 259 |
+
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
| 260 |
+
|
| 261 |
+
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
|
| 262 |
+
new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
|
| 263 |
+
new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
|
| 264 |
+
new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
|
| 265 |
+
|
| 266 |
+
return new_checkpoint
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
if __name__ == "__main__":
|
| 270 |
+
parser = argparse.ArgumentParser()
|
| 271 |
+
|
| 272 |
+
parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.")
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
"--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model."
|
| 275 |
+
)
|
| 276 |
+
parser.add_argument("--class_cond", default=True, type=str, help="Whether the model is class-conditional.")
|
| 277 |
+
|
| 278 |
+
args = parser.parse_args()
|
| 279 |
+
args.class_cond = str2bool(args.class_cond)
|
| 280 |
+
|
| 281 |
+
ckpt_name = os.path.basename(args.unet_path)
|
| 282 |
+
print(f"Checkpoint: {ckpt_name}")
|
| 283 |
+
|
| 284 |
+
# Get U-Net config
|
| 285 |
+
if "imagenet64" in ckpt_name:
|
| 286 |
+
unet_config = IMAGENET_64_UNET_CONFIG
|
| 287 |
+
elif "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
|
| 288 |
+
unet_config = LSUN_256_UNET_CONFIG
|
| 289 |
+
elif "test" in ckpt_name:
|
| 290 |
+
unet_config = TEST_UNET_CONFIG
|
| 291 |
+
else:
|
| 292 |
+
raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
|
| 293 |
+
|
| 294 |
+
if not args.class_cond:
|
| 295 |
+
unet_config["num_class_embeds"] = None
|
| 296 |
+
|
| 297 |
+
converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config)
|
| 298 |
+
|
| 299 |
+
image_unet = UNet2DModel(**unet_config)
|
| 300 |
+
image_unet.load_state_dict(converted_unet_ckpt)
|
| 301 |
+
|
| 302 |
+
# Get scheduler config
|
| 303 |
+
if "cd" in ckpt_name or "test" in ckpt_name:
|
| 304 |
+
scheduler_config = CD_SCHEDULER_CONFIG
|
| 305 |
+
elif "ct" in ckpt_name and "imagenet64" in ckpt_name:
|
| 306 |
+
scheduler_config = CT_IMAGENET_64_SCHEDULER_CONFIG
|
| 307 |
+
elif "ct" in ckpt_name and "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
|
| 308 |
+
scheduler_config = CT_LSUN_256_SCHEDULER_CONFIG
|
| 309 |
+
else:
|
| 310 |
+
raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
|
| 311 |
+
|
| 312 |
+
cm_scheduler = CMStochasticIterativeScheduler(**scheduler_config)
|
| 313 |
+
|
| 314 |
+
consistency_model = ConsistencyModelPipeline(unet=image_unet, scheduler=cm_scheduler)
|
| 315 |
+
consistency_model.save_pretrained(args.dump_path)
|
diffusers/scripts/convert_cosmos_to_diffusers.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import pathlib
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from accelerate import init_empty_weights
|
| 7 |
+
from huggingface_hub import snapshot_download
|
| 8 |
+
from transformers import T5EncoderModel, T5TokenizerFast
|
| 9 |
+
|
| 10 |
+
from diffusers import (
|
| 11 |
+
AutoencoderKLCosmos,
|
| 12 |
+
AutoencoderKLWan,
|
| 13 |
+
Cosmos2TextToImagePipeline,
|
| 14 |
+
Cosmos2VideoToWorldPipeline,
|
| 15 |
+
CosmosTextToWorldPipeline,
|
| 16 |
+
CosmosTransformer3DModel,
|
| 17 |
+
CosmosVideoToWorldPipeline,
|
| 18 |
+
EDMEulerScheduler,
|
| 19 |
+
FlowMatchEulerDiscreteScheduler,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def remove_keys_(key: str, state_dict: Dict[str, Any]):
|
| 24 |
+
state_dict.pop(key)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
| 28 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
|
| 32 |
+
block_index = int(key.split(".")[1].removeprefix("block"))
|
| 33 |
+
new_key = key
|
| 34 |
+
|
| 35 |
+
old_prefix = f"blocks.block{block_index}"
|
| 36 |
+
new_prefix = f"transformer_blocks.{block_index}"
|
| 37 |
+
new_key = new_prefix + new_key.removeprefix(old_prefix)
|
| 38 |
+
|
| 39 |
+
state_dict[new_key] = state_dict.pop(key)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
|
| 43 |
+
"t_embedder.1": "time_embed.t_embedder",
|
| 44 |
+
"affline_norm": "time_embed.norm",
|
| 45 |
+
".blocks.0.block.attn": ".attn1",
|
| 46 |
+
".blocks.1.block.attn": ".attn2",
|
| 47 |
+
".blocks.2.block": ".ff",
|
| 48 |
+
".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
|
| 49 |
+
".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
|
| 50 |
+
".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
|
| 51 |
+
".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
|
| 52 |
+
".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
|
| 53 |
+
".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
|
| 54 |
+
"to_q.0": "to_q",
|
| 55 |
+
"to_q.1": "norm_q",
|
| 56 |
+
"to_k.0": "to_k",
|
| 57 |
+
"to_k.1": "norm_k",
|
| 58 |
+
"to_v.0": "to_v",
|
| 59 |
+
"layer1": "net.0.proj",
|
| 60 |
+
"layer2": "net.2",
|
| 61 |
+
"proj.1": "proj",
|
| 62 |
+
"x_embedder": "patch_embed",
|
| 63 |
+
"extra_pos_embedder": "learnable_pos_embed",
|
| 64 |
+
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
|
| 65 |
+
"final_layer.adaLN_modulation.2": "norm_out.linear_2",
|
| 66 |
+
"final_layer.linear": "proj_out",
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
|
| 70 |
+
"blocks.block": rename_transformer_blocks_,
|
| 71 |
+
"logvar.0.freqs": remove_keys_,
|
| 72 |
+
"logvar.0.phases": remove_keys_,
|
| 73 |
+
"logvar.1.weight": remove_keys_,
|
| 74 |
+
"pos_embedder.seq": remove_keys_,
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
|
| 78 |
+
"t_embedder.1": "time_embed.t_embedder",
|
| 79 |
+
"t_embedding_norm": "time_embed.norm",
|
| 80 |
+
"blocks": "transformer_blocks",
|
| 81 |
+
"adaln_modulation_self_attn.1": "norm1.linear_1",
|
| 82 |
+
"adaln_modulation_self_attn.2": "norm1.linear_2",
|
| 83 |
+
"adaln_modulation_cross_attn.1": "norm2.linear_1",
|
| 84 |
+
"adaln_modulation_cross_attn.2": "norm2.linear_2",
|
| 85 |
+
"adaln_modulation_mlp.1": "norm3.linear_1",
|
| 86 |
+
"adaln_modulation_mlp.2": "norm3.linear_2",
|
| 87 |
+
"self_attn": "attn1",
|
| 88 |
+
"cross_attn": "attn2",
|
| 89 |
+
"q_proj": "to_q",
|
| 90 |
+
"k_proj": "to_k",
|
| 91 |
+
"v_proj": "to_v",
|
| 92 |
+
"output_proj": "to_out.0",
|
| 93 |
+
"q_norm": "norm_q",
|
| 94 |
+
"k_norm": "norm_k",
|
| 95 |
+
"mlp.layer1": "ff.net.0.proj",
|
| 96 |
+
"mlp.layer2": "ff.net.2",
|
| 97 |
+
"x_embedder.proj.1": "patch_embed.proj",
|
| 98 |
+
"final_layer.adaln_modulation.1": "norm_out.linear_1",
|
| 99 |
+
"final_layer.adaln_modulation.2": "norm_out.linear_2",
|
| 100 |
+
"final_layer.linear": "proj_out",
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
|
| 104 |
+
"accum_video_sample_counter": remove_keys_,
|
| 105 |
+
"accum_image_sample_counter": remove_keys_,
|
| 106 |
+
"accum_iteration": remove_keys_,
|
| 107 |
+
"accum_train_in_hours": remove_keys_,
|
| 108 |
+
"pos_embedder.seq": remove_keys_,
|
| 109 |
+
"pos_embedder.dim_spatial_range": remove_keys_,
|
| 110 |
+
"pos_embedder.dim_temporal_range": remove_keys_,
|
| 111 |
+
"_extra_state": remove_keys_,
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
TRANSFORMER_CONFIGS = {
|
| 116 |
+
"Cosmos-1.0-Diffusion-7B-Text2World": {
|
| 117 |
+
"in_channels": 16,
|
| 118 |
+
"out_channels": 16,
|
| 119 |
+
"num_attention_heads": 32,
|
| 120 |
+
"attention_head_dim": 128,
|
| 121 |
+
"num_layers": 28,
|
| 122 |
+
"mlp_ratio": 4.0,
|
| 123 |
+
"text_embed_dim": 1024,
|
| 124 |
+
"adaln_lora_dim": 256,
|
| 125 |
+
"max_size": (128, 240, 240),
|
| 126 |
+
"patch_size": (1, 2, 2),
|
| 127 |
+
"rope_scale": (2.0, 1.0, 1.0),
|
| 128 |
+
"concat_padding_mask": True,
|
| 129 |
+
"extra_pos_embed_type": "learnable",
|
| 130 |
+
},
|
| 131 |
+
"Cosmos-1.0-Diffusion-7B-Video2World": {
|
| 132 |
+
"in_channels": 16 + 1,
|
| 133 |
+
"out_channels": 16,
|
| 134 |
+
"num_attention_heads": 32,
|
| 135 |
+
"attention_head_dim": 128,
|
| 136 |
+
"num_layers": 28,
|
| 137 |
+
"mlp_ratio": 4.0,
|
| 138 |
+
"text_embed_dim": 1024,
|
| 139 |
+
"adaln_lora_dim": 256,
|
| 140 |
+
"max_size": (128, 240, 240),
|
| 141 |
+
"patch_size": (1, 2, 2),
|
| 142 |
+
"rope_scale": (2.0, 1.0, 1.0),
|
| 143 |
+
"concat_padding_mask": True,
|
| 144 |
+
"extra_pos_embed_type": "learnable",
|
| 145 |
+
},
|
| 146 |
+
"Cosmos-1.0-Diffusion-14B-Text2World": {
|
| 147 |
+
"in_channels": 16,
|
| 148 |
+
"out_channels": 16,
|
| 149 |
+
"num_attention_heads": 40,
|
| 150 |
+
"attention_head_dim": 128,
|
| 151 |
+
"num_layers": 36,
|
| 152 |
+
"mlp_ratio": 4.0,
|
| 153 |
+
"text_embed_dim": 1024,
|
| 154 |
+
"adaln_lora_dim": 256,
|
| 155 |
+
"max_size": (128, 240, 240),
|
| 156 |
+
"patch_size": (1, 2, 2),
|
| 157 |
+
"rope_scale": (2.0, 2.0, 2.0),
|
| 158 |
+
"concat_padding_mask": True,
|
| 159 |
+
"extra_pos_embed_type": "learnable",
|
| 160 |
+
},
|
| 161 |
+
"Cosmos-1.0-Diffusion-14B-Video2World": {
|
| 162 |
+
"in_channels": 16 + 1,
|
| 163 |
+
"out_channels": 16,
|
| 164 |
+
"num_attention_heads": 40,
|
| 165 |
+
"attention_head_dim": 128,
|
| 166 |
+
"num_layers": 36,
|
| 167 |
+
"mlp_ratio": 4.0,
|
| 168 |
+
"text_embed_dim": 1024,
|
| 169 |
+
"adaln_lora_dim": 256,
|
| 170 |
+
"max_size": (128, 240, 240),
|
| 171 |
+
"patch_size": (1, 2, 2),
|
| 172 |
+
"rope_scale": (2.0, 2.0, 2.0),
|
| 173 |
+
"concat_padding_mask": True,
|
| 174 |
+
"extra_pos_embed_type": "learnable",
|
| 175 |
+
},
|
| 176 |
+
"Cosmos-2.0-Diffusion-2B-Text2Image": {
|
| 177 |
+
"in_channels": 16,
|
| 178 |
+
"out_channels": 16,
|
| 179 |
+
"num_attention_heads": 16,
|
| 180 |
+
"attention_head_dim": 128,
|
| 181 |
+
"num_layers": 28,
|
| 182 |
+
"mlp_ratio": 4.0,
|
| 183 |
+
"text_embed_dim": 1024,
|
| 184 |
+
"adaln_lora_dim": 256,
|
| 185 |
+
"max_size": (128, 240, 240),
|
| 186 |
+
"patch_size": (1, 2, 2),
|
| 187 |
+
"rope_scale": (1.0, 4.0, 4.0),
|
| 188 |
+
"concat_padding_mask": True,
|
| 189 |
+
"extra_pos_embed_type": None,
|
| 190 |
+
},
|
| 191 |
+
"Cosmos-2.0-Diffusion-14B-Text2Image": {
|
| 192 |
+
"in_channels": 16,
|
| 193 |
+
"out_channels": 16,
|
| 194 |
+
"num_attention_heads": 40,
|
| 195 |
+
"attention_head_dim": 128,
|
| 196 |
+
"num_layers": 36,
|
| 197 |
+
"mlp_ratio": 4.0,
|
| 198 |
+
"text_embed_dim": 1024,
|
| 199 |
+
"adaln_lora_dim": 256,
|
| 200 |
+
"max_size": (128, 240, 240),
|
| 201 |
+
"patch_size": (1, 2, 2),
|
| 202 |
+
"rope_scale": (1.0, 4.0, 4.0),
|
| 203 |
+
"concat_padding_mask": True,
|
| 204 |
+
"extra_pos_embed_type": None,
|
| 205 |
+
},
|
| 206 |
+
"Cosmos-2.0-Diffusion-2B-Video2World": {
|
| 207 |
+
"in_channels": 16 + 1,
|
| 208 |
+
"out_channels": 16,
|
| 209 |
+
"num_attention_heads": 16,
|
| 210 |
+
"attention_head_dim": 128,
|
| 211 |
+
"num_layers": 28,
|
| 212 |
+
"mlp_ratio": 4.0,
|
| 213 |
+
"text_embed_dim": 1024,
|
| 214 |
+
"adaln_lora_dim": 256,
|
| 215 |
+
"max_size": (128, 240, 240),
|
| 216 |
+
"patch_size": (1, 2, 2),
|
| 217 |
+
"rope_scale": (1.0, 3.0, 3.0),
|
| 218 |
+
"concat_padding_mask": True,
|
| 219 |
+
"extra_pos_embed_type": None,
|
| 220 |
+
},
|
| 221 |
+
"Cosmos-2.0-Diffusion-14B-Video2World": {
|
| 222 |
+
"in_channels": 16 + 1,
|
| 223 |
+
"out_channels": 16,
|
| 224 |
+
"num_attention_heads": 40,
|
| 225 |
+
"attention_head_dim": 128,
|
| 226 |
+
"num_layers": 36,
|
| 227 |
+
"mlp_ratio": 4.0,
|
| 228 |
+
"text_embed_dim": 1024,
|
| 229 |
+
"adaln_lora_dim": 256,
|
| 230 |
+
"max_size": (128, 240, 240),
|
| 231 |
+
"patch_size": (1, 2, 2),
|
| 232 |
+
"rope_scale": (20 / 24, 2.0, 2.0),
|
| 233 |
+
"concat_padding_mask": True,
|
| 234 |
+
"extra_pos_embed_type": None,
|
| 235 |
+
},
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
VAE_KEYS_RENAME_DICT = {
|
| 239 |
+
"down.0": "down_blocks.0",
|
| 240 |
+
"down.1": "down_blocks.1",
|
| 241 |
+
"down.2": "down_blocks.2",
|
| 242 |
+
"up.0": "up_blocks.2",
|
| 243 |
+
"up.1": "up_blocks.1",
|
| 244 |
+
"up.2": "up_blocks.0",
|
| 245 |
+
".block.": ".resnets.",
|
| 246 |
+
"downsample": "downsamplers.0",
|
| 247 |
+
"upsample": "upsamplers.0",
|
| 248 |
+
"mid.block_1": "mid_block.resnets.0",
|
| 249 |
+
"mid.attn_1.0": "mid_block.attentions.0",
|
| 250 |
+
"mid.attn_1.1": "mid_block.temp_attentions.0",
|
| 251 |
+
"mid.block_2": "mid_block.resnets.1",
|
| 252 |
+
".q.conv3d": ".to_q",
|
| 253 |
+
".k.conv3d": ".to_k",
|
| 254 |
+
".v.conv3d": ".to_v",
|
| 255 |
+
".proj_out.conv3d": ".to_out.0",
|
| 256 |
+
".0.conv3d": ".conv_s",
|
| 257 |
+
".1.conv3d": ".conv_t",
|
| 258 |
+
"conv1.conv3d": "conv1",
|
| 259 |
+
"conv2.conv3d": "conv2",
|
| 260 |
+
"conv3.conv3d": "conv3",
|
| 261 |
+
"nin_shortcut.conv3d": "conv_shortcut",
|
| 262 |
+
"quant_conv.conv3d": "quant_conv",
|
| 263 |
+
"post_quant_conv.conv3d": "post_quant_conv",
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
VAE_SPECIAL_KEYS_REMAP = {
|
| 267 |
+
"wavelets": remove_keys_,
|
| 268 |
+
"_arange": remove_keys_,
|
| 269 |
+
"patch_size_buffer": remove_keys_,
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
VAE_CONFIGS = {
|
| 273 |
+
"CV8x8x8-0.1": {
|
| 274 |
+
"name": "nvidia/Cosmos-0.1-Tokenizer-CV8x8x8",
|
| 275 |
+
"diffusers_config": {
|
| 276 |
+
"in_channels": 3,
|
| 277 |
+
"out_channels": 3,
|
| 278 |
+
"latent_channels": 16,
|
| 279 |
+
"encoder_block_out_channels": (128, 256, 512, 512),
|
| 280 |
+
"decode_block_out_channels": (256, 512, 512, 512),
|
| 281 |
+
"attention_resolutions": (32,),
|
| 282 |
+
"resolution": 1024,
|
| 283 |
+
"num_layers": 2,
|
| 284 |
+
"patch_size": 4,
|
| 285 |
+
"patch_type": "haar",
|
| 286 |
+
"scaling_factor": 1.0,
|
| 287 |
+
"spatial_compression_ratio": 8,
|
| 288 |
+
"temporal_compression_ratio": 8,
|
| 289 |
+
"latents_mean": None,
|
| 290 |
+
"latents_std": None,
|
| 291 |
+
},
|
| 292 |
+
},
|
| 293 |
+
"CV8x8x8-1.0": {
|
| 294 |
+
"name": "nvidia/Cosmos-1.0-Tokenizer-CV8x8x8",
|
| 295 |
+
"diffusers_config": {
|
| 296 |
+
"in_channels": 3,
|
| 297 |
+
"out_channels": 3,
|
| 298 |
+
"latent_channels": 16,
|
| 299 |
+
"encoder_block_out_channels": (128, 256, 512, 512),
|
| 300 |
+
"decode_block_out_channels": (256, 512, 512, 512),
|
| 301 |
+
"attention_resolutions": (32,),
|
| 302 |
+
"resolution": 1024,
|
| 303 |
+
"num_layers": 2,
|
| 304 |
+
"patch_size": 4,
|
| 305 |
+
"patch_type": "haar",
|
| 306 |
+
"scaling_factor": 1.0,
|
| 307 |
+
"spatial_compression_ratio": 8,
|
| 308 |
+
"temporal_compression_ratio": 8,
|
| 309 |
+
"latents_mean": None,
|
| 310 |
+
"latents_std": None,
|
| 311 |
+
},
|
| 312 |
+
},
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
| 317 |
+
state_dict = saved_dict
|
| 318 |
+
if "model" in saved_dict.keys():
|
| 319 |
+
state_dict = state_dict["model"]
|
| 320 |
+
if "module" in saved_dict.keys():
|
| 321 |
+
state_dict = state_dict["module"]
|
| 322 |
+
if "state_dict" in saved_dict.keys():
|
| 323 |
+
state_dict = state_dict["state_dict"]
|
| 324 |
+
return state_dict
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: bool = True):
|
| 328 |
+
PREFIX_KEY = "net."
|
| 329 |
+
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=weights_only))
|
| 330 |
+
|
| 331 |
+
if "Cosmos-1.0" in transformer_type:
|
| 332 |
+
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
|
| 333 |
+
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
|
| 334 |
+
elif "Cosmos-2.0" in transformer_type:
|
| 335 |
+
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
| 336 |
+
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
| 337 |
+
else:
|
| 338 |
+
assert False
|
| 339 |
+
|
| 340 |
+
with init_empty_weights():
|
| 341 |
+
config = TRANSFORMER_CONFIGS[transformer_type]
|
| 342 |
+
transformer = CosmosTransformer3DModel(**config)
|
| 343 |
+
|
| 344 |
+
for key in list(original_state_dict.keys()):
|
| 345 |
+
new_key = key[:]
|
| 346 |
+
if new_key.startswith(PREFIX_KEY):
|
| 347 |
+
new_key = new_key.removeprefix(PREFIX_KEY)
|
| 348 |
+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
| 349 |
+
new_key = new_key.replace(replace_key, rename_key)
|
| 350 |
+
update_state_dict_(original_state_dict, key, new_key)
|
| 351 |
+
|
| 352 |
+
for key in list(original_state_dict.keys()):
|
| 353 |
+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
| 354 |
+
if special_key not in key:
|
| 355 |
+
continue
|
| 356 |
+
handler_fn_inplace(key, original_state_dict)
|
| 357 |
+
|
| 358 |
+
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
| 359 |
+
return transformer
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def convert_vae(vae_type: str):
|
| 363 |
+
model_name = VAE_CONFIGS[vae_type]["name"]
|
| 364 |
+
snapshot_directory = snapshot_download(model_name, repo_type="model")
|
| 365 |
+
directory = pathlib.Path(snapshot_directory)
|
| 366 |
+
|
| 367 |
+
autoencoder_file = directory / "autoencoder.jit"
|
| 368 |
+
mean_std_file = directory / "mean_std.pt"
|
| 369 |
+
|
| 370 |
+
original_state_dict = torch.jit.load(autoencoder_file.as_posix()).state_dict()
|
| 371 |
+
if mean_std_file.exists():
|
| 372 |
+
mean_std = torch.load(mean_std_file, map_location="cpu", weights_only=True)
|
| 373 |
+
else:
|
| 374 |
+
mean_std = (None, None)
|
| 375 |
+
|
| 376 |
+
config = VAE_CONFIGS[vae_type]["diffusers_config"]
|
| 377 |
+
config.update(
|
| 378 |
+
{
|
| 379 |
+
"latents_mean": mean_std[0].detach().cpu().numpy().tolist(),
|
| 380 |
+
"latents_std": mean_std[1].detach().cpu().numpy().tolist(),
|
| 381 |
+
}
|
| 382 |
+
)
|
| 383 |
+
vae = AutoencoderKLCosmos(**config)
|
| 384 |
+
|
| 385 |
+
for key in list(original_state_dict.keys()):
|
| 386 |
+
new_key = key[:]
|
| 387 |
+
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
| 388 |
+
new_key = new_key.replace(replace_key, rename_key)
|
| 389 |
+
update_state_dict_(original_state_dict, key, new_key)
|
| 390 |
+
|
| 391 |
+
for key in list(original_state_dict.keys()):
|
| 392 |
+
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
|
| 393 |
+
if special_key not in key:
|
| 394 |
+
continue
|
| 395 |
+
handler_fn_inplace(key, original_state_dict)
|
| 396 |
+
|
| 397 |
+
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
| 398 |
+
return vae
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def save_pipeline_cosmos_1_0(args, transformer, vae):
|
| 402 |
+
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
|
| 403 |
+
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
|
| 404 |
+
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
|
| 405 |
+
# So, the sigma_min values that is used is the default value of 0.002.
|
| 406 |
+
scheduler = EDMEulerScheduler(
|
| 407 |
+
sigma_min=0.002,
|
| 408 |
+
sigma_max=80,
|
| 409 |
+
sigma_data=0.5,
|
| 410 |
+
sigma_schedule="karras",
|
| 411 |
+
num_train_timesteps=1000,
|
| 412 |
+
prediction_type="epsilon",
|
| 413 |
+
rho=7.0,
|
| 414 |
+
final_sigmas_type="sigma_min",
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
pipe_cls = CosmosTextToWorldPipeline if "Text2World" in args.transformer_type else CosmosVideoToWorldPipeline
|
| 418 |
+
pipe = pipe_cls(
|
| 419 |
+
text_encoder=text_encoder,
|
| 420 |
+
tokenizer=tokenizer,
|
| 421 |
+
transformer=transformer,
|
| 422 |
+
vae=vae,
|
| 423 |
+
scheduler=scheduler,
|
| 424 |
+
safety_checker=lambda *args, **kwargs: None,
|
| 425 |
+
)
|
| 426 |
+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def save_pipeline_cosmos_2_0(args, transformer, vae):
|
| 430 |
+
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
|
| 431 |
+
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
|
| 432 |
+
|
| 433 |
+
scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
|
| 434 |
+
|
| 435 |
+
pipe_cls = Cosmos2TextToImagePipeline if "Text2Image" in args.transformer_type else Cosmos2VideoToWorldPipeline
|
| 436 |
+
pipe = pipe_cls(
|
| 437 |
+
text_encoder=text_encoder,
|
| 438 |
+
tokenizer=tokenizer,
|
| 439 |
+
transformer=transformer,
|
| 440 |
+
vae=vae,
|
| 441 |
+
scheduler=scheduler,
|
| 442 |
+
safety_checker=lambda *args, **kwargs: None,
|
| 443 |
+
)
|
| 444 |
+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def get_args():
|
| 448 |
+
parser = argparse.ArgumentParser()
|
| 449 |
+
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
|
| 450 |
+
parser.add_argument(
|
| 451 |
+
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
| 452 |
+
)
|
| 453 |
+
parser.add_argument(
|
| 454 |
+
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
|
| 455 |
+
)
|
| 456 |
+
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
|
| 457 |
+
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
|
| 458 |
+
parser.add_argument("--save_pipeline", action="store_true")
|
| 459 |
+
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
| 460 |
+
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
|
| 461 |
+
return parser.parse_args()
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
DTYPE_MAPPING = {
|
| 465 |
+
"fp32": torch.float32,
|
| 466 |
+
"fp16": torch.float16,
|
| 467 |
+
"bf16": torch.bfloat16,
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
if __name__ == "__main__":
|
| 472 |
+
args = get_args()
|
| 473 |
+
|
| 474 |
+
transformer = None
|
| 475 |
+
dtype = DTYPE_MAPPING[args.dtype]
|
| 476 |
+
|
| 477 |
+
if args.save_pipeline:
|
| 478 |
+
assert args.transformer_ckpt_path is not None
|
| 479 |
+
assert args.vae_type is not None
|
| 480 |
+
assert args.text_encoder_path is not None
|
| 481 |
+
assert args.tokenizer_path is not None
|
| 482 |
+
|
| 483 |
+
if args.transformer_ckpt_path is not None:
|
| 484 |
+
weights_only = "Cosmos-1.0" in args.transformer_type
|
| 485 |
+
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path, weights_only)
|
| 486 |
+
transformer = transformer.to(dtype=dtype)
|
| 487 |
+
if not args.save_pipeline:
|
| 488 |
+
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
| 489 |
+
|
| 490 |
+
if args.vae_type is not None:
|
| 491 |
+
if "Cosmos-1.0" in args.transformer_type:
|
| 492 |
+
vae = convert_vae(args.vae_type)
|
| 493 |
+
else:
|
| 494 |
+
vae = AutoencoderKLWan.from_pretrained(
|
| 495 |
+
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
|
| 496 |
+
)
|
| 497 |
+
if not args.save_pipeline:
|
| 498 |
+
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
| 499 |
+
|
| 500 |
+
if args.save_pipeline:
|
| 501 |
+
if "Cosmos-1.0" in args.transformer_type:
|
| 502 |
+
save_pipeline_cosmos_1_0(args, transformer, vae)
|
| 503 |
+
elif "Cosmos-2.0" in args.transformer_type:
|
| 504 |
+
save_pipeline_cosmos_2_0(args, transformer, vae)
|
| 505 |
+
else:
|
| 506 |
+
assert False
|
diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from diffusers import AutoencoderKL, DDPMPipeline, DDPMScheduler, UNet2DModel, VQModel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
| 10 |
+
"""
|
| 11 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
| 12 |
+
"""
|
| 13 |
+
if n_shave_prefix_segments >= 0:
|
| 14 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
| 15 |
+
else:
|
| 16 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 20 |
+
mapping = []
|
| 21 |
+
for old_item in old_list:
|
| 22 |
+
new_item = old_item
|
| 23 |
+
new_item = new_item.replace("block.", "resnets.")
|
| 24 |
+
new_item = new_item.replace("conv_shorcut", "conv1")
|
| 25 |
+
new_item = new_item.replace("in_shortcut", "conv_shortcut")
|
| 26 |
+
new_item = new_item.replace("temb_proj", "time_emb_proj")
|
| 27 |
+
|
| 28 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 29 |
+
|
| 30 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 31 |
+
|
| 32 |
+
return mapping
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):
|
| 36 |
+
mapping = []
|
| 37 |
+
for old_item in old_list:
|
| 38 |
+
new_item = old_item
|
| 39 |
+
|
| 40 |
+
# In `model.mid`, the layer is called `attn`.
|
| 41 |
+
if not in_mid:
|
| 42 |
+
new_item = new_item.replace("attn", "attentions")
|
| 43 |
+
new_item = new_item.replace(".k.", ".key.")
|
| 44 |
+
new_item = new_item.replace(".v.", ".value.")
|
| 45 |
+
new_item = new_item.replace(".q.", ".query.")
|
| 46 |
+
|
| 47 |
+
new_item = new_item.replace("proj_out", "proj_attn")
|
| 48 |
+
new_item = new_item.replace("norm", "group_norm")
|
| 49 |
+
|
| 50 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 51 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 52 |
+
|
| 53 |
+
return mapping
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def assign_to_checkpoint(
|
| 57 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
| 58 |
+
):
|
| 59 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
| 60 |
+
|
| 61 |
+
if attention_paths_to_split is not None:
|
| 62 |
+
if config is None:
|
| 63 |
+
raise ValueError("Please specify the config if setting 'attention_paths_to_split' to 'True'.")
|
| 64 |
+
|
| 65 |
+
for path, path_map in attention_paths_to_split.items():
|
| 66 |
+
old_tensor = old_checkpoint[path]
|
| 67 |
+
channels = old_tensor.shape[0] // 3
|
| 68 |
+
|
| 69 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
| 70 |
+
|
| 71 |
+
num_heads = old_tensor.shape[0] // config.get("num_head_channels", 1) // 3
|
| 72 |
+
|
| 73 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
| 74 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
| 75 |
+
|
| 76 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape).squeeze()
|
| 77 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape).squeeze()
|
| 78 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape).squeeze()
|
| 79 |
+
|
| 80 |
+
for path in paths:
|
| 81 |
+
new_path = path["new"]
|
| 82 |
+
|
| 83 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
new_path = new_path.replace("down.", "down_blocks.")
|
| 87 |
+
new_path = new_path.replace("up.", "up_blocks.")
|
| 88 |
+
|
| 89 |
+
if additional_replacements is not None:
|
| 90 |
+
for replacement in additional_replacements:
|
| 91 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
| 92 |
+
|
| 93 |
+
if "attentions" in new_path:
|
| 94 |
+
checkpoint[new_path] = old_checkpoint[path["old"]].squeeze()
|
| 95 |
+
else:
|
| 96 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def convert_ddpm_checkpoint(checkpoint, config):
|
| 100 |
+
"""
|
| 101 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
| 102 |
+
"""
|
| 103 |
+
new_checkpoint = {}
|
| 104 |
+
|
| 105 |
+
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["temb.dense.0.weight"]
|
| 106 |
+
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["temb.dense.0.bias"]
|
| 107 |
+
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["temb.dense.1.weight"]
|
| 108 |
+
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["temb.dense.1.bias"]
|
| 109 |
+
|
| 110 |
+
new_checkpoint["conv_norm_out.weight"] = checkpoint["norm_out.weight"]
|
| 111 |
+
new_checkpoint["conv_norm_out.bias"] = checkpoint["norm_out.bias"]
|
| 112 |
+
|
| 113 |
+
new_checkpoint["conv_in.weight"] = checkpoint["conv_in.weight"]
|
| 114 |
+
new_checkpoint["conv_in.bias"] = checkpoint["conv_in.bias"]
|
| 115 |
+
new_checkpoint["conv_out.weight"] = checkpoint["conv_out.weight"]
|
| 116 |
+
new_checkpoint["conv_out.bias"] = checkpoint["conv_out.bias"]
|
| 117 |
+
|
| 118 |
+
num_down_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "down" in layer})
|
| 119 |
+
down_blocks = {
|
| 120 |
+
layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
num_up_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "up" in layer})
|
| 124 |
+
up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
|
| 125 |
+
|
| 126 |
+
for i in range(num_down_blocks):
|
| 127 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
| 128 |
+
|
| 129 |
+
if any("downsample" in layer for layer in down_blocks[i]):
|
| 130 |
+
new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
|
| 131 |
+
f"down.{i}.downsample.op.weight"
|
| 132 |
+
]
|
| 133 |
+
new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[f"down.{i}.downsample.op.bias"]
|
| 134 |
+
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
|
| 135 |
+
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
|
| 136 |
+
|
| 137 |
+
if any("block" in layer for layer in down_blocks[i]):
|
| 138 |
+
num_blocks = len(
|
| 139 |
+
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "block" in layer}
|
| 140 |
+
)
|
| 141 |
+
blocks = {
|
| 142 |
+
layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
|
| 143 |
+
for layer_id in range(num_blocks)
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
if num_blocks > 0:
|
| 147 |
+
for j in range(config["layers_per_block"]):
|
| 148 |
+
paths = renew_resnet_paths(blocks[j])
|
| 149 |
+
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
|
| 150 |
+
|
| 151 |
+
if any("attn" in layer for layer in down_blocks[i]):
|
| 152 |
+
num_attn = len(
|
| 153 |
+
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "attn" in layer}
|
| 154 |
+
)
|
| 155 |
+
attns = {
|
| 156 |
+
layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
|
| 157 |
+
for layer_id in range(num_blocks)
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
if num_attn > 0:
|
| 161 |
+
for j in range(config["layers_per_block"]):
|
| 162 |
+
paths = renew_attention_paths(attns[j])
|
| 163 |
+
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
|
| 164 |
+
|
| 165 |
+
mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
|
| 166 |
+
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
|
| 167 |
+
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
|
| 168 |
+
|
| 169 |
+
# Mid new 2
|
| 170 |
+
paths = renew_resnet_paths(mid_block_1_layers)
|
| 171 |
+
assign_to_checkpoint(
|
| 172 |
+
paths,
|
| 173 |
+
new_checkpoint,
|
| 174 |
+
checkpoint,
|
| 175 |
+
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
paths = renew_resnet_paths(mid_block_2_layers)
|
| 179 |
+
assign_to_checkpoint(
|
| 180 |
+
paths,
|
| 181 |
+
new_checkpoint,
|
| 182 |
+
checkpoint,
|
| 183 |
+
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
|
| 187 |
+
assign_to_checkpoint(
|
| 188 |
+
paths,
|
| 189 |
+
new_checkpoint,
|
| 190 |
+
checkpoint,
|
| 191 |
+
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
for i in range(num_up_blocks):
|
| 195 |
+
block_id = num_up_blocks - 1 - i
|
| 196 |
+
|
| 197 |
+
if any("upsample" in layer for layer in up_blocks[i]):
|
| 198 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
|
| 199 |
+
f"up.{i}.upsample.conv.weight"
|
| 200 |
+
]
|
| 201 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[f"up.{i}.upsample.conv.bias"]
|
| 202 |
+
|
| 203 |
+
if any("block" in layer for layer in up_blocks[i]):
|
| 204 |
+
num_blocks = len(
|
| 205 |
+
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "block" in layer}
|
| 206 |
+
)
|
| 207 |
+
blocks = {
|
| 208 |
+
layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
if num_blocks > 0:
|
| 212 |
+
for j in range(config["layers_per_block"] + 1):
|
| 213 |
+
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
|
| 214 |
+
paths = renew_resnet_paths(blocks[j])
|
| 215 |
+
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
| 216 |
+
|
| 217 |
+
if any("attn" in layer for layer in up_blocks[i]):
|
| 218 |
+
num_attn = len(
|
| 219 |
+
{".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "attn" in layer}
|
| 220 |
+
)
|
| 221 |
+
attns = {
|
| 222 |
+
layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
if num_attn > 0:
|
| 226 |
+
for j in range(config["layers_per_block"] + 1):
|
| 227 |
+
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
|
| 228 |
+
paths = renew_attention_paths(attns[j])
|
| 229 |
+
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
| 230 |
+
|
| 231 |
+
new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
|
| 232 |
+
return new_checkpoint
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def convert_vq_autoenc_checkpoint(checkpoint, config):
|
| 236 |
+
"""
|
| 237 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
| 238 |
+
"""
|
| 239 |
+
new_checkpoint = {}
|
| 240 |
+
|
| 241 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = checkpoint["encoder.norm_out.weight"]
|
| 242 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = checkpoint["encoder.norm_out.bias"]
|
| 243 |
+
|
| 244 |
+
new_checkpoint["encoder.conv_in.weight"] = checkpoint["encoder.conv_in.weight"]
|
| 245 |
+
new_checkpoint["encoder.conv_in.bias"] = checkpoint["encoder.conv_in.bias"]
|
| 246 |
+
new_checkpoint["encoder.conv_out.weight"] = checkpoint["encoder.conv_out.weight"]
|
| 247 |
+
new_checkpoint["encoder.conv_out.bias"] = checkpoint["encoder.conv_out.bias"]
|
| 248 |
+
|
| 249 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = checkpoint["decoder.norm_out.weight"]
|
| 250 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = checkpoint["decoder.norm_out.bias"]
|
| 251 |
+
|
| 252 |
+
new_checkpoint["decoder.conv_in.weight"] = checkpoint["decoder.conv_in.weight"]
|
| 253 |
+
new_checkpoint["decoder.conv_in.bias"] = checkpoint["decoder.conv_in.bias"]
|
| 254 |
+
new_checkpoint["decoder.conv_out.weight"] = checkpoint["decoder.conv_out.weight"]
|
| 255 |
+
new_checkpoint["decoder.conv_out.bias"] = checkpoint["decoder.conv_out.bias"]
|
| 256 |
+
|
| 257 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "down" in layer})
|
| 258 |
+
down_blocks = {
|
| 259 |
+
layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "up" in layer})
|
| 263 |
+
up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
|
| 264 |
+
|
| 265 |
+
for i in range(num_down_blocks):
|
| 266 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
| 267 |
+
|
| 268 |
+
if any("downsample" in layer for layer in down_blocks[i]):
|
| 269 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
|
| 270 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
| 271 |
+
]
|
| 272 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[
|
| 273 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
| 274 |
+
]
|
| 275 |
+
|
| 276 |
+
if any("block" in layer for layer in down_blocks[i]):
|
| 277 |
+
num_blocks = len(
|
| 278 |
+
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "block" in layer}
|
| 279 |
+
)
|
| 280 |
+
blocks = {
|
| 281 |
+
layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
|
| 282 |
+
for layer_id in range(num_blocks)
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
if num_blocks > 0:
|
| 286 |
+
for j in range(config["layers_per_block"]):
|
| 287 |
+
paths = renew_resnet_paths(blocks[j])
|
| 288 |
+
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
|
| 289 |
+
|
| 290 |
+
if any("attn" in layer for layer in down_blocks[i]):
|
| 291 |
+
num_attn = len(
|
| 292 |
+
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "attn" in layer}
|
| 293 |
+
)
|
| 294 |
+
attns = {
|
| 295 |
+
layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
|
| 296 |
+
for layer_id in range(num_blocks)
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
if num_attn > 0:
|
| 300 |
+
for j in range(config["layers_per_block"]):
|
| 301 |
+
paths = renew_attention_paths(attns[j])
|
| 302 |
+
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
|
| 303 |
+
|
| 304 |
+
mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
|
| 305 |
+
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
|
| 306 |
+
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
|
| 307 |
+
|
| 308 |
+
# Mid new 2
|
| 309 |
+
paths = renew_resnet_paths(mid_block_1_layers)
|
| 310 |
+
assign_to_checkpoint(
|
| 311 |
+
paths,
|
| 312 |
+
new_checkpoint,
|
| 313 |
+
checkpoint,
|
| 314 |
+
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
paths = renew_resnet_paths(mid_block_2_layers)
|
| 318 |
+
assign_to_checkpoint(
|
| 319 |
+
paths,
|
| 320 |
+
new_checkpoint,
|
| 321 |
+
checkpoint,
|
| 322 |
+
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
|
| 326 |
+
assign_to_checkpoint(
|
| 327 |
+
paths,
|
| 328 |
+
new_checkpoint,
|
| 329 |
+
checkpoint,
|
| 330 |
+
additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
for i in range(num_up_blocks):
|
| 334 |
+
block_id = num_up_blocks - 1 - i
|
| 335 |
+
|
| 336 |
+
if any("upsample" in layer for layer in up_blocks[i]):
|
| 337 |
+
new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
|
| 338 |
+
f"decoder.up.{i}.upsample.conv.weight"
|
| 339 |
+
]
|
| 340 |
+
new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
|
| 341 |
+
f"decoder.up.{i}.upsample.conv.bias"
|
| 342 |
+
]
|
| 343 |
+
|
| 344 |
+
if any("block" in layer for layer in up_blocks[i]):
|
| 345 |
+
num_blocks = len(
|
| 346 |
+
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "block" in layer}
|
| 347 |
+
)
|
| 348 |
+
blocks = {
|
| 349 |
+
layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
if num_blocks > 0:
|
| 353 |
+
for j in range(config["layers_per_block"] + 1):
|
| 354 |
+
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
|
| 355 |
+
paths = renew_resnet_paths(blocks[j])
|
| 356 |
+
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
| 357 |
+
|
| 358 |
+
if any("attn" in layer for layer in up_blocks[i]):
|
| 359 |
+
num_attn = len(
|
| 360 |
+
{".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "attn" in layer}
|
| 361 |
+
)
|
| 362 |
+
attns = {
|
| 363 |
+
layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
if num_attn > 0:
|
| 367 |
+
for j in range(config["layers_per_block"] + 1):
|
| 368 |
+
replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
|
| 369 |
+
paths = renew_attention_paths(attns[j])
|
| 370 |
+
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
|
| 371 |
+
|
| 372 |
+
new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
|
| 373 |
+
new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
|
| 374 |
+
new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
|
| 375 |
+
if "quantize.embedding.weight" in checkpoint:
|
| 376 |
+
new_checkpoint["quantize.embedding.weight"] = checkpoint["quantize.embedding.weight"]
|
| 377 |
+
new_checkpoint["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"]
|
| 378 |
+
new_checkpoint["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"]
|
| 379 |
+
|
| 380 |
+
return new_checkpoint
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
if __name__ == "__main__":
|
| 384 |
+
parser = argparse.ArgumentParser()
|
| 385 |
+
|
| 386 |
+
parser.add_argument(
|
| 387 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
parser.add_argument(
|
| 391 |
+
"--config_file",
|
| 392 |
+
default=None,
|
| 393 |
+
type=str,
|
| 394 |
+
required=True,
|
| 395 |
+
help="The config json file corresponding to the architecture.",
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 399 |
+
|
| 400 |
+
args = parser.parse_args()
|
| 401 |
+
checkpoint = torch.load(args.checkpoint_path)
|
| 402 |
+
|
| 403 |
+
with open(args.config_file) as f:
|
| 404 |
+
config = json.loads(f.read())
|
| 405 |
+
|
| 406 |
+
# unet case
|
| 407 |
+
key_prefix_set = {key.split(".")[0] for key in checkpoint.keys()}
|
| 408 |
+
if "encoder" in key_prefix_set and "decoder" in key_prefix_set:
|
| 409 |
+
converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config)
|
| 410 |
+
else:
|
| 411 |
+
converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
|
| 412 |
+
|
| 413 |
+
if "ddpm" in config:
|
| 414 |
+
del config["ddpm"]
|
| 415 |
+
|
| 416 |
+
if config["_class_name"] == "VQModel":
|
| 417 |
+
model = VQModel(**config)
|
| 418 |
+
model.load_state_dict(converted_checkpoint)
|
| 419 |
+
model.save_pretrained(args.dump_path)
|
| 420 |
+
elif config["_class_name"] == "AutoencoderKL":
|
| 421 |
+
model = AutoencoderKL(**config)
|
| 422 |
+
model.load_state_dict(converted_checkpoint)
|
| 423 |
+
model.save_pretrained(args.dump_path)
|
| 424 |
+
else:
|
| 425 |
+
model = UNet2DModel(**config)
|
| 426 |
+
model.load_state_dict(converted_checkpoint)
|
| 427 |
+
|
| 428 |
+
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
|
| 429 |
+
|
| 430 |
+
pipe = DDPMPipeline(unet=model, scheduler=scheduler)
|
| 431 |
+
pipe.save_pretrained(args.dump_path)
|
diffusers/scripts/convert_diffusers_to_original_sdxl.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
|
| 2 |
+
# *Only* converts the UNet, VAE, and Text Encoder.
|
| 3 |
+
# Does not convert optimizer state or any other thing.
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import os.path as osp
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from safetensors.torch import load_file, save_file
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# =================#
|
| 14 |
+
# UNet Conversion #
|
| 15 |
+
# =================#
|
| 16 |
+
|
| 17 |
+
unet_conversion_map = [
|
| 18 |
+
# (stable-diffusion, HF Diffusers)
|
| 19 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
| 20 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
| 21 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
| 22 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
| 23 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
| 24 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
| 25 |
+
("out.0.weight", "conv_norm_out.weight"),
|
| 26 |
+
("out.0.bias", "conv_norm_out.bias"),
|
| 27 |
+
("out.2.weight", "conv_out.weight"),
|
| 28 |
+
("out.2.bias", "conv_out.bias"),
|
| 29 |
+
# the following are for sdxl
|
| 30 |
+
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
| 31 |
+
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
| 32 |
+
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
| 33 |
+
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
unet_conversion_map_resnet = [
|
| 37 |
+
# (stable-diffusion, HF Diffusers)
|
| 38 |
+
("in_layers.0", "norm1"),
|
| 39 |
+
("in_layers.2", "conv1"),
|
| 40 |
+
("out_layers.0", "norm2"),
|
| 41 |
+
("out_layers.3", "conv2"),
|
| 42 |
+
("emb_layers.1", "time_emb_proj"),
|
| 43 |
+
("skip_connection", "conv_shortcut"),
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
unet_conversion_map_layer = []
|
| 47 |
+
# hardcoded number of downblocks and resnets/attentions...
|
| 48 |
+
# would need smarter logic for other networks.
|
| 49 |
+
for i in range(3):
|
| 50 |
+
# loop over downblocks/upblocks
|
| 51 |
+
|
| 52 |
+
for j in range(2):
|
| 53 |
+
# loop over resnets/attentions for downblocks
|
| 54 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
| 55 |
+
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
| 56 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
| 57 |
+
|
| 58 |
+
if i > 0:
|
| 59 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
| 60 |
+
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
| 61 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
| 62 |
+
|
| 63 |
+
for j in range(4):
|
| 64 |
+
# loop over resnets/attentions for upblocks
|
| 65 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
| 66 |
+
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
| 67 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
| 68 |
+
|
| 69 |
+
if i < 2:
|
| 70 |
+
# no attention layers in up_blocks.0
|
| 71 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
| 72 |
+
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
| 73 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
| 74 |
+
|
| 75 |
+
if i < 3:
|
| 76 |
+
# no downsample in down_blocks.3
|
| 77 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
| 78 |
+
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
| 79 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
| 80 |
+
|
| 81 |
+
# no upsample in up_blocks.3
|
| 82 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
| 83 |
+
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
| 84 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
| 85 |
+
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
|
| 86 |
+
|
| 87 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
| 88 |
+
sd_mid_atn_prefix = "middle_block.1."
|
| 89 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
| 90 |
+
for j in range(2):
|
| 91 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
| 92 |
+
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
| 93 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def convert_unet_state_dict(unet_state_dict):
|
| 97 |
+
# buyer beware: this is a *brittle* function,
|
| 98 |
+
# and correct output requires that all of these pieces interact in
|
| 99 |
+
# the exact order in which I have arranged them.
|
| 100 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
| 101 |
+
for sd_name, hf_name in unet_conversion_map:
|
| 102 |
+
mapping[hf_name] = sd_name
|
| 103 |
+
for k, v in mapping.items():
|
| 104 |
+
if "resnets" in k:
|
| 105 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
| 106 |
+
v = v.replace(hf_part, sd_part)
|
| 107 |
+
mapping[k] = v
|
| 108 |
+
for k, v in mapping.items():
|
| 109 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
| 110 |
+
v = v.replace(hf_part, sd_part)
|
| 111 |
+
mapping[k] = v
|
| 112 |
+
new_state_dict = {sd_name: unet_state_dict[hf_name] for hf_name, sd_name in mapping.items()}
|
| 113 |
+
return new_state_dict
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ================#
|
| 117 |
+
# VAE Conversion #
|
| 118 |
+
# ================#
|
| 119 |
+
|
| 120 |
+
vae_conversion_map = [
|
| 121 |
+
# (stable-diffusion, HF Diffusers)
|
| 122 |
+
("nin_shortcut", "conv_shortcut"),
|
| 123 |
+
("norm_out", "conv_norm_out"),
|
| 124 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
for i in range(4):
|
| 128 |
+
# down_blocks have two resnets
|
| 129 |
+
for j in range(2):
|
| 130 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
| 131 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
| 132 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
| 133 |
+
|
| 134 |
+
if i < 3:
|
| 135 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
| 136 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
| 137 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
| 138 |
+
|
| 139 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
| 140 |
+
sd_upsample_prefix = f"up.{3 - i}.upsample."
|
| 141 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
| 142 |
+
|
| 143 |
+
# up_blocks have three resnets
|
| 144 |
+
# also, up blocks in hf are numbered in reverse from sd
|
| 145 |
+
for j in range(3):
|
| 146 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
| 147 |
+
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
|
| 148 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
| 149 |
+
|
| 150 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
| 151 |
+
for i in range(2):
|
| 152 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
| 153 |
+
sd_mid_res_prefix = f"mid.block_{i + 1}."
|
| 154 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
vae_conversion_map_attn = [
|
| 158 |
+
# (stable-diffusion, HF Diffusers)
|
| 159 |
+
("norm.", "group_norm."),
|
| 160 |
+
# the following are for SDXL
|
| 161 |
+
("q.", "to_q."),
|
| 162 |
+
("k.", "to_k."),
|
| 163 |
+
("v.", "to_v."),
|
| 164 |
+
("proj_out.", "to_out.0."),
|
| 165 |
+
]
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def reshape_weight_for_sd(w):
|
| 169 |
+
# convert HF linear weights to SD conv2d weights
|
| 170 |
+
if not w.ndim == 1:
|
| 171 |
+
return w.reshape(*w.shape, 1, 1)
|
| 172 |
+
else:
|
| 173 |
+
return w
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def convert_vae_state_dict(vae_state_dict):
|
| 177 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
| 178 |
+
for k, v in mapping.items():
|
| 179 |
+
for sd_part, hf_part in vae_conversion_map:
|
| 180 |
+
v = v.replace(hf_part, sd_part)
|
| 181 |
+
mapping[k] = v
|
| 182 |
+
for k, v in mapping.items():
|
| 183 |
+
if "attentions" in k:
|
| 184 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
| 185 |
+
v = v.replace(hf_part, sd_part)
|
| 186 |
+
mapping[k] = v
|
| 187 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
| 188 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
| 189 |
+
for k, v in new_state_dict.items():
|
| 190 |
+
for weight_name in weights_to_convert:
|
| 191 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
| 192 |
+
print(f"Reshaping {k} for SD format")
|
| 193 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
| 194 |
+
return new_state_dict
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# =========================#
|
| 198 |
+
# Text Encoder Conversion #
|
| 199 |
+
# =========================#
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
textenc_conversion_lst = [
|
| 203 |
+
# (stable-diffusion, HF Diffusers)
|
| 204 |
+
("transformer.resblocks.", "text_model.encoder.layers."),
|
| 205 |
+
("ln_1", "layer_norm1"),
|
| 206 |
+
("ln_2", "layer_norm2"),
|
| 207 |
+
(".c_fc.", ".fc1."),
|
| 208 |
+
(".c_proj.", ".fc2."),
|
| 209 |
+
(".attn", ".self_attn"),
|
| 210 |
+
("ln_final.", "text_model.final_layer_norm."),
|
| 211 |
+
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
| 212 |
+
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
| 213 |
+
]
|
| 214 |
+
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
|
| 215 |
+
textenc_pattern = re.compile("|".join(protected.keys()))
|
| 216 |
+
|
| 217 |
+
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
| 218 |
+
code2idx = {"q": 0, "k": 1, "v": 2}
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def convert_openclip_text_enc_state_dict(text_enc_dict):
|
| 222 |
+
new_state_dict = {}
|
| 223 |
+
capture_qkv_weight = {}
|
| 224 |
+
capture_qkv_bias = {}
|
| 225 |
+
for k, v in text_enc_dict.items():
|
| 226 |
+
if (
|
| 227 |
+
k.endswith(".self_attn.q_proj.weight")
|
| 228 |
+
or k.endswith(".self_attn.k_proj.weight")
|
| 229 |
+
or k.endswith(".self_attn.v_proj.weight")
|
| 230 |
+
):
|
| 231 |
+
k_pre = k[: -len(".q_proj.weight")]
|
| 232 |
+
k_code = k[-len("q_proj.weight")]
|
| 233 |
+
if k_pre not in capture_qkv_weight:
|
| 234 |
+
capture_qkv_weight[k_pre] = [None, None, None]
|
| 235 |
+
capture_qkv_weight[k_pre][code2idx[k_code]] = v
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
if (
|
| 239 |
+
k.endswith(".self_attn.q_proj.bias")
|
| 240 |
+
or k.endswith(".self_attn.k_proj.bias")
|
| 241 |
+
or k.endswith(".self_attn.v_proj.bias")
|
| 242 |
+
):
|
| 243 |
+
k_pre = k[: -len(".q_proj.bias")]
|
| 244 |
+
k_code = k[-len("q_proj.bias")]
|
| 245 |
+
if k_pre not in capture_qkv_bias:
|
| 246 |
+
capture_qkv_bias[k_pre] = [None, None, None]
|
| 247 |
+
capture_qkv_bias[k_pre][code2idx[k_code]] = v
|
| 248 |
+
continue
|
| 249 |
+
|
| 250 |
+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
| 251 |
+
new_state_dict[relabelled_key] = v
|
| 252 |
+
|
| 253 |
+
for k_pre, tensors in capture_qkv_weight.items():
|
| 254 |
+
if None in tensors:
|
| 255 |
+
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
| 256 |
+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
| 257 |
+
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
|
| 258 |
+
|
| 259 |
+
for k_pre, tensors in capture_qkv_bias.items():
|
| 260 |
+
if None in tensors:
|
| 261 |
+
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
| 262 |
+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
| 263 |
+
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
|
| 264 |
+
|
| 265 |
+
return new_state_dict
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def convert_openai_text_enc_state_dict(text_enc_dict):
|
| 269 |
+
return text_enc_dict
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
parser = argparse.ArgumentParser()
|
| 274 |
+
|
| 275 |
+
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
|
| 276 |
+
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 277 |
+
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
args = parser.parse_args()
|
| 283 |
+
|
| 284 |
+
assert args.model_path is not None, "Must provide a model path!"
|
| 285 |
+
|
| 286 |
+
assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
|
| 287 |
+
|
| 288 |
+
# Path for safetensors
|
| 289 |
+
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
|
| 290 |
+
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
|
| 291 |
+
text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
|
| 292 |
+
text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "model.safetensors")
|
| 293 |
+
|
| 294 |
+
# Load models from safetensors if it exists, if it doesn't pytorch
|
| 295 |
+
if osp.exists(unet_path):
|
| 296 |
+
unet_state_dict = load_file(unet_path, device="cpu")
|
| 297 |
+
else:
|
| 298 |
+
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
|
| 299 |
+
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
| 300 |
+
|
| 301 |
+
if osp.exists(vae_path):
|
| 302 |
+
vae_state_dict = load_file(vae_path, device="cpu")
|
| 303 |
+
else:
|
| 304 |
+
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
|
| 305 |
+
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
| 306 |
+
|
| 307 |
+
if osp.exists(text_enc_path):
|
| 308 |
+
text_enc_dict = load_file(text_enc_path, device="cpu")
|
| 309 |
+
else:
|
| 310 |
+
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
|
| 311 |
+
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
| 312 |
+
|
| 313 |
+
if osp.exists(text_enc_2_path):
|
| 314 |
+
text_enc_2_dict = load_file(text_enc_2_path, device="cpu")
|
| 315 |
+
else:
|
| 316 |
+
text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "pytorch_model.bin")
|
| 317 |
+
text_enc_2_dict = torch.load(text_enc_2_path, map_location="cpu")
|
| 318 |
+
|
| 319 |
+
# Convert the UNet model
|
| 320 |
+
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
| 321 |
+
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
| 322 |
+
|
| 323 |
+
# Convert the VAE model
|
| 324 |
+
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
| 325 |
+
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
| 326 |
+
|
| 327 |
+
# Convert text encoder 1
|
| 328 |
+
text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
|
| 329 |
+
text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
|
| 330 |
+
|
| 331 |
+
# Convert text encoder 2
|
| 332 |
+
text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
|
| 333 |
+
text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
|
| 334 |
+
# We call the `.T.contiguous()` to match what's done in
|
| 335 |
+
# https://github.com/huggingface/diffusers/blob/84905ca7287876b925b6bf8e9bb92fec21c78764/src/diffusers/loaders/single_file_utils.py#L1085
|
| 336 |
+
text_enc_2_dict["conditioner.embedders.1.model.text_projection"] = text_enc_2_dict.pop(
|
| 337 |
+
"conditioner.embedders.1.model.text_projection.weight"
|
| 338 |
+
).T.contiguous()
|
| 339 |
+
|
| 340 |
+
# Put together new checkpoint
|
| 341 |
+
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
|
| 342 |
+
|
| 343 |
+
if args.half:
|
| 344 |
+
state_dict = {k: v.half() for k, v in state_dict.items()}
|
| 345 |
+
|
| 346 |
+
if args.use_safetensors:
|
| 347 |
+
save_file(state_dict, args.checkpoint_path)
|
| 348 |
+
else:
|
| 349 |
+
state_dict = {"state_dict": state_dict}
|
| 350 |
+
torch.save(state_dict, args.checkpoint_path)
|
diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
|
| 2 |
+
# *Only* converts the UNet, VAE, and Text Encoder.
|
| 3 |
+
# Does not convert optimizer state or any other thing.
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import os.path as osp
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from safetensors.torch import load_file, save_file
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# =================#
|
| 14 |
+
# UNet Conversion #
|
| 15 |
+
# =================#
|
| 16 |
+
|
| 17 |
+
unet_conversion_map = [
|
| 18 |
+
# (stable-diffusion, HF Diffusers)
|
| 19 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
| 20 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
| 21 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
| 22 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
| 23 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
| 24 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
| 25 |
+
("out.0.weight", "conv_norm_out.weight"),
|
| 26 |
+
("out.0.bias", "conv_norm_out.bias"),
|
| 27 |
+
("out.2.weight", "conv_out.weight"),
|
| 28 |
+
("out.2.bias", "conv_out.bias"),
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
unet_conversion_map_resnet = [
|
| 32 |
+
# (stable-diffusion, HF Diffusers)
|
| 33 |
+
("in_layers.0", "norm1"),
|
| 34 |
+
("in_layers.2", "conv1"),
|
| 35 |
+
("out_layers.0", "norm2"),
|
| 36 |
+
("out_layers.3", "conv2"),
|
| 37 |
+
("emb_layers.1", "time_emb_proj"),
|
| 38 |
+
("skip_connection", "conv_shortcut"),
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
unet_conversion_map_layer = []
|
| 42 |
+
# hardcoded number of downblocks and resnets/attentions...
|
| 43 |
+
# would need smarter logic for other networks.
|
| 44 |
+
for i in range(4):
|
| 45 |
+
# loop over downblocks/upblocks
|
| 46 |
+
|
| 47 |
+
for j in range(2):
|
| 48 |
+
# loop over resnets/attentions for downblocks
|
| 49 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
| 50 |
+
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
| 51 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
| 52 |
+
|
| 53 |
+
if i < 3:
|
| 54 |
+
# no attention layers in down_blocks.3
|
| 55 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
| 56 |
+
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
| 57 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
| 58 |
+
|
| 59 |
+
for j in range(3):
|
| 60 |
+
# loop over resnets/attentions for upblocks
|
| 61 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
| 62 |
+
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
| 63 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
| 64 |
+
|
| 65 |
+
if i > 0:
|
| 66 |
+
# no attention layers in up_blocks.0
|
| 67 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
| 68 |
+
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
| 69 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
| 70 |
+
|
| 71 |
+
if i < 3:
|
| 72 |
+
# no downsample in down_blocks.3
|
| 73 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
| 74 |
+
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
| 75 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
| 76 |
+
|
| 77 |
+
# no upsample in up_blocks.3
|
| 78 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
| 79 |
+
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
| 80 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
| 81 |
+
|
| 82 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
| 83 |
+
sd_mid_atn_prefix = "middle_block.1."
|
| 84 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
| 85 |
+
|
| 86 |
+
for j in range(2):
|
| 87 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
| 88 |
+
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
| 89 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def convert_unet_state_dict(unet_state_dict):
|
| 93 |
+
# buyer beware: this is a *brittle* function,
|
| 94 |
+
# and correct output requires that all of these pieces interact in
|
| 95 |
+
# the exact order in which I have arranged them.
|
| 96 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
| 97 |
+
for sd_name, hf_name in unet_conversion_map:
|
| 98 |
+
mapping[hf_name] = sd_name
|
| 99 |
+
for k, v in mapping.items():
|
| 100 |
+
if "resnets" in k:
|
| 101 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
| 102 |
+
v = v.replace(hf_part, sd_part)
|
| 103 |
+
mapping[k] = v
|
| 104 |
+
for k, v in mapping.items():
|
| 105 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
| 106 |
+
v = v.replace(hf_part, sd_part)
|
| 107 |
+
mapping[k] = v
|
| 108 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
| 109 |
+
return new_state_dict
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ================#
|
| 113 |
+
# VAE Conversion #
|
| 114 |
+
# ================#
|
| 115 |
+
|
| 116 |
+
vae_conversion_map = [
|
| 117 |
+
# (stable-diffusion, HF Diffusers)
|
| 118 |
+
("nin_shortcut", "conv_shortcut"),
|
| 119 |
+
("norm_out", "conv_norm_out"),
|
| 120 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
for i in range(4):
|
| 124 |
+
# down_blocks have two resnets
|
| 125 |
+
for j in range(2):
|
| 126 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
| 127 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
| 128 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
| 129 |
+
|
| 130 |
+
if i < 3:
|
| 131 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
| 132 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
| 133 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
| 134 |
+
|
| 135 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
| 136 |
+
sd_upsample_prefix = f"up.{3 - i}.upsample."
|
| 137 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
| 138 |
+
|
| 139 |
+
# up_blocks have three resnets
|
| 140 |
+
# also, up blocks in hf are numbered in reverse from sd
|
| 141 |
+
for j in range(3):
|
| 142 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
| 143 |
+
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
|
| 144 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
| 145 |
+
|
| 146 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
| 147 |
+
for i in range(2):
|
| 148 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
| 149 |
+
sd_mid_res_prefix = f"mid.block_{i + 1}."
|
| 150 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
vae_conversion_map_attn = [
|
| 154 |
+
# (stable-diffusion, HF Diffusers)
|
| 155 |
+
("norm.", "group_norm."),
|
| 156 |
+
("q.", "query."),
|
| 157 |
+
("k.", "key."),
|
| 158 |
+
("v.", "value."),
|
| 159 |
+
("proj_out.", "proj_attn."),
|
| 160 |
+
]
|
| 161 |
+
|
| 162 |
+
# This is probably not the most ideal solution, but it does work.
|
| 163 |
+
vae_extra_conversion_map = [
|
| 164 |
+
("to_q", "q"),
|
| 165 |
+
("to_k", "k"),
|
| 166 |
+
("to_v", "v"),
|
| 167 |
+
("to_out.0", "proj_out"),
|
| 168 |
+
]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def reshape_weight_for_sd(w):
|
| 172 |
+
# convert HF linear weights to SD conv2d weights
|
| 173 |
+
if not w.ndim == 1:
|
| 174 |
+
return w.reshape(*w.shape, 1, 1)
|
| 175 |
+
else:
|
| 176 |
+
return w
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def convert_vae_state_dict(vae_state_dict):
|
| 180 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
| 181 |
+
for k, v in mapping.items():
|
| 182 |
+
for sd_part, hf_part in vae_conversion_map:
|
| 183 |
+
v = v.replace(hf_part, sd_part)
|
| 184 |
+
mapping[k] = v
|
| 185 |
+
for k, v in mapping.items():
|
| 186 |
+
if "attentions" in k:
|
| 187 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
| 188 |
+
v = v.replace(hf_part, sd_part)
|
| 189 |
+
mapping[k] = v
|
| 190 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
| 191 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
| 192 |
+
keys_to_rename = {}
|
| 193 |
+
for k, v in new_state_dict.items():
|
| 194 |
+
for weight_name in weights_to_convert:
|
| 195 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
| 196 |
+
print(f"Reshaping {k} for SD format")
|
| 197 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
| 198 |
+
for weight_name, real_weight_name in vae_extra_conversion_map:
|
| 199 |
+
if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
|
| 200 |
+
keys_to_rename[k] = k.replace(weight_name, real_weight_name)
|
| 201 |
+
for k, v in keys_to_rename.items():
|
| 202 |
+
if k in new_state_dict:
|
| 203 |
+
print(f"Renaming {k} to {v}")
|
| 204 |
+
new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
|
| 205 |
+
del new_state_dict[k]
|
| 206 |
+
return new_state_dict
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# =========================#
|
| 210 |
+
# Text Encoder Conversion #
|
| 211 |
+
# =========================#
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
textenc_conversion_lst = [
|
| 215 |
+
# (stable-diffusion, HF Diffusers)
|
| 216 |
+
("resblocks.", "text_model.encoder.layers."),
|
| 217 |
+
("ln_1", "layer_norm1"),
|
| 218 |
+
("ln_2", "layer_norm2"),
|
| 219 |
+
(".c_fc.", ".fc1."),
|
| 220 |
+
(".c_proj.", ".fc2."),
|
| 221 |
+
(".attn", ".self_attn"),
|
| 222 |
+
("ln_final.", "transformer.text_model.final_layer_norm."),
|
| 223 |
+
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
| 224 |
+
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
| 225 |
+
]
|
| 226 |
+
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
|
| 227 |
+
textenc_pattern = re.compile("|".join(protected.keys()))
|
| 228 |
+
|
| 229 |
+
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
| 230 |
+
code2idx = {"q": 0, "k": 1, "v": 2}
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def convert_text_enc_state_dict_v20(text_enc_dict):
|
| 234 |
+
new_state_dict = {}
|
| 235 |
+
capture_qkv_weight = {}
|
| 236 |
+
capture_qkv_bias = {}
|
| 237 |
+
for k, v in text_enc_dict.items():
|
| 238 |
+
if (
|
| 239 |
+
k.endswith(".self_attn.q_proj.weight")
|
| 240 |
+
or k.endswith(".self_attn.k_proj.weight")
|
| 241 |
+
or k.endswith(".self_attn.v_proj.weight")
|
| 242 |
+
):
|
| 243 |
+
k_pre = k[: -len(".q_proj.weight")]
|
| 244 |
+
k_code = k[-len("q_proj.weight")]
|
| 245 |
+
if k_pre not in capture_qkv_weight:
|
| 246 |
+
capture_qkv_weight[k_pre] = [None, None, None]
|
| 247 |
+
capture_qkv_weight[k_pre][code2idx[k_code]] = v
|
| 248 |
+
continue
|
| 249 |
+
|
| 250 |
+
if (
|
| 251 |
+
k.endswith(".self_attn.q_proj.bias")
|
| 252 |
+
or k.endswith(".self_attn.k_proj.bias")
|
| 253 |
+
or k.endswith(".self_attn.v_proj.bias")
|
| 254 |
+
):
|
| 255 |
+
k_pre = k[: -len(".q_proj.bias")]
|
| 256 |
+
k_code = k[-len("q_proj.bias")]
|
| 257 |
+
if k_pre not in capture_qkv_bias:
|
| 258 |
+
capture_qkv_bias[k_pre] = [None, None, None]
|
| 259 |
+
capture_qkv_bias[k_pre][code2idx[k_code]] = v
|
| 260 |
+
continue
|
| 261 |
+
|
| 262 |
+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
|
| 263 |
+
new_state_dict[relabelled_key] = v
|
| 264 |
+
|
| 265 |
+
for k_pre, tensors in capture_qkv_weight.items():
|
| 266 |
+
if None in tensors:
|
| 267 |
+
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
| 268 |
+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
| 269 |
+
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
|
| 270 |
+
|
| 271 |
+
for k_pre, tensors in capture_qkv_bias.items():
|
| 272 |
+
if None in tensors:
|
| 273 |
+
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
| 274 |
+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
| 275 |
+
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
|
| 276 |
+
|
| 277 |
+
return new_state_dict
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def convert_text_enc_state_dict(text_enc_dict):
|
| 281 |
+
return text_enc_dict
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
if __name__ == "__main__":
|
| 285 |
+
parser = argparse.ArgumentParser()
|
| 286 |
+
|
| 287 |
+
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
|
| 288 |
+
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 289 |
+
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
|
| 290 |
+
parser.add_argument(
|
| 291 |
+
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
args = parser.parse_args()
|
| 295 |
+
|
| 296 |
+
assert args.model_path is not None, "Must provide a model path!"
|
| 297 |
+
|
| 298 |
+
assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
|
| 299 |
+
|
| 300 |
+
# Path for safetensors
|
| 301 |
+
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
|
| 302 |
+
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
|
| 303 |
+
text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
|
| 304 |
+
|
| 305 |
+
# Load models from safetensors if it exists, if it doesn't pytorch
|
| 306 |
+
if osp.exists(unet_path):
|
| 307 |
+
unet_state_dict = load_file(unet_path, device="cpu")
|
| 308 |
+
else:
|
| 309 |
+
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
|
| 310 |
+
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
| 311 |
+
|
| 312 |
+
if osp.exists(vae_path):
|
| 313 |
+
vae_state_dict = load_file(vae_path, device="cpu")
|
| 314 |
+
else:
|
| 315 |
+
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
|
| 316 |
+
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
| 317 |
+
|
| 318 |
+
if osp.exists(text_enc_path):
|
| 319 |
+
text_enc_dict = load_file(text_enc_path, device="cpu")
|
| 320 |
+
else:
|
| 321 |
+
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
|
| 322 |
+
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
| 323 |
+
|
| 324 |
+
# Convert the UNet model
|
| 325 |
+
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
| 326 |
+
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
| 327 |
+
|
| 328 |
+
# Convert the VAE model
|
| 329 |
+
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
| 330 |
+
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
| 331 |
+
|
| 332 |
+
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
| 333 |
+
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
| 334 |
+
|
| 335 |
+
if is_v20_model:
|
| 336 |
+
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
|
| 337 |
+
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
|
| 338 |
+
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
|
| 339 |
+
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
|
| 340 |
+
else:
|
| 341 |
+
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
|
| 342 |
+
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
|
| 343 |
+
|
| 344 |
+
# Put together new checkpoint
|
| 345 |
+
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
|
| 346 |
+
if args.half:
|
| 347 |
+
state_dict = {k: v.half() for k, v in state_dict.items()}
|
| 348 |
+
|
| 349 |
+
if args.use_safetensors:
|
| 350 |
+
save_file(state_dict, args.checkpoint_path)
|
| 351 |
+
else:
|
| 352 |
+
state_dict = {"state_dict": state_dict}
|
| 353 |
+
torch.save(state_dict, args.checkpoint_path)
|
diffusers/scripts/convert_dit_to_diffusers.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision.datasets.utils import download_url
|
| 6 |
+
|
| 7 |
+
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def download_model(model_name):
|
| 14 |
+
"""
|
| 15 |
+
Downloads a pre-trained DiT model from the web.
|
| 16 |
+
"""
|
| 17 |
+
local_path = f"pretrained_models/{model_name}"
|
| 18 |
+
if not os.path.isfile(local_path):
|
| 19 |
+
os.makedirs("pretrained_models", exist_ok=True)
|
| 20 |
+
web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}"
|
| 21 |
+
download_url(web_path, "pretrained_models")
|
| 22 |
+
model = torch.load(local_path, map_location=lambda storage, loc: storage)
|
| 23 |
+
return model
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main(args):
|
| 27 |
+
state_dict = download_model(pretrained_models[args.image_size])
|
| 28 |
+
|
| 29 |
+
state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
|
| 30 |
+
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
|
| 31 |
+
state_dict.pop("x_embedder.proj.weight")
|
| 32 |
+
state_dict.pop("x_embedder.proj.bias")
|
| 33 |
+
|
| 34 |
+
for depth in range(28):
|
| 35 |
+
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.weight"] = state_dict[
|
| 36 |
+
"t_embedder.mlp.0.weight"
|
| 37 |
+
]
|
| 38 |
+
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.bias"] = state_dict[
|
| 39 |
+
"t_embedder.mlp.0.bias"
|
| 40 |
+
]
|
| 41 |
+
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.weight"] = state_dict[
|
| 42 |
+
"t_embedder.mlp.2.weight"
|
| 43 |
+
]
|
| 44 |
+
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.bias"] = state_dict[
|
| 45 |
+
"t_embedder.mlp.2.bias"
|
| 46 |
+
]
|
| 47 |
+
state_dict[f"transformer_blocks.{depth}.norm1.emb.class_embedder.embedding_table.weight"] = state_dict[
|
| 48 |
+
"y_embedder.embedding_table.weight"
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
state_dict[f"transformer_blocks.{depth}.norm1.linear.weight"] = state_dict[
|
| 52 |
+
f"blocks.{depth}.adaLN_modulation.1.weight"
|
| 53 |
+
]
|
| 54 |
+
state_dict[f"transformer_blocks.{depth}.norm1.linear.bias"] = state_dict[
|
| 55 |
+
f"blocks.{depth}.adaLN_modulation.1.bias"
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0)
|
| 59 |
+
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0)
|
| 60 |
+
|
| 61 |
+
state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
|
| 62 |
+
state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
|
| 63 |
+
state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
|
| 64 |
+
state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
|
| 65 |
+
state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
|
| 66 |
+
state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
|
| 67 |
+
|
| 68 |
+
state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict[
|
| 69 |
+
f"blocks.{depth}.attn.proj.weight"
|
| 70 |
+
]
|
| 71 |
+
state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"]
|
| 72 |
+
|
| 73 |
+
state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict[f"blocks.{depth}.mlp.fc1.weight"]
|
| 74 |
+
state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict[f"blocks.{depth}.mlp.fc1.bias"]
|
| 75 |
+
state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict[f"blocks.{depth}.mlp.fc2.weight"]
|
| 76 |
+
state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict[f"blocks.{depth}.mlp.fc2.bias"]
|
| 77 |
+
|
| 78 |
+
state_dict.pop(f"blocks.{depth}.attn.qkv.weight")
|
| 79 |
+
state_dict.pop(f"blocks.{depth}.attn.qkv.bias")
|
| 80 |
+
state_dict.pop(f"blocks.{depth}.attn.proj.weight")
|
| 81 |
+
state_dict.pop(f"blocks.{depth}.attn.proj.bias")
|
| 82 |
+
state_dict.pop(f"blocks.{depth}.mlp.fc1.weight")
|
| 83 |
+
state_dict.pop(f"blocks.{depth}.mlp.fc1.bias")
|
| 84 |
+
state_dict.pop(f"blocks.{depth}.mlp.fc2.weight")
|
| 85 |
+
state_dict.pop(f"blocks.{depth}.mlp.fc2.bias")
|
| 86 |
+
state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.weight")
|
| 87 |
+
state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.bias")
|
| 88 |
+
|
| 89 |
+
state_dict.pop("t_embedder.mlp.0.weight")
|
| 90 |
+
state_dict.pop("t_embedder.mlp.0.bias")
|
| 91 |
+
state_dict.pop("t_embedder.mlp.2.weight")
|
| 92 |
+
state_dict.pop("t_embedder.mlp.2.bias")
|
| 93 |
+
state_dict.pop("y_embedder.embedding_table.weight")
|
| 94 |
+
|
| 95 |
+
state_dict["proj_out_1.weight"] = state_dict["final_layer.adaLN_modulation.1.weight"]
|
| 96 |
+
state_dict["proj_out_1.bias"] = state_dict["final_layer.adaLN_modulation.1.bias"]
|
| 97 |
+
state_dict["proj_out_2.weight"] = state_dict["final_layer.linear.weight"]
|
| 98 |
+
state_dict["proj_out_2.bias"] = state_dict["final_layer.linear.bias"]
|
| 99 |
+
|
| 100 |
+
state_dict.pop("final_layer.linear.weight")
|
| 101 |
+
state_dict.pop("final_layer.linear.bias")
|
| 102 |
+
state_dict.pop("final_layer.adaLN_modulation.1.weight")
|
| 103 |
+
state_dict.pop("final_layer.adaLN_modulation.1.bias")
|
| 104 |
+
|
| 105 |
+
# DiT XL/2
|
| 106 |
+
transformer = Transformer2DModel(
|
| 107 |
+
sample_size=args.image_size // 8,
|
| 108 |
+
num_layers=28,
|
| 109 |
+
attention_head_dim=72,
|
| 110 |
+
in_channels=4,
|
| 111 |
+
out_channels=8,
|
| 112 |
+
patch_size=2,
|
| 113 |
+
attention_bias=True,
|
| 114 |
+
num_attention_heads=16,
|
| 115 |
+
activation_fn="gelu-approximate",
|
| 116 |
+
num_embeds_ada_norm=1000,
|
| 117 |
+
norm_type="ada_norm_zero",
|
| 118 |
+
norm_elementwise_affine=False,
|
| 119 |
+
)
|
| 120 |
+
transformer.load_state_dict(state_dict, strict=True)
|
| 121 |
+
|
| 122 |
+
scheduler = DDIMScheduler(
|
| 123 |
+
num_train_timesteps=1000,
|
| 124 |
+
beta_schedule="linear",
|
| 125 |
+
prediction_type="epsilon",
|
| 126 |
+
clip_sample=False,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
vae = AutoencoderKL.from_pretrained(args.vae_model)
|
| 130 |
+
|
| 131 |
+
pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler)
|
| 132 |
+
|
| 133 |
+
if args.save:
|
| 134 |
+
pipeline.save_pretrained(args.checkpoint_path)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
parser = argparse.ArgumentParser()
|
| 139 |
+
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--image_size",
|
| 142 |
+
default=256,
|
| 143 |
+
type=int,
|
| 144 |
+
required=False,
|
| 145 |
+
help="Image size of pretrained model, either 256 or 512.",
|
| 146 |
+
)
|
| 147 |
+
parser.add_argument(
|
| 148 |
+
"--vae_model",
|
| 149 |
+
default="stabilityai/sd-vae-ft-ema",
|
| 150 |
+
type=str,
|
| 151 |
+
required=False,
|
| 152 |
+
help="Path to pretrained VAE model, either stabilityai/sd-vae-ft-mse or stabilityai/sd-vae-ft-ema.",
|
| 153 |
+
)
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the output pipeline."
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
args = parser.parse_args()
|
| 162 |
+
main(args)
|
diffusers/scripts/convert_flux_to_diffusers.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from contextlib import nullcontext
|
| 3 |
+
|
| 4 |
+
import safetensors.torch
|
| 5 |
+
import torch
|
| 6 |
+
from accelerate import init_empty_weights
|
| 7 |
+
from huggingface_hub import hf_hub_download
|
| 8 |
+
|
| 9 |
+
from diffusers import AutoencoderKL, FluxTransformer2DModel
|
| 10 |
+
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
|
| 11 |
+
from diffusers.utils.import_utils import is_accelerate_available
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
"""
|
| 15 |
+
# Transformer
|
| 16 |
+
|
| 17 |
+
python scripts/convert_flux_to_diffusers.py \
|
| 18 |
+
--original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \
|
| 19 |
+
--filename "flux1-schnell.sft"
|
| 20 |
+
--output_path "flux-schnell" \
|
| 21 |
+
--transformer
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
# VAE
|
| 26 |
+
|
| 27 |
+
python scripts/convert_flux_to_diffusers.py \
|
| 28 |
+
--original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \
|
| 29 |
+
--filename "ae.sft"
|
| 30 |
+
--output_path "flux-schnell" \
|
| 31 |
+
--vae
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
| 35 |
+
|
| 36 |
+
parser = argparse.ArgumentParser()
|
| 37 |
+
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
|
| 38 |
+
parser.add_argument("--filename", default="flux.safetensors", type=str)
|
| 39 |
+
parser.add_argument("--checkpoint_path", default=None, type=str)
|
| 40 |
+
parser.add_argument("--in_channels", type=int, default=64)
|
| 41 |
+
parser.add_argument("--out_channels", type=int, default=None)
|
| 42 |
+
parser.add_argument("--vae", action="store_true")
|
| 43 |
+
parser.add_argument("--transformer", action="store_true")
|
| 44 |
+
parser.add_argument("--output_path", type=str)
|
| 45 |
+
parser.add_argument("--dtype", type=str, default="bf16")
|
| 46 |
+
|
| 47 |
+
args = parser.parse_args()
|
| 48 |
+
dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_original_checkpoint(args):
|
| 52 |
+
if args.original_state_dict_repo_id is not None:
|
| 53 |
+
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
|
| 54 |
+
elif args.checkpoint_path is not None:
|
| 55 |
+
ckpt_path = args.checkpoint_path
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
| 58 |
+
|
| 59 |
+
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
| 60 |
+
return original_state_dict
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
| 64 |
+
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
|
| 65 |
+
def swap_scale_shift(weight):
|
| 66 |
+
shift, scale = weight.chunk(2, dim=0)
|
| 67 |
+
new_weight = torch.cat([scale, shift], dim=0)
|
| 68 |
+
return new_weight
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def convert_flux_transformer_checkpoint_to_diffusers(
|
| 72 |
+
original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0
|
| 73 |
+
):
|
| 74 |
+
converted_state_dict = {}
|
| 75 |
+
|
| 76 |
+
## time_text_embed.timestep_embedder <- time_in
|
| 77 |
+
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
|
| 78 |
+
"time_in.in_layer.weight"
|
| 79 |
+
)
|
| 80 |
+
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
|
| 81 |
+
"time_in.in_layer.bias"
|
| 82 |
+
)
|
| 83 |
+
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
|
| 84 |
+
"time_in.out_layer.weight"
|
| 85 |
+
)
|
| 86 |
+
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
|
| 87 |
+
"time_in.out_layer.bias"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
## time_text_embed.text_embedder <- vector_in
|
| 91 |
+
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
|
| 92 |
+
"vector_in.in_layer.weight"
|
| 93 |
+
)
|
| 94 |
+
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
|
| 95 |
+
"vector_in.in_layer.bias"
|
| 96 |
+
)
|
| 97 |
+
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
|
| 98 |
+
"vector_in.out_layer.weight"
|
| 99 |
+
)
|
| 100 |
+
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
|
| 101 |
+
"vector_in.out_layer.bias"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# guidance
|
| 105 |
+
has_guidance = any("guidance" in k for k in original_state_dict)
|
| 106 |
+
if has_guidance:
|
| 107 |
+
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop(
|
| 108 |
+
"guidance_in.in_layer.weight"
|
| 109 |
+
)
|
| 110 |
+
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop(
|
| 111 |
+
"guidance_in.in_layer.bias"
|
| 112 |
+
)
|
| 113 |
+
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop(
|
| 114 |
+
"guidance_in.out_layer.weight"
|
| 115 |
+
)
|
| 116 |
+
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop(
|
| 117 |
+
"guidance_in.out_layer.bias"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# context_embedder
|
| 121 |
+
converted_state_dict["context_embedder.weight"] = original_state_dict.pop("txt_in.weight")
|
| 122 |
+
converted_state_dict["context_embedder.bias"] = original_state_dict.pop("txt_in.bias")
|
| 123 |
+
|
| 124 |
+
# x_embedder
|
| 125 |
+
converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight")
|
| 126 |
+
converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias")
|
| 127 |
+
|
| 128 |
+
# double transformer blocks
|
| 129 |
+
for i in range(num_layers):
|
| 130 |
+
block_prefix = f"transformer_blocks.{i}."
|
| 131 |
+
# norms.
|
| 132 |
+
## norm1
|
| 133 |
+
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
|
| 134 |
+
f"double_blocks.{i}.img_mod.lin.weight"
|
| 135 |
+
)
|
| 136 |
+
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
|
| 137 |
+
f"double_blocks.{i}.img_mod.lin.bias"
|
| 138 |
+
)
|
| 139 |
+
## norm1_context
|
| 140 |
+
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
|
| 141 |
+
f"double_blocks.{i}.txt_mod.lin.weight"
|
| 142 |
+
)
|
| 143 |
+
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
|
| 144 |
+
f"double_blocks.{i}.txt_mod.lin.bias"
|
| 145 |
+
)
|
| 146 |
+
# Q, K, V
|
| 147 |
+
sample_q, sample_k, sample_v = torch.chunk(
|
| 148 |
+
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
|
| 149 |
+
)
|
| 150 |
+
context_q, context_k, context_v = torch.chunk(
|
| 151 |
+
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
|
| 152 |
+
)
|
| 153 |
+
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
| 154 |
+
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
|
| 155 |
+
)
|
| 156 |
+
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
| 157 |
+
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
|
| 158 |
+
)
|
| 159 |
+
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
|
| 160 |
+
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
|
| 161 |
+
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
|
| 162 |
+
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
|
| 163 |
+
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
|
| 164 |
+
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
|
| 165 |
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
|
| 166 |
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
|
| 167 |
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
|
| 168 |
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
|
| 169 |
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
|
| 170 |
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
|
| 171 |
+
# qk_norm
|
| 172 |
+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
|
| 173 |
+
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
|
| 174 |
+
)
|
| 175 |
+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
|
| 176 |
+
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
|
| 177 |
+
)
|
| 178 |
+
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
|
| 179 |
+
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
|
| 180 |
+
)
|
| 181 |
+
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
|
| 182 |
+
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
|
| 183 |
+
)
|
| 184 |
+
# ff img_mlp
|
| 185 |
+
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
|
| 186 |
+
f"double_blocks.{i}.img_mlp.0.weight"
|
| 187 |
+
)
|
| 188 |
+
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
|
| 189 |
+
f"double_blocks.{i}.img_mlp.0.bias"
|
| 190 |
+
)
|
| 191 |
+
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
|
| 192 |
+
f"double_blocks.{i}.img_mlp.2.weight"
|
| 193 |
+
)
|
| 194 |
+
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
|
| 195 |
+
f"double_blocks.{i}.img_mlp.2.bias"
|
| 196 |
+
)
|
| 197 |
+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
|
| 198 |
+
f"double_blocks.{i}.txt_mlp.0.weight"
|
| 199 |
+
)
|
| 200 |
+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
|
| 201 |
+
f"double_blocks.{i}.txt_mlp.0.bias"
|
| 202 |
+
)
|
| 203 |
+
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
|
| 204 |
+
f"double_blocks.{i}.txt_mlp.2.weight"
|
| 205 |
+
)
|
| 206 |
+
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
|
| 207 |
+
f"double_blocks.{i}.txt_mlp.2.bias"
|
| 208 |
+
)
|
| 209 |
+
# output projections.
|
| 210 |
+
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
|
| 211 |
+
f"double_blocks.{i}.img_attn.proj.weight"
|
| 212 |
+
)
|
| 213 |
+
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
|
| 214 |
+
f"double_blocks.{i}.img_attn.proj.bias"
|
| 215 |
+
)
|
| 216 |
+
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
|
| 217 |
+
f"double_blocks.{i}.txt_attn.proj.weight"
|
| 218 |
+
)
|
| 219 |
+
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
|
| 220 |
+
f"double_blocks.{i}.txt_attn.proj.bias"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# single transformer blocks
|
| 224 |
+
for i in range(num_single_layers):
|
| 225 |
+
block_prefix = f"single_transformer_blocks.{i}."
|
| 226 |
+
# norm.linear <- single_blocks.0.modulation.lin
|
| 227 |
+
converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop(
|
| 228 |
+
f"single_blocks.{i}.modulation.lin.weight"
|
| 229 |
+
)
|
| 230 |
+
converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop(
|
| 231 |
+
f"single_blocks.{i}.modulation.lin.bias"
|
| 232 |
+
)
|
| 233 |
+
# Q, K, V, mlp
|
| 234 |
+
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
| 235 |
+
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
| 236 |
+
q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
|
| 237 |
+
q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
| 238 |
+
original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
|
| 239 |
+
)
|
| 240 |
+
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
|
| 241 |
+
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
|
| 242 |
+
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
|
| 243 |
+
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
|
| 244 |
+
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
|
| 245 |
+
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
|
| 246 |
+
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
|
| 247 |
+
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
|
| 248 |
+
# qk norm
|
| 249 |
+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
|
| 250 |
+
f"single_blocks.{i}.norm.query_norm.scale"
|
| 251 |
+
)
|
| 252 |
+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
|
| 253 |
+
f"single_blocks.{i}.norm.key_norm.scale"
|
| 254 |
+
)
|
| 255 |
+
# output projections.
|
| 256 |
+
converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop(
|
| 257 |
+
f"single_blocks.{i}.linear2.weight"
|
| 258 |
+
)
|
| 259 |
+
converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop(
|
| 260 |
+
f"single_blocks.{i}.linear2.bias"
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
|
| 264 |
+
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
|
| 265 |
+
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
| 266 |
+
original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
|
| 267 |
+
)
|
| 268 |
+
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
|
| 269 |
+
original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
return converted_state_dict
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def main(args):
|
| 276 |
+
original_ckpt = load_original_checkpoint(args)
|
| 277 |
+
has_guidance = any("guidance" in k for k in original_ckpt)
|
| 278 |
+
|
| 279 |
+
if args.transformer:
|
| 280 |
+
num_layers = 19
|
| 281 |
+
num_single_layers = 38
|
| 282 |
+
inner_dim = 3072
|
| 283 |
+
mlp_ratio = 4.0
|
| 284 |
+
|
| 285 |
+
converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
|
| 286 |
+
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
|
| 287 |
+
)
|
| 288 |
+
transformer = FluxTransformer2DModel(
|
| 289 |
+
in_channels=args.in_channels, out_channels=args.out_channels, guidance_embeds=has_guidance
|
| 290 |
+
)
|
| 291 |
+
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
|
| 292 |
+
|
| 293 |
+
print(
|
| 294 |
+
f"Saving Flux Transformer in Diffusers format. Variant: {'guidance-distilled' if has_guidance else 'timestep-distilled'}"
|
| 295 |
+
)
|
| 296 |
+
transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
|
| 297 |
+
|
| 298 |
+
if args.vae:
|
| 299 |
+
config = AutoencoderKL.load_config("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae")
|
| 300 |
+
vae = AutoencoderKL.from_config(config, scaling_factor=0.3611, shift_factor=0.1159).to(torch.bfloat16)
|
| 301 |
+
|
| 302 |
+
converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config)
|
| 303 |
+
vae.load_state_dict(converted_vae_state_dict, strict=True)
|
| 304 |
+
vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
if __name__ == "__main__":
|
| 308 |
+
main(args)
|
diffusers/scripts/convert_gligen_to_diffusers.py
ADDED
|
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import yaml
|
| 6 |
+
from transformers import (
|
| 7 |
+
CLIPProcessor,
|
| 8 |
+
CLIPTextModel,
|
| 9 |
+
CLIPTokenizer,
|
| 10 |
+
CLIPVisionModelWithProjection,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from diffusers import (
|
| 14 |
+
AutoencoderKL,
|
| 15 |
+
DDIMScheduler,
|
| 16 |
+
StableDiffusionGLIGENPipeline,
|
| 17 |
+
StableDiffusionGLIGENTextImagePipeline,
|
| 18 |
+
UNet2DConditionModel,
|
| 19 |
+
)
|
| 20 |
+
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
| 21 |
+
assign_to_checkpoint,
|
| 22 |
+
conv_attn_to_linear,
|
| 23 |
+
protected,
|
| 24 |
+
renew_attention_paths,
|
| 25 |
+
renew_resnet_paths,
|
| 26 |
+
renew_vae_attention_paths,
|
| 27 |
+
renew_vae_resnet_paths,
|
| 28 |
+
shave_segments,
|
| 29 |
+
textenc_conversion_map,
|
| 30 |
+
textenc_pattern,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def convert_open_clip_checkpoint(checkpoint):
|
| 35 |
+
checkpoint = checkpoint["text_encoder"]
|
| 36 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
| 37 |
+
|
| 38 |
+
keys = list(checkpoint.keys())
|
| 39 |
+
|
| 40 |
+
text_model_dict = {}
|
| 41 |
+
|
| 42 |
+
if "cond_stage_model.model.text_projection" in checkpoint:
|
| 43 |
+
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
|
| 44 |
+
else:
|
| 45 |
+
d_model = 1024
|
| 46 |
+
|
| 47 |
+
for key in keys:
|
| 48 |
+
if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
|
| 49 |
+
continue
|
| 50 |
+
if key in textenc_conversion_map:
|
| 51 |
+
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
|
| 52 |
+
# if key.startswith("cond_stage_model.model.transformer."):
|
| 53 |
+
new_key = key[len("transformer.") :]
|
| 54 |
+
if new_key.endswith(".in_proj_weight"):
|
| 55 |
+
new_key = new_key[: -len(".in_proj_weight")]
|
| 56 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
| 57 |
+
text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
| 58 |
+
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
|
| 59 |
+
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
|
| 60 |
+
elif new_key.endswith(".in_proj_bias"):
|
| 61 |
+
new_key = new_key[: -len(".in_proj_bias")]
|
| 62 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
| 63 |
+
text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
|
| 64 |
+
text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
|
| 65 |
+
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
|
| 66 |
+
else:
|
| 67 |
+
if key != "transformer.text_model.embeddings.position_ids":
|
| 68 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
| 69 |
+
|
| 70 |
+
text_model_dict[new_key] = checkpoint[key]
|
| 71 |
+
|
| 72 |
+
if key == "transformer.text_model.embeddings.token_embedding.weight":
|
| 73 |
+
text_model_dict["text_model.embeddings.token_embedding.weight"] = checkpoint[key]
|
| 74 |
+
|
| 75 |
+
text_model_dict.pop("text_model.embeddings.transformer.text_model.embeddings.token_embedding.weight")
|
| 76 |
+
|
| 77 |
+
text_model.load_state_dict(text_model_dict)
|
| 78 |
+
|
| 79 |
+
return text_model
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def convert_gligen_vae_checkpoint(checkpoint, config):
|
| 83 |
+
checkpoint = checkpoint["autoencoder"]
|
| 84 |
+
vae_state_dict = {}
|
| 85 |
+
vae_key = "first_stage_model."
|
| 86 |
+
keys = list(checkpoint.keys())
|
| 87 |
+
for key in keys:
|
| 88 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
| 89 |
+
|
| 90 |
+
new_checkpoint = {}
|
| 91 |
+
|
| 92 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
| 93 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
| 94 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
| 95 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
| 96 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
| 97 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
| 98 |
+
|
| 99 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
| 100 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
| 101 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
| 102 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
| 103 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
| 104 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
| 105 |
+
|
| 106 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
| 107 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
| 108 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
| 109 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
| 110 |
+
|
| 111 |
+
# Retrieves the keys for the encoder down blocks only
|
| 112 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
| 113 |
+
down_blocks = {
|
| 114 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
# Retrieves the keys for the decoder up blocks only
|
| 118 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
| 119 |
+
up_blocks = {
|
| 120 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
for i in range(num_down_blocks):
|
| 124 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
| 125 |
+
|
| 126 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
| 127 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
| 128 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
| 129 |
+
)
|
| 130 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
| 131 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 135 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
| 136 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 137 |
+
|
| 138 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
| 139 |
+
num_mid_res_blocks = 2
|
| 140 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 141 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
| 142 |
+
|
| 143 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 144 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 145 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 146 |
+
|
| 147 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
| 148 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 149 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 150 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 151 |
+
conv_attn_to_linear(new_checkpoint)
|
| 152 |
+
|
| 153 |
+
for i in range(num_up_blocks):
|
| 154 |
+
block_id = num_up_blocks - 1 - i
|
| 155 |
+
resnets = [
|
| 156 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
| 160 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
| 161 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
| 162 |
+
]
|
| 163 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
| 164 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
| 165 |
+
]
|
| 166 |
+
|
| 167 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 168 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
| 169 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 170 |
+
|
| 171 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
| 172 |
+
num_mid_res_blocks = 2
|
| 173 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 174 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
| 175 |
+
|
| 176 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 177 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 178 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 179 |
+
|
| 180 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
| 181 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 182 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 183 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 184 |
+
conv_attn_to_linear(new_checkpoint)
|
| 185 |
+
|
| 186 |
+
for key in new_checkpoint.keys():
|
| 187 |
+
if "encoder.mid_block.attentions.0" in key or "decoder.mid_block.attentions.0" in key:
|
| 188 |
+
if "query" in key:
|
| 189 |
+
new_checkpoint[key.replace("query", "to_q")] = new_checkpoint.pop(key)
|
| 190 |
+
if "value" in key:
|
| 191 |
+
new_checkpoint[key.replace("value", "to_v")] = new_checkpoint.pop(key)
|
| 192 |
+
if "key" in key:
|
| 193 |
+
new_checkpoint[key.replace("key", "to_k")] = new_checkpoint.pop(key)
|
| 194 |
+
if "proj_attn" in key:
|
| 195 |
+
new_checkpoint[key.replace("proj_attn", "to_out.0")] = new_checkpoint.pop(key)
|
| 196 |
+
|
| 197 |
+
return new_checkpoint
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def convert_gligen_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
|
| 201 |
+
unet_state_dict = {}
|
| 202 |
+
checkpoint = checkpoint["model"]
|
| 203 |
+
keys = list(checkpoint.keys())
|
| 204 |
+
|
| 205 |
+
unet_key = "model.diffusion_model."
|
| 206 |
+
|
| 207 |
+
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
| 208 |
+
print(f"Checkpoint {path} has bot EMA and non-EMA weights.")
|
| 209 |
+
print(
|
| 210 |
+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
| 211 |
+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
| 212 |
+
)
|
| 213 |
+
for key in keys:
|
| 214 |
+
if key.startswith("model.diffusion_model"):
|
| 215 |
+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
| 216 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
| 217 |
+
else:
|
| 218 |
+
if sum(k.startswith("model_ema") for k in keys) > 100:
|
| 219 |
+
print(
|
| 220 |
+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
| 221 |
+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
| 222 |
+
)
|
| 223 |
+
for key in keys:
|
| 224 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
| 225 |
+
|
| 226 |
+
new_checkpoint = {}
|
| 227 |
+
|
| 228 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
| 229 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
| 230 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
| 231 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
| 232 |
+
|
| 233 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
| 234 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
| 235 |
+
|
| 236 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
| 237 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
| 238 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
| 239 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
| 240 |
+
|
| 241 |
+
# Retrieves the keys for the input blocks only
|
| 242 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
| 243 |
+
input_blocks = {
|
| 244 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
| 245 |
+
for layer_id in range(num_input_blocks)
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
# Retrieves the keys for the middle blocks only
|
| 249 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
| 250 |
+
middle_blocks = {
|
| 251 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
| 252 |
+
for layer_id in range(num_middle_blocks)
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
# Retrieves the keys for the output blocks only
|
| 256 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
| 257 |
+
output_blocks = {
|
| 258 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
| 259 |
+
for layer_id in range(num_output_blocks)
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
for i in range(1, num_input_blocks):
|
| 263 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
| 264 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
| 265 |
+
|
| 266 |
+
resnets = [
|
| 267 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
| 268 |
+
]
|
| 269 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
| 270 |
+
|
| 271 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
| 272 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
| 273 |
+
f"input_blocks.{i}.0.op.weight"
|
| 274 |
+
)
|
| 275 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
| 276 |
+
f"input_blocks.{i}.0.op.bias"
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
paths = renew_resnet_paths(resnets)
|
| 280 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 281 |
+
assign_to_checkpoint(
|
| 282 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
if len(attentions):
|
| 286 |
+
paths = renew_attention_paths(attentions)
|
| 287 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
| 288 |
+
assign_to_checkpoint(
|
| 289 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
resnet_0 = middle_blocks[0]
|
| 293 |
+
attentions = middle_blocks[1]
|
| 294 |
+
resnet_1 = middle_blocks[2]
|
| 295 |
+
|
| 296 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
| 297 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
| 298 |
+
|
| 299 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
| 300 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
| 301 |
+
|
| 302 |
+
attentions_paths = renew_attention_paths(attentions)
|
| 303 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
| 304 |
+
assign_to_checkpoint(
|
| 305 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
for i in range(num_output_blocks):
|
| 309 |
+
block_id = i // (config["layers_per_block"] + 1)
|
| 310 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
| 311 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
| 312 |
+
output_block_list = {}
|
| 313 |
+
|
| 314 |
+
for layer in output_block_layers:
|
| 315 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
| 316 |
+
if layer_id in output_block_list:
|
| 317 |
+
output_block_list[layer_id].append(layer_name)
|
| 318 |
+
else:
|
| 319 |
+
output_block_list[layer_id] = [layer_name]
|
| 320 |
+
|
| 321 |
+
if len(output_block_list) > 1:
|
| 322 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
| 323 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
| 324 |
+
|
| 325 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
| 326 |
+
paths = renew_resnet_paths(resnets)
|
| 327 |
+
|
| 328 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 329 |
+
assign_to_checkpoint(
|
| 330 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
| 334 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
| 335 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
| 336 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
| 337 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
| 338 |
+
]
|
| 339 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
| 340 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
| 341 |
+
]
|
| 342 |
+
|
| 343 |
+
# Clear attentions as they have been attributed above.
|
| 344 |
+
if len(attentions) == 2:
|
| 345 |
+
attentions = []
|
| 346 |
+
|
| 347 |
+
if len(attentions):
|
| 348 |
+
paths = renew_attention_paths(attentions)
|
| 349 |
+
meta_path = {
|
| 350 |
+
"old": f"output_blocks.{i}.1",
|
| 351 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
| 352 |
+
}
|
| 353 |
+
assign_to_checkpoint(
|
| 354 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 355 |
+
)
|
| 356 |
+
else:
|
| 357 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
| 358 |
+
for path in resnet_0_paths:
|
| 359 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
| 360 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
| 361 |
+
|
| 362 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
| 363 |
+
|
| 364 |
+
for key in keys:
|
| 365 |
+
if "position_net" in key:
|
| 366 |
+
new_checkpoint[key] = unet_state_dict[key]
|
| 367 |
+
|
| 368 |
+
return new_checkpoint
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def create_vae_config(original_config, image_size: int):
|
| 372 |
+
vae_params = original_config["autoencoder"]["params"]["ddconfig"]
|
| 373 |
+
_ = original_config["autoencoder"]["params"]["embed_dim"]
|
| 374 |
+
|
| 375 |
+
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
| 376 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
| 377 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
| 378 |
+
|
| 379 |
+
config = {
|
| 380 |
+
"sample_size": image_size,
|
| 381 |
+
"in_channels": vae_params["in_channels"],
|
| 382 |
+
"out_channels": vae_params["out_ch"],
|
| 383 |
+
"down_block_types": tuple(down_block_types),
|
| 384 |
+
"up_block_types": tuple(up_block_types),
|
| 385 |
+
"block_out_channels": tuple(block_out_channels),
|
| 386 |
+
"latent_channels": vae_params["z_channels"],
|
| 387 |
+
"layers_per_block": vae_params["num_res_blocks"],
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
return config
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def create_unet_config(original_config, image_size: int, attention_type):
|
| 394 |
+
unet_params = original_config["model"]["params"]
|
| 395 |
+
vae_params = original_config["autoencoder"]["params"]["ddconfig"]
|
| 396 |
+
|
| 397 |
+
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
|
| 398 |
+
|
| 399 |
+
down_block_types = []
|
| 400 |
+
resolution = 1
|
| 401 |
+
for i in range(len(block_out_channels)):
|
| 402 |
+
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
|
| 403 |
+
down_block_types.append(block_type)
|
| 404 |
+
if i != len(block_out_channels) - 1:
|
| 405 |
+
resolution *= 2
|
| 406 |
+
|
| 407 |
+
up_block_types = []
|
| 408 |
+
for i in range(len(block_out_channels)):
|
| 409 |
+
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
|
| 410 |
+
up_block_types.append(block_type)
|
| 411 |
+
resolution //= 2
|
| 412 |
+
|
| 413 |
+
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
|
| 414 |
+
|
| 415 |
+
head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
|
| 416 |
+
use_linear_projection = (
|
| 417 |
+
unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
|
| 418 |
+
)
|
| 419 |
+
if use_linear_projection:
|
| 420 |
+
if head_dim is None:
|
| 421 |
+
head_dim = [5, 10, 20, 20]
|
| 422 |
+
|
| 423 |
+
config = {
|
| 424 |
+
"sample_size": image_size // vae_scale_factor,
|
| 425 |
+
"in_channels": unet_params["in_channels"],
|
| 426 |
+
"down_block_types": tuple(down_block_types),
|
| 427 |
+
"block_out_channels": tuple(block_out_channels),
|
| 428 |
+
"layers_per_block": unet_params["num_res_blocks"],
|
| 429 |
+
"cross_attention_dim": unet_params["context_dim"],
|
| 430 |
+
"attention_head_dim": head_dim,
|
| 431 |
+
"use_linear_projection": use_linear_projection,
|
| 432 |
+
"attention_type": attention_type,
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
return config
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def convert_gligen_to_diffusers(
|
| 439 |
+
checkpoint_path: str,
|
| 440 |
+
original_config_file: str,
|
| 441 |
+
attention_type: str,
|
| 442 |
+
image_size: int = 512,
|
| 443 |
+
extract_ema: bool = False,
|
| 444 |
+
num_in_channels: int = None,
|
| 445 |
+
device: str = None,
|
| 446 |
+
):
|
| 447 |
+
if device is None:
|
| 448 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 449 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 450 |
+
else:
|
| 451 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 452 |
+
|
| 453 |
+
if "global_step" in checkpoint:
|
| 454 |
+
checkpoint["global_step"]
|
| 455 |
+
else:
|
| 456 |
+
print("global_step key not found in model")
|
| 457 |
+
|
| 458 |
+
original_config = yaml.safe_load(original_config_file)
|
| 459 |
+
|
| 460 |
+
if num_in_channels is not None:
|
| 461 |
+
original_config["model"]["params"]["in_channels"] = num_in_channels
|
| 462 |
+
|
| 463 |
+
num_train_timesteps = original_config["diffusion"]["params"]["timesteps"]
|
| 464 |
+
beta_start = original_config["diffusion"]["params"]["linear_start"]
|
| 465 |
+
beta_end = original_config["diffusion"]["params"]["linear_end"]
|
| 466 |
+
|
| 467 |
+
scheduler = DDIMScheduler(
|
| 468 |
+
beta_end=beta_end,
|
| 469 |
+
beta_schedule="scaled_linear",
|
| 470 |
+
beta_start=beta_start,
|
| 471 |
+
num_train_timesteps=num_train_timesteps,
|
| 472 |
+
steps_offset=1,
|
| 473 |
+
clip_sample=False,
|
| 474 |
+
set_alpha_to_one=False,
|
| 475 |
+
prediction_type="epsilon",
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# Convert the UNet2DConditionalModel model
|
| 479 |
+
unet_config = create_unet_config(original_config, image_size, attention_type)
|
| 480 |
+
unet = UNet2DConditionModel(**unet_config)
|
| 481 |
+
|
| 482 |
+
converted_unet_checkpoint = convert_gligen_unet_checkpoint(
|
| 483 |
+
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
unet.load_state_dict(converted_unet_checkpoint)
|
| 487 |
+
|
| 488 |
+
# Convert the VAE model
|
| 489 |
+
vae_config = create_vae_config(original_config, image_size)
|
| 490 |
+
converted_vae_checkpoint = convert_gligen_vae_checkpoint(checkpoint, vae_config)
|
| 491 |
+
|
| 492 |
+
vae = AutoencoderKL(**vae_config)
|
| 493 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
| 494 |
+
|
| 495 |
+
# Convert the text model
|
| 496 |
+
text_encoder = convert_open_clip_checkpoint(checkpoint)
|
| 497 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
| 498 |
+
|
| 499 |
+
if attention_type == "gated-text-image":
|
| 500 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
| 501 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
| 502 |
+
|
| 503 |
+
pipe = StableDiffusionGLIGENTextImagePipeline(
|
| 504 |
+
vae=vae,
|
| 505 |
+
text_encoder=text_encoder,
|
| 506 |
+
tokenizer=tokenizer,
|
| 507 |
+
image_encoder=image_encoder,
|
| 508 |
+
processor=processor,
|
| 509 |
+
unet=unet,
|
| 510 |
+
scheduler=scheduler,
|
| 511 |
+
safety_checker=None,
|
| 512 |
+
feature_extractor=None,
|
| 513 |
+
)
|
| 514 |
+
elif attention_type == "gated":
|
| 515 |
+
pipe = StableDiffusionGLIGENPipeline(
|
| 516 |
+
vae=vae,
|
| 517 |
+
text_encoder=text_encoder,
|
| 518 |
+
tokenizer=tokenizer,
|
| 519 |
+
unet=unet,
|
| 520 |
+
scheduler=scheduler,
|
| 521 |
+
safety_checker=None,
|
| 522 |
+
feature_extractor=None,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
return pipe
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
if __name__ == "__main__":
|
| 529 |
+
parser = argparse.ArgumentParser()
|
| 530 |
+
|
| 531 |
+
parser.add_argument(
|
| 532 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
| 533 |
+
)
|
| 534 |
+
parser.add_argument(
|
| 535 |
+
"--original_config_file",
|
| 536 |
+
default=None,
|
| 537 |
+
type=str,
|
| 538 |
+
required=True,
|
| 539 |
+
help="The YAML config file corresponding to the gligen architecture.",
|
| 540 |
+
)
|
| 541 |
+
parser.add_argument(
|
| 542 |
+
"--num_in_channels",
|
| 543 |
+
default=None,
|
| 544 |
+
type=int,
|
| 545 |
+
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
|
| 546 |
+
)
|
| 547 |
+
parser.add_argument(
|
| 548 |
+
"--extract_ema",
|
| 549 |
+
action="store_true",
|
| 550 |
+
help=(
|
| 551 |
+
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
|
| 552 |
+
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
|
| 553 |
+
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
|
| 554 |
+
),
|
| 555 |
+
)
|
| 556 |
+
parser.add_argument(
|
| 557 |
+
"--attention_type",
|
| 558 |
+
default=None,
|
| 559 |
+
type=str,
|
| 560 |
+
required=True,
|
| 561 |
+
help="Type of attention ex: gated or gated-text-image",
|
| 562 |
+
)
|
| 563 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 564 |
+
parser.add_argument("--device", type=str, help="Device to use.")
|
| 565 |
+
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
|
| 566 |
+
|
| 567 |
+
args = parser.parse_args()
|
| 568 |
+
|
| 569 |
+
pipe = convert_gligen_to_diffusers(
|
| 570 |
+
checkpoint_path=args.checkpoint_path,
|
| 571 |
+
original_config_file=args.original_config_file,
|
| 572 |
+
attention_type=args.attention_type,
|
| 573 |
+
extract_ema=args.extract_ema,
|
| 574 |
+
num_in_channels=args.num_in_channels,
|
| 575 |
+
device=args.device,
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
if args.half:
|
| 579 |
+
pipe.to(dtype=torch.float16)
|
| 580 |
+
|
| 581 |
+
pipe.save_pretrained(args.dump_path)
|
diffusers/scripts/convert_hunyuan_video_to_diffusers.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from typing import Any, Dict
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from accelerate import init_empty_weights
|
| 6 |
+
from transformers import (
|
| 7 |
+
AutoModel,
|
| 8 |
+
AutoTokenizer,
|
| 9 |
+
CLIPImageProcessor,
|
| 10 |
+
CLIPTextModel,
|
| 11 |
+
CLIPTokenizer,
|
| 12 |
+
LlavaForConditionalGeneration,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from diffusers import (
|
| 16 |
+
AutoencoderKLHunyuanVideo,
|
| 17 |
+
FlowMatchEulerDiscreteScheduler,
|
| 18 |
+
HunyuanVideoImageToVideoPipeline,
|
| 19 |
+
HunyuanVideoPipeline,
|
| 20 |
+
HunyuanVideoTransformer3DModel,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def remap_norm_scale_shift_(key, state_dict):
|
| 25 |
+
weight = state_dict.pop(key)
|
| 26 |
+
shift, scale = weight.chunk(2, dim=0)
|
| 27 |
+
new_weight = torch.cat([scale, shift], dim=0)
|
| 28 |
+
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def remap_txt_in_(key, state_dict):
|
| 32 |
+
def rename_key(key):
|
| 33 |
+
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
|
| 34 |
+
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
|
| 35 |
+
new_key = new_key.replace("txt_in", "context_embedder")
|
| 36 |
+
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
|
| 37 |
+
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
|
| 38 |
+
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
|
| 39 |
+
new_key = new_key.replace("mlp", "ff")
|
| 40 |
+
return new_key
|
| 41 |
+
|
| 42 |
+
if "self_attn_qkv" in key:
|
| 43 |
+
weight = state_dict.pop(key)
|
| 44 |
+
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
| 45 |
+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
|
| 46 |
+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
|
| 47 |
+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
|
| 48 |
+
else:
|
| 49 |
+
state_dict[rename_key(key)] = state_dict.pop(key)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def remap_img_attn_qkv_(key, state_dict):
|
| 53 |
+
weight = state_dict.pop(key)
|
| 54 |
+
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
| 55 |
+
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
|
| 56 |
+
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
|
| 57 |
+
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def remap_txt_attn_qkv_(key, state_dict):
|
| 61 |
+
weight = state_dict.pop(key)
|
| 62 |
+
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
| 63 |
+
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
|
| 64 |
+
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
|
| 65 |
+
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def remap_single_transformer_blocks_(key, state_dict):
|
| 69 |
+
hidden_size = 3072
|
| 70 |
+
|
| 71 |
+
if "linear1.weight" in key:
|
| 72 |
+
linear1_weight = state_dict.pop(key)
|
| 73 |
+
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
|
| 74 |
+
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
|
| 75 |
+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
|
| 76 |
+
state_dict[f"{new_key}.attn.to_q.weight"] = q
|
| 77 |
+
state_dict[f"{new_key}.attn.to_k.weight"] = k
|
| 78 |
+
state_dict[f"{new_key}.attn.to_v.weight"] = v
|
| 79 |
+
state_dict[f"{new_key}.proj_mlp.weight"] = mlp
|
| 80 |
+
|
| 81 |
+
elif "linear1.bias" in key:
|
| 82 |
+
linear1_bias = state_dict.pop(key)
|
| 83 |
+
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
|
| 84 |
+
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
|
| 85 |
+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
|
| 86 |
+
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
|
| 87 |
+
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
|
| 88 |
+
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
|
| 89 |
+
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
|
| 90 |
+
|
| 91 |
+
else:
|
| 92 |
+
new_key = key.replace("single_blocks", "single_transformer_blocks")
|
| 93 |
+
new_key = new_key.replace("linear2", "proj_out")
|
| 94 |
+
new_key = new_key.replace("q_norm", "attn.norm_q")
|
| 95 |
+
new_key = new_key.replace("k_norm", "attn.norm_k")
|
| 96 |
+
state_dict[new_key] = state_dict.pop(key)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
TRANSFORMER_KEYS_RENAME_DICT = {
|
| 100 |
+
"img_in": "x_embedder",
|
| 101 |
+
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
|
| 102 |
+
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
|
| 103 |
+
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
|
| 104 |
+
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
|
| 105 |
+
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
|
| 106 |
+
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
|
| 107 |
+
"double_blocks": "transformer_blocks",
|
| 108 |
+
"img_attn_q_norm": "attn.norm_q",
|
| 109 |
+
"img_attn_k_norm": "attn.norm_k",
|
| 110 |
+
"img_attn_proj": "attn.to_out.0",
|
| 111 |
+
"txt_attn_q_norm": "attn.norm_added_q",
|
| 112 |
+
"txt_attn_k_norm": "attn.norm_added_k",
|
| 113 |
+
"txt_attn_proj": "attn.to_add_out",
|
| 114 |
+
"img_mod.linear": "norm1.linear",
|
| 115 |
+
"img_norm1": "norm1.norm",
|
| 116 |
+
"img_norm2": "norm2",
|
| 117 |
+
"img_mlp": "ff",
|
| 118 |
+
"txt_mod.linear": "norm1_context.linear",
|
| 119 |
+
"txt_norm1": "norm1.norm",
|
| 120 |
+
"txt_norm2": "norm2_context",
|
| 121 |
+
"txt_mlp": "ff_context",
|
| 122 |
+
"self_attn_proj": "attn.to_out.0",
|
| 123 |
+
"modulation.linear": "norm.linear",
|
| 124 |
+
"pre_norm": "norm.norm",
|
| 125 |
+
"final_layer.norm_final": "norm_out.norm",
|
| 126 |
+
"final_layer.linear": "proj_out",
|
| 127 |
+
"fc1": "net.0.proj",
|
| 128 |
+
"fc2": "net.2",
|
| 129 |
+
"input_embedder": "proj_in",
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
| 133 |
+
"txt_in": remap_txt_in_,
|
| 134 |
+
"img_attn_qkv": remap_img_attn_qkv_,
|
| 135 |
+
"txt_attn_qkv": remap_txt_attn_qkv_,
|
| 136 |
+
"single_blocks": remap_single_transformer_blocks_,
|
| 137 |
+
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
VAE_KEYS_RENAME_DICT = {}
|
| 141 |
+
|
| 142 |
+
VAE_SPECIAL_KEYS_REMAP = {}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
TRANSFORMER_CONFIGS = {
|
| 146 |
+
"HYVideo-T/2-cfgdistill": {
|
| 147 |
+
"in_channels": 16,
|
| 148 |
+
"out_channels": 16,
|
| 149 |
+
"num_attention_heads": 24,
|
| 150 |
+
"attention_head_dim": 128,
|
| 151 |
+
"num_layers": 20,
|
| 152 |
+
"num_single_layers": 40,
|
| 153 |
+
"num_refiner_layers": 2,
|
| 154 |
+
"mlp_ratio": 4.0,
|
| 155 |
+
"patch_size": 2,
|
| 156 |
+
"patch_size_t": 1,
|
| 157 |
+
"qk_norm": "rms_norm",
|
| 158 |
+
"guidance_embeds": True,
|
| 159 |
+
"text_embed_dim": 4096,
|
| 160 |
+
"pooled_projection_dim": 768,
|
| 161 |
+
"rope_theta": 256.0,
|
| 162 |
+
"rope_axes_dim": (16, 56, 56),
|
| 163 |
+
"image_condition_type": None,
|
| 164 |
+
},
|
| 165 |
+
"HYVideo-T/2-I2V-33ch": {
|
| 166 |
+
"in_channels": 16 * 2 + 1,
|
| 167 |
+
"out_channels": 16,
|
| 168 |
+
"num_attention_heads": 24,
|
| 169 |
+
"attention_head_dim": 128,
|
| 170 |
+
"num_layers": 20,
|
| 171 |
+
"num_single_layers": 40,
|
| 172 |
+
"num_refiner_layers": 2,
|
| 173 |
+
"mlp_ratio": 4.0,
|
| 174 |
+
"patch_size": 2,
|
| 175 |
+
"patch_size_t": 1,
|
| 176 |
+
"qk_norm": "rms_norm",
|
| 177 |
+
"guidance_embeds": False,
|
| 178 |
+
"text_embed_dim": 4096,
|
| 179 |
+
"pooled_projection_dim": 768,
|
| 180 |
+
"rope_theta": 256.0,
|
| 181 |
+
"rope_axes_dim": (16, 56, 56),
|
| 182 |
+
"image_condition_type": "latent_concat",
|
| 183 |
+
},
|
| 184 |
+
"HYVideo-T/2-I2V-16ch": {
|
| 185 |
+
"in_channels": 16,
|
| 186 |
+
"out_channels": 16,
|
| 187 |
+
"num_attention_heads": 24,
|
| 188 |
+
"attention_head_dim": 128,
|
| 189 |
+
"num_layers": 20,
|
| 190 |
+
"num_single_layers": 40,
|
| 191 |
+
"num_refiner_layers": 2,
|
| 192 |
+
"mlp_ratio": 4.0,
|
| 193 |
+
"patch_size": 2,
|
| 194 |
+
"patch_size_t": 1,
|
| 195 |
+
"qk_norm": "rms_norm",
|
| 196 |
+
"guidance_embeds": True,
|
| 197 |
+
"text_embed_dim": 4096,
|
| 198 |
+
"pooled_projection_dim": 768,
|
| 199 |
+
"rope_theta": 256.0,
|
| 200 |
+
"rope_axes_dim": (16, 56, 56),
|
| 201 |
+
"image_condition_type": "token_replace",
|
| 202 |
+
},
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
| 207 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
| 211 |
+
state_dict = saved_dict
|
| 212 |
+
if "model" in saved_dict.keys():
|
| 213 |
+
state_dict = state_dict["model"]
|
| 214 |
+
if "module" in saved_dict.keys():
|
| 215 |
+
state_dict = state_dict["module"]
|
| 216 |
+
if "state_dict" in saved_dict.keys():
|
| 217 |
+
state_dict = state_dict["state_dict"]
|
| 218 |
+
return state_dict
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def convert_transformer(ckpt_path: str, transformer_type: str):
|
| 222 |
+
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
|
| 223 |
+
config = TRANSFORMER_CONFIGS[transformer_type]
|
| 224 |
+
|
| 225 |
+
with init_empty_weights():
|
| 226 |
+
transformer = HunyuanVideoTransformer3DModel(**config)
|
| 227 |
+
|
| 228 |
+
for key in list(original_state_dict.keys()):
|
| 229 |
+
new_key = key[:]
|
| 230 |
+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
| 231 |
+
new_key = new_key.replace(replace_key, rename_key)
|
| 232 |
+
update_state_dict_(original_state_dict, key, new_key)
|
| 233 |
+
|
| 234 |
+
for key in list(original_state_dict.keys()):
|
| 235 |
+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
| 236 |
+
if special_key not in key:
|
| 237 |
+
continue
|
| 238 |
+
handler_fn_inplace(key, original_state_dict)
|
| 239 |
+
|
| 240 |
+
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
| 241 |
+
return transformer
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def convert_vae(ckpt_path: str):
|
| 245 |
+
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
|
| 246 |
+
|
| 247 |
+
with init_empty_weights():
|
| 248 |
+
vae = AutoencoderKLHunyuanVideo()
|
| 249 |
+
|
| 250 |
+
for key in list(original_state_dict.keys()):
|
| 251 |
+
new_key = key[:]
|
| 252 |
+
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
| 253 |
+
new_key = new_key.replace(replace_key, rename_key)
|
| 254 |
+
update_state_dict_(original_state_dict, key, new_key)
|
| 255 |
+
|
| 256 |
+
for key in list(original_state_dict.keys()):
|
| 257 |
+
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
|
| 258 |
+
if special_key not in key:
|
| 259 |
+
continue
|
| 260 |
+
handler_fn_inplace(key, original_state_dict)
|
| 261 |
+
|
| 262 |
+
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
| 263 |
+
return vae
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def get_args():
|
| 267 |
+
parser = argparse.ArgumentParser()
|
| 268 |
+
parser.add_argument(
|
| 269 |
+
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
| 270 |
+
)
|
| 271 |
+
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint")
|
| 272 |
+
parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint")
|
| 273 |
+
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer")
|
| 274 |
+
parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint")
|
| 275 |
+
parser.add_argument("--save_pipeline", action="store_true")
|
| 276 |
+
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
| 277 |
+
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--transformer_type", type=str, default="HYVideo-T/2-cfgdistill", choices=list(TRANSFORMER_CONFIGS.keys())
|
| 280 |
+
)
|
| 281 |
+
parser.add_argument("--flow_shift", type=float, default=7.0)
|
| 282 |
+
return parser.parse_args()
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
DTYPE_MAPPING = {
|
| 286 |
+
"fp32": torch.float32,
|
| 287 |
+
"fp16": torch.float16,
|
| 288 |
+
"bf16": torch.bfloat16,
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
if __name__ == "__main__":
|
| 293 |
+
args = get_args()
|
| 294 |
+
|
| 295 |
+
transformer = None
|
| 296 |
+
dtype = DTYPE_MAPPING[args.dtype]
|
| 297 |
+
|
| 298 |
+
if args.save_pipeline:
|
| 299 |
+
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
|
| 300 |
+
assert args.text_encoder_path is not None
|
| 301 |
+
assert args.tokenizer_path is not None
|
| 302 |
+
assert args.text_encoder_2_path is not None
|
| 303 |
+
|
| 304 |
+
if args.transformer_ckpt_path is not None:
|
| 305 |
+
transformer = convert_transformer(args.transformer_ckpt_path, args.transformer_type)
|
| 306 |
+
transformer = transformer.to(dtype=dtype)
|
| 307 |
+
if not args.save_pipeline:
|
| 308 |
+
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
| 309 |
+
|
| 310 |
+
if args.vae_ckpt_path is not None:
|
| 311 |
+
vae = convert_vae(args.vae_ckpt_path)
|
| 312 |
+
if not args.save_pipeline:
|
| 313 |
+
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
| 314 |
+
|
| 315 |
+
if args.save_pipeline:
|
| 316 |
+
if args.transformer_type == "HYVideo-T/2-cfgdistill":
|
| 317 |
+
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
|
| 318 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
|
| 319 |
+
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
|
| 320 |
+
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
|
| 321 |
+
scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
|
| 322 |
+
|
| 323 |
+
pipe = HunyuanVideoPipeline(
|
| 324 |
+
transformer=transformer,
|
| 325 |
+
vae=vae,
|
| 326 |
+
text_encoder=text_encoder,
|
| 327 |
+
tokenizer=tokenizer,
|
| 328 |
+
text_encoder_2=text_encoder_2,
|
| 329 |
+
tokenizer_2=tokenizer_2,
|
| 330 |
+
scheduler=scheduler,
|
| 331 |
+
)
|
| 332 |
+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
| 333 |
+
else:
|
| 334 |
+
text_encoder = LlavaForConditionalGeneration.from_pretrained(
|
| 335 |
+
args.text_encoder_path, torch_dtype=torch.float16
|
| 336 |
+
)
|
| 337 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
|
| 338 |
+
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
|
| 339 |
+
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
|
| 340 |
+
scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
|
| 341 |
+
image_processor = CLIPImageProcessor.from_pretrained(args.text_encoder_path)
|
| 342 |
+
|
| 343 |
+
pipe = HunyuanVideoImageToVideoPipeline(
|
| 344 |
+
transformer=transformer,
|
| 345 |
+
vae=vae,
|
| 346 |
+
text_encoder=text_encoder,
|
| 347 |
+
tokenizer=tokenizer,
|
| 348 |
+
text_encoder_2=text_encoder_2,
|
| 349 |
+
tokenizer_2=tokenizer_2,
|
| 350 |
+
scheduler=scheduler,
|
| 351 |
+
image_processor=image_processor,
|
| 352 |
+
)
|
| 353 |
+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
diffusers/scripts/convert_hunyuandit_to_diffusers.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from diffusers import HunyuanDiT2DModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main(args):
|
| 9 |
+
state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
|
| 10 |
+
|
| 11 |
+
if args.load_key != "none":
|
| 12 |
+
try:
|
| 13 |
+
state_dict = state_dict[args.load_key]
|
| 14 |
+
except KeyError:
|
| 15 |
+
raise KeyError(
|
| 16 |
+
f"{args.load_key} not found in the checkpoint.Please load from the following keys:{state_dict.keys()}"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
device = "cuda"
|
| 20 |
+
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
|
| 21 |
+
model_config["use_style_cond_and_image_meta_size"] = (
|
| 22 |
+
args.use_style_cond_and_image_meta_size
|
| 23 |
+
) ### version <= v1.1: True; version >= v1.2: False
|
| 24 |
+
|
| 25 |
+
# input_size -> sample_size, text_dim -> cross_attention_dim
|
| 26 |
+
for key in state_dict:
|
| 27 |
+
print("local:", key)
|
| 28 |
+
|
| 29 |
+
model = HunyuanDiT2DModel.from_config(model_config).to(device)
|
| 30 |
+
|
| 31 |
+
for key in model.state_dict():
|
| 32 |
+
print("diffusers:", key)
|
| 33 |
+
|
| 34 |
+
num_layers = 40
|
| 35 |
+
for i in range(num_layers):
|
| 36 |
+
# attn1
|
| 37 |
+
# Wkqv -> to_q, to_k, to_v
|
| 38 |
+
q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
|
| 39 |
+
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
|
| 40 |
+
state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
|
| 41 |
+
state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
|
| 42 |
+
state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
|
| 43 |
+
state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
|
| 44 |
+
state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
|
| 45 |
+
state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
|
| 46 |
+
state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
|
| 47 |
+
state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
|
| 48 |
+
|
| 49 |
+
# q_norm, k_norm -> norm_q, norm_k
|
| 50 |
+
state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
|
| 51 |
+
state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
|
| 52 |
+
state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
|
| 53 |
+
state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
|
| 54 |
+
|
| 55 |
+
state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
|
| 56 |
+
state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
|
| 57 |
+
state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
|
| 58 |
+
state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
|
| 59 |
+
|
| 60 |
+
# out_proj -> to_out
|
| 61 |
+
state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
|
| 62 |
+
state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
|
| 63 |
+
state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
|
| 64 |
+
state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
|
| 65 |
+
|
| 66 |
+
# attn2
|
| 67 |
+
# kq_proj -> to_k, to_v
|
| 68 |
+
k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
|
| 69 |
+
k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
|
| 70 |
+
state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
|
| 71 |
+
state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
|
| 72 |
+
state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
|
| 73 |
+
state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
|
| 74 |
+
state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
|
| 75 |
+
state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
|
| 76 |
+
|
| 77 |
+
# q_proj -> to_q
|
| 78 |
+
state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
|
| 79 |
+
state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
|
| 80 |
+
state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
|
| 81 |
+
state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
|
| 82 |
+
|
| 83 |
+
# q_norm, k_norm -> norm_q, norm_k
|
| 84 |
+
state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
|
| 85 |
+
state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
|
| 86 |
+
state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
|
| 87 |
+
state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
|
| 88 |
+
|
| 89 |
+
state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
|
| 90 |
+
state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
|
| 91 |
+
state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
|
| 92 |
+
state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
|
| 93 |
+
|
| 94 |
+
# out_proj -> to_out
|
| 95 |
+
state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
|
| 96 |
+
state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
|
| 97 |
+
state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
|
| 98 |
+
state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
|
| 99 |
+
|
| 100 |
+
# switch norm 2 and norm 3
|
| 101 |
+
norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
|
| 102 |
+
norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
|
| 103 |
+
state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
|
| 104 |
+
state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
|
| 105 |
+
state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
|
| 106 |
+
state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
|
| 107 |
+
|
| 108 |
+
# norm1 -> norm1.norm
|
| 109 |
+
# default_modulation.1 -> norm1.linear
|
| 110 |
+
state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]
|
| 111 |
+
state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]
|
| 112 |
+
state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]
|
| 113 |
+
state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]
|
| 114 |
+
state_dict.pop(f"blocks.{i}.norm1.weight")
|
| 115 |
+
state_dict.pop(f"blocks.{i}.norm1.bias")
|
| 116 |
+
state_dict.pop(f"blocks.{i}.default_modulation.1.weight")
|
| 117 |
+
state_dict.pop(f"blocks.{i}.default_modulation.1.bias")
|
| 118 |
+
|
| 119 |
+
# mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2
|
| 120 |
+
state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]
|
| 121 |
+
state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]
|
| 122 |
+
state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]
|
| 123 |
+
state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]
|
| 124 |
+
state_dict.pop(f"blocks.{i}.mlp.fc1.weight")
|
| 125 |
+
state_dict.pop(f"blocks.{i}.mlp.fc1.bias")
|
| 126 |
+
state_dict.pop(f"blocks.{i}.mlp.fc2.weight")
|
| 127 |
+
state_dict.pop(f"blocks.{i}.mlp.fc2.bias")
|
| 128 |
+
|
| 129 |
+
# pooler -> time_extra_emb
|
| 130 |
+
state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
|
| 131 |
+
state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
|
| 132 |
+
state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
|
| 133 |
+
state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
|
| 134 |
+
state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
|
| 135 |
+
state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
|
| 136 |
+
state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
|
| 137 |
+
state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
|
| 138 |
+
state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
|
| 139 |
+
state_dict.pop("pooler.k_proj.weight")
|
| 140 |
+
state_dict.pop("pooler.k_proj.bias")
|
| 141 |
+
state_dict.pop("pooler.q_proj.weight")
|
| 142 |
+
state_dict.pop("pooler.q_proj.bias")
|
| 143 |
+
state_dict.pop("pooler.v_proj.weight")
|
| 144 |
+
state_dict.pop("pooler.v_proj.bias")
|
| 145 |
+
state_dict.pop("pooler.c_proj.weight")
|
| 146 |
+
state_dict.pop("pooler.c_proj.bias")
|
| 147 |
+
state_dict.pop("pooler.positional_embedding")
|
| 148 |
+
|
| 149 |
+
# t_embedder -> time_embedding (`TimestepEmbedding`)
|
| 150 |
+
state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
|
| 151 |
+
state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
|
| 152 |
+
state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
|
| 153 |
+
state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
|
| 154 |
+
|
| 155 |
+
state_dict.pop("t_embedder.mlp.0.bias")
|
| 156 |
+
state_dict.pop("t_embedder.mlp.0.weight")
|
| 157 |
+
state_dict.pop("t_embedder.mlp.2.bias")
|
| 158 |
+
state_dict.pop("t_embedder.mlp.2.weight")
|
| 159 |
+
|
| 160 |
+
# x_embedder -> pos_embd (`PatchEmbed`)
|
| 161 |
+
state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
|
| 162 |
+
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
|
| 163 |
+
state_dict.pop("x_embedder.proj.weight")
|
| 164 |
+
state_dict.pop("x_embedder.proj.bias")
|
| 165 |
+
|
| 166 |
+
# mlp_t5 -> text_embedder
|
| 167 |
+
state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]
|
| 168 |
+
state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]
|
| 169 |
+
state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]
|
| 170 |
+
state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]
|
| 171 |
+
state_dict.pop("mlp_t5.0.bias")
|
| 172 |
+
state_dict.pop("mlp_t5.0.weight")
|
| 173 |
+
state_dict.pop("mlp_t5.2.bias")
|
| 174 |
+
state_dict.pop("mlp_t5.2.weight")
|
| 175 |
+
|
| 176 |
+
# extra_embedder -> extra_embedder
|
| 177 |
+
state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]
|
| 178 |
+
state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]
|
| 179 |
+
state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]
|
| 180 |
+
state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]
|
| 181 |
+
state_dict.pop("extra_embedder.0.bias")
|
| 182 |
+
state_dict.pop("extra_embedder.0.weight")
|
| 183 |
+
state_dict.pop("extra_embedder.2.bias")
|
| 184 |
+
state_dict.pop("extra_embedder.2.weight")
|
| 185 |
+
|
| 186 |
+
# model.final_adaLN_modulation.1 -> norm_out.linear
|
| 187 |
+
def swap_scale_shift(weight):
|
| 188 |
+
shift, scale = weight.chunk(2, dim=0)
|
| 189 |
+
new_weight = torch.cat([scale, shift], dim=0)
|
| 190 |
+
return new_weight
|
| 191 |
+
|
| 192 |
+
state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.weight"])
|
| 193 |
+
state_dict["norm_out.linear.bias"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.bias"])
|
| 194 |
+
state_dict.pop("final_layer.adaLN_modulation.1.weight")
|
| 195 |
+
state_dict.pop("final_layer.adaLN_modulation.1.bias")
|
| 196 |
+
|
| 197 |
+
# final_linear -> proj_out
|
| 198 |
+
state_dict["proj_out.weight"] = state_dict["final_layer.linear.weight"]
|
| 199 |
+
state_dict["proj_out.bias"] = state_dict["final_layer.linear.bias"]
|
| 200 |
+
state_dict.pop("final_layer.linear.weight")
|
| 201 |
+
state_dict.pop("final_layer.linear.bias")
|
| 202 |
+
|
| 203 |
+
# style_embedder
|
| 204 |
+
if model_config["use_style_cond_and_image_meta_size"]:
|
| 205 |
+
print(state_dict["style_embedder.weight"])
|
| 206 |
+
print(state_dict["style_embedder.weight"].shape)
|
| 207 |
+
state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]
|
| 208 |
+
state_dict.pop("style_embedder.weight")
|
| 209 |
+
|
| 210 |
+
model.load_state_dict(state_dict)
|
| 211 |
+
|
| 212 |
+
from diffusers import HunyuanDiTPipeline
|
| 213 |
+
|
| 214 |
+
if args.use_style_cond_and_image_meta_size:
|
| 215 |
+
pipe = HunyuanDiTPipeline.from_pretrained(
|
| 216 |
+
"Tencent-Hunyuan/HunyuanDiT-Diffusers", transformer=model, torch_dtype=torch.float32
|
| 217 |
+
)
|
| 218 |
+
else:
|
| 219 |
+
pipe = HunyuanDiTPipeline.from_pretrained(
|
| 220 |
+
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", transformer=model, torch_dtype=torch.float32
|
| 221 |
+
)
|
| 222 |
+
pipe.to("cuda")
|
| 223 |
+
pipe.to(dtype=torch.float32)
|
| 224 |
+
|
| 225 |
+
if args.save:
|
| 226 |
+
pipe.save_pretrained(args.output_checkpoint_path)
|
| 227 |
+
|
| 228 |
+
# ### NOTE: HunyuanDiT supports both Chinese and English inputs
|
| 229 |
+
prompt = "一个宇航员在骑马"
|
| 230 |
+
# prompt = "An astronaut riding a horse"
|
| 231 |
+
generator = torch.Generator(device="cuda").manual_seed(0)
|
| 232 |
+
image = pipe(
|
| 233 |
+
height=1024, width=1024, prompt=prompt, generator=generator, num_inference_steps=25, guidance_scale=5.0
|
| 234 |
+
).images[0]
|
| 235 |
+
|
| 236 |
+
image.save("img.png")
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
if __name__ == "__main__":
|
| 240 |
+
parser = argparse.ArgumentParser()
|
| 241 |
+
|
| 242 |
+
parser.add_argument(
|
| 243 |
+
"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
|
| 244 |
+
)
|
| 245 |
+
parser.add_argument(
|
| 246 |
+
"--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."
|
| 247 |
+
)
|
| 248 |
+
parser.add_argument(
|
| 249 |
+
"--output_checkpoint_path",
|
| 250 |
+
default=None,
|
| 251 |
+
type=str,
|
| 252 |
+
required=False,
|
| 253 |
+
help="Path to the output converted diffusers pipeline.",
|
| 254 |
+
)
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"
|
| 257 |
+
)
|
| 258 |
+
parser.add_argument(
|
| 259 |
+
"--use_style_cond_and_image_meta_size",
|
| 260 |
+
type=bool,
|
| 261 |
+
default=False,
|
| 262 |
+
help="version <= v1.1: True; version >= v1.2: False",
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
args = parser.parse_args()
|
| 266 |
+
main(args)
|
diffusers/scripts/convert_k_upscaler_to_diffusers.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import huggingface_hub
|
| 4 |
+
import k_diffusion as K
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from diffusers import UNet2DConditionModel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
UPSCALER_REPO = "pcuenq/k-upscaler"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
|
| 14 |
+
rv = {
|
| 15 |
+
# norm1
|
| 16 |
+
f"{diffusers_resnet_prefix}.norm1.linear.weight": checkpoint[f"{resnet_prefix}.main.0.mapper.weight"],
|
| 17 |
+
f"{diffusers_resnet_prefix}.norm1.linear.bias": checkpoint[f"{resnet_prefix}.main.0.mapper.bias"],
|
| 18 |
+
# conv1
|
| 19 |
+
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.main.2.weight"],
|
| 20 |
+
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.main.2.bias"],
|
| 21 |
+
# norm2
|
| 22 |
+
f"{diffusers_resnet_prefix}.norm2.linear.weight": checkpoint[f"{resnet_prefix}.main.4.mapper.weight"],
|
| 23 |
+
f"{diffusers_resnet_prefix}.norm2.linear.bias": checkpoint[f"{resnet_prefix}.main.4.mapper.bias"],
|
| 24 |
+
# conv2
|
| 25 |
+
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.main.6.weight"],
|
| 26 |
+
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.main.6.bias"],
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
if resnet.conv_shortcut is not None:
|
| 30 |
+
rv.update(
|
| 31 |
+
{
|
| 32 |
+
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.skip.weight"],
|
| 33 |
+
}
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
return rv
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def self_attn_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
|
| 40 |
+
weight_q, weight_k, weight_v = checkpoint[f"{attention_prefix}.qkv_proj.weight"].chunk(3, dim=0)
|
| 41 |
+
bias_q, bias_k, bias_v = checkpoint[f"{attention_prefix}.qkv_proj.bias"].chunk(3, dim=0)
|
| 42 |
+
rv = {
|
| 43 |
+
# norm
|
| 44 |
+
f"{diffusers_attention_prefix}.norm1.linear.weight": checkpoint[f"{attention_prefix}.norm_in.mapper.weight"],
|
| 45 |
+
f"{diffusers_attention_prefix}.norm1.linear.bias": checkpoint[f"{attention_prefix}.norm_in.mapper.bias"],
|
| 46 |
+
# to_q
|
| 47 |
+
f"{diffusers_attention_prefix}.attn1.to_q.weight": weight_q.squeeze(-1).squeeze(-1),
|
| 48 |
+
f"{diffusers_attention_prefix}.attn1.to_q.bias": bias_q,
|
| 49 |
+
# to_k
|
| 50 |
+
f"{diffusers_attention_prefix}.attn1.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
|
| 51 |
+
f"{diffusers_attention_prefix}.attn1.to_k.bias": bias_k,
|
| 52 |
+
# to_v
|
| 53 |
+
f"{diffusers_attention_prefix}.attn1.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
|
| 54 |
+
f"{diffusers_attention_prefix}.attn1.to_v.bias": bias_v,
|
| 55 |
+
# to_out
|
| 56 |
+
f"{diffusers_attention_prefix}.attn1.to_out.0.weight": checkpoint[f"{attention_prefix}.out_proj.weight"]
|
| 57 |
+
.squeeze(-1)
|
| 58 |
+
.squeeze(-1),
|
| 59 |
+
f"{diffusers_attention_prefix}.attn1.to_out.0.bias": checkpoint[f"{attention_prefix}.out_proj.bias"],
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
return rv
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def cross_attn_to_diffusers_checkpoint(
|
| 66 |
+
checkpoint, *, diffusers_attention_prefix, diffusers_attention_index, attention_prefix
|
| 67 |
+
):
|
| 68 |
+
weight_k, weight_v = checkpoint[f"{attention_prefix}.kv_proj.weight"].chunk(2, dim=0)
|
| 69 |
+
bias_k, bias_v = checkpoint[f"{attention_prefix}.kv_proj.bias"].chunk(2, dim=0)
|
| 70 |
+
|
| 71 |
+
rv = {
|
| 72 |
+
# norm2 (ada groupnorm)
|
| 73 |
+
f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.weight": checkpoint[
|
| 74 |
+
f"{attention_prefix}.norm_dec.mapper.weight"
|
| 75 |
+
],
|
| 76 |
+
f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.bias": checkpoint[
|
| 77 |
+
f"{attention_prefix}.norm_dec.mapper.bias"
|
| 78 |
+
],
|
| 79 |
+
# layernorm on encoder_hidden_state
|
| 80 |
+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.weight": checkpoint[
|
| 81 |
+
f"{attention_prefix}.norm_enc.weight"
|
| 82 |
+
],
|
| 83 |
+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.bias": checkpoint[
|
| 84 |
+
f"{attention_prefix}.norm_enc.bias"
|
| 85 |
+
],
|
| 86 |
+
# to_q
|
| 87 |
+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.weight": checkpoint[
|
| 88 |
+
f"{attention_prefix}.q_proj.weight"
|
| 89 |
+
]
|
| 90 |
+
.squeeze(-1)
|
| 91 |
+
.squeeze(-1),
|
| 92 |
+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.bias": checkpoint[
|
| 93 |
+
f"{attention_prefix}.q_proj.bias"
|
| 94 |
+
],
|
| 95 |
+
# to_k
|
| 96 |
+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
|
| 97 |
+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.bias": bias_k,
|
| 98 |
+
# to_v
|
| 99 |
+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
|
| 100 |
+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.bias": bias_v,
|
| 101 |
+
# to_out
|
| 102 |
+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.weight": checkpoint[
|
| 103 |
+
f"{attention_prefix}.out_proj.weight"
|
| 104 |
+
]
|
| 105 |
+
.squeeze(-1)
|
| 106 |
+
.squeeze(-1),
|
| 107 |
+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.bias": checkpoint[
|
| 108 |
+
f"{attention_prefix}.out_proj.bias"
|
| 109 |
+
],
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
return rv
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
|
| 116 |
+
block_prefix = "inner_model.u_net.u_blocks" if block_type == "up" else "inner_model.u_net.d_blocks"
|
| 117 |
+
block_prefix = f"{block_prefix}.{block_idx}"
|
| 118 |
+
|
| 119 |
+
diffusers_checkpoint = {}
|
| 120 |
+
|
| 121 |
+
if not hasattr(block, "attentions"):
|
| 122 |
+
n = 1 # resnet only
|
| 123 |
+
elif not block.attentions[0].add_self_attention:
|
| 124 |
+
n = 2 # resnet -> cross-attention
|
| 125 |
+
else:
|
| 126 |
+
n = 3 # resnet -> self-attention -> cross-attention)
|
| 127 |
+
|
| 128 |
+
for resnet_idx, resnet in enumerate(block.resnets):
|
| 129 |
+
# diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
|
| 130 |
+
diffusers_resnet_prefix = f"{block_type}_blocks.{block_idx}.resnets.{resnet_idx}"
|
| 131 |
+
idx = n * resnet_idx if block_type == "up" else n * resnet_idx + 1
|
| 132 |
+
resnet_prefix = f"{block_prefix}.{idx}" if block_type == "up" else f"{block_prefix}.{idx}"
|
| 133 |
+
|
| 134 |
+
diffusers_checkpoint.update(
|
| 135 |
+
resnet_to_diffusers_checkpoint(
|
| 136 |
+
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if hasattr(block, "attentions"):
|
| 141 |
+
for attention_idx, attention in enumerate(block.attentions):
|
| 142 |
+
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
|
| 143 |
+
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
|
| 144 |
+
self_attention_prefix = f"{block_prefix}.{idx}"
|
| 145 |
+
cross_attention_prefix = f"{block_prefix}.{idx}"
|
| 146 |
+
cross_attention_index = 1 if not attention.add_self_attention else 2
|
| 147 |
+
idx = (
|
| 148 |
+
n * attention_idx + cross_attention_index
|
| 149 |
+
if block_type == "up"
|
| 150 |
+
else n * attention_idx + cross_attention_index + 1
|
| 151 |
+
)
|
| 152 |
+
cross_attention_prefix = f"{block_prefix}.{idx}"
|
| 153 |
+
|
| 154 |
+
diffusers_checkpoint.update(
|
| 155 |
+
cross_attn_to_diffusers_checkpoint(
|
| 156 |
+
checkpoint,
|
| 157 |
+
diffusers_attention_prefix=diffusers_attention_prefix,
|
| 158 |
+
diffusers_attention_index=2,
|
| 159 |
+
attention_prefix=cross_attention_prefix,
|
| 160 |
+
)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if attention.add_self_attention is True:
|
| 164 |
+
diffusers_checkpoint.update(
|
| 165 |
+
self_attn_to_diffusers_checkpoint(
|
| 166 |
+
checkpoint,
|
| 167 |
+
diffusers_attention_prefix=diffusers_attention_prefix,
|
| 168 |
+
attention_prefix=self_attention_prefix,
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
return diffusers_checkpoint
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def unet_to_diffusers_checkpoint(model, checkpoint):
|
| 176 |
+
diffusers_checkpoint = {}
|
| 177 |
+
|
| 178 |
+
# pre-processing
|
| 179 |
+
diffusers_checkpoint.update(
|
| 180 |
+
{
|
| 181 |
+
"conv_in.weight": checkpoint["inner_model.proj_in.weight"],
|
| 182 |
+
"conv_in.bias": checkpoint["inner_model.proj_in.bias"],
|
| 183 |
+
}
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# timestep and class embedding
|
| 187 |
+
diffusers_checkpoint.update(
|
| 188 |
+
{
|
| 189 |
+
"time_proj.weight": checkpoint["inner_model.timestep_embed.weight"].squeeze(-1),
|
| 190 |
+
"time_embedding.linear_1.weight": checkpoint["inner_model.mapping.0.weight"],
|
| 191 |
+
"time_embedding.linear_1.bias": checkpoint["inner_model.mapping.0.bias"],
|
| 192 |
+
"time_embedding.linear_2.weight": checkpoint["inner_model.mapping.2.weight"],
|
| 193 |
+
"time_embedding.linear_2.bias": checkpoint["inner_model.mapping.2.bias"],
|
| 194 |
+
"time_embedding.cond_proj.weight": checkpoint["inner_model.mapping_cond.weight"],
|
| 195 |
+
}
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# down_blocks
|
| 199 |
+
for down_block_idx, down_block in enumerate(model.down_blocks):
|
| 200 |
+
diffusers_checkpoint.update(block_to_diffusers_checkpoint(down_block, checkpoint, down_block_idx, "down"))
|
| 201 |
+
|
| 202 |
+
# up_blocks
|
| 203 |
+
for up_block_idx, up_block in enumerate(model.up_blocks):
|
| 204 |
+
diffusers_checkpoint.update(block_to_diffusers_checkpoint(up_block, checkpoint, up_block_idx, "up"))
|
| 205 |
+
|
| 206 |
+
# post-processing
|
| 207 |
+
diffusers_checkpoint.update(
|
| 208 |
+
{
|
| 209 |
+
"conv_out.weight": checkpoint["inner_model.proj_out.weight"],
|
| 210 |
+
"conv_out.bias": checkpoint["inner_model.proj_out.bias"],
|
| 211 |
+
}
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
return diffusers_checkpoint
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def unet_model_from_original_config(original_config):
|
| 218 |
+
in_channels = original_config["input_channels"] + original_config["unet_cond_dim"]
|
| 219 |
+
out_channels = original_config["input_channels"] + (1 if original_config["has_variance"] else 0)
|
| 220 |
+
|
| 221 |
+
block_out_channels = original_config["channels"]
|
| 222 |
+
|
| 223 |
+
assert len(set(original_config["depths"])) == 1, (
|
| 224 |
+
"UNet2DConditionModel currently do not support blocks with different number of layers"
|
| 225 |
+
)
|
| 226 |
+
layers_per_block = original_config["depths"][0]
|
| 227 |
+
|
| 228 |
+
class_labels_dim = original_config["mapping_cond_dim"]
|
| 229 |
+
cross_attention_dim = original_config["cross_cond_dim"]
|
| 230 |
+
|
| 231 |
+
attn1_types = []
|
| 232 |
+
attn2_types = []
|
| 233 |
+
for s, c in zip(original_config["self_attn_depths"], original_config["cross_attn_depths"]):
|
| 234 |
+
if s:
|
| 235 |
+
a1 = "self"
|
| 236 |
+
a2 = "cross" if c else None
|
| 237 |
+
elif c:
|
| 238 |
+
a1 = "cross"
|
| 239 |
+
a2 = None
|
| 240 |
+
else:
|
| 241 |
+
a1 = None
|
| 242 |
+
a2 = None
|
| 243 |
+
attn1_types.append(a1)
|
| 244 |
+
attn2_types.append(a2)
|
| 245 |
+
|
| 246 |
+
unet = UNet2DConditionModel(
|
| 247 |
+
in_channels=in_channels,
|
| 248 |
+
out_channels=out_channels,
|
| 249 |
+
down_block_types=("KDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D"),
|
| 250 |
+
mid_block_type=None,
|
| 251 |
+
up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"),
|
| 252 |
+
block_out_channels=block_out_channels,
|
| 253 |
+
layers_per_block=layers_per_block,
|
| 254 |
+
act_fn="gelu",
|
| 255 |
+
norm_num_groups=None,
|
| 256 |
+
cross_attention_dim=cross_attention_dim,
|
| 257 |
+
attention_head_dim=64,
|
| 258 |
+
time_cond_proj_dim=class_labels_dim,
|
| 259 |
+
resnet_time_scale_shift="scale_shift",
|
| 260 |
+
time_embedding_type="fourier",
|
| 261 |
+
timestep_post_act="gelu",
|
| 262 |
+
conv_in_kernel=1,
|
| 263 |
+
conv_out_kernel=1,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
return unet
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def main(args):
|
| 270 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 271 |
+
|
| 272 |
+
orig_config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json")
|
| 273 |
+
orig_weights_path = huggingface_hub.hf_hub_download(
|
| 274 |
+
UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth"
|
| 275 |
+
)
|
| 276 |
+
print(f"loading original model configuration from {orig_config_path}")
|
| 277 |
+
print(f"loading original model checkpoint from {orig_weights_path}")
|
| 278 |
+
|
| 279 |
+
print("converting to diffusers unet")
|
| 280 |
+
orig_config = K.config.load_config(open(orig_config_path))["model"]
|
| 281 |
+
model = unet_model_from_original_config(orig_config)
|
| 282 |
+
|
| 283 |
+
orig_checkpoint = torch.load(orig_weights_path, map_location=device)["model_ema"]
|
| 284 |
+
converted_checkpoint = unet_to_diffusers_checkpoint(model, orig_checkpoint)
|
| 285 |
+
|
| 286 |
+
model.load_state_dict(converted_checkpoint, strict=True)
|
| 287 |
+
model.save_pretrained(args.dump_path)
|
| 288 |
+
print(f"saving converted unet model in {args.dump_path}")
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
if __name__ == "__main__":
|
| 292 |
+
parser = argparse.ArgumentParser()
|
| 293 |
+
|
| 294 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 295 |
+
args = parser.parse_args()
|
| 296 |
+
|
| 297 |
+
main(args)
|
diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py
ADDED
|
@@ -0,0 +1,1159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import tempfile
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from accelerate import load_checkpoint_and_dispatch
|
| 6 |
+
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
| 7 |
+
|
| 8 |
+
from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel
|
| 9 |
+
from diffusers.models.transformers.prior_transformer import PriorTransformer
|
| 10 |
+
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
|
| 11 |
+
from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
r"""
|
| 15 |
+
Example - From the diffusers root directory:
|
| 16 |
+
|
| 17 |
+
Download weights:
|
| 18 |
+
```sh
|
| 19 |
+
$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt
|
| 20 |
+
$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt
|
| 21 |
+
$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt
|
| 22 |
+
$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
Convert the model:
|
| 26 |
+
```sh
|
| 27 |
+
$ python scripts/convert_kakao_brain_unclip_to_diffusers.py \
|
| 28 |
+
--decoder_checkpoint_path ./decoder-ckpt-step\=01000000-of-01000000.ckpt \
|
| 29 |
+
--super_res_unet_checkpoint_path ./improved-sr-ckpt-step\=1.2M.ckpt \
|
| 30 |
+
--prior_checkpoint_path ./prior-ckpt-step\=01000000-of-01000000.ckpt \
|
| 31 |
+
--clip_stat_path ./ViT-L-14_stats.th \
|
| 32 |
+
--dump_path <path where to save model>
|
| 33 |
+
```
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# prior
|
| 38 |
+
|
| 39 |
+
PRIOR_ORIGINAL_PREFIX = "model"
|
| 40 |
+
|
| 41 |
+
# Uses default arguments
|
| 42 |
+
PRIOR_CONFIG = {}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def prior_model_from_original_config():
|
| 46 |
+
model = PriorTransformer(**PRIOR_CONFIG)
|
| 47 |
+
|
| 48 |
+
return model
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint, clip_stats_checkpoint):
|
| 52 |
+
diffusers_checkpoint = {}
|
| 53 |
+
|
| 54 |
+
# <original>.time_embed.0 -> <diffusers>.time_embedding.linear_1
|
| 55 |
+
diffusers_checkpoint.update(
|
| 56 |
+
{
|
| 57 |
+
"time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.weight"],
|
| 58 |
+
"time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.bias"],
|
| 59 |
+
}
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# <original>.clip_img_proj -> <diffusers>.proj_in
|
| 63 |
+
diffusers_checkpoint.update(
|
| 64 |
+
{
|
| 65 |
+
"proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.weight"],
|
| 66 |
+
"proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.bias"],
|
| 67 |
+
}
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# <original>.text_emb_proj -> <diffusers>.embedding_proj
|
| 71 |
+
diffusers_checkpoint.update(
|
| 72 |
+
{
|
| 73 |
+
"embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.weight"],
|
| 74 |
+
"embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.bias"],
|
| 75 |
+
}
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# <original>.text_enc_proj -> <diffusers>.encoder_hidden_states_proj
|
| 79 |
+
diffusers_checkpoint.update(
|
| 80 |
+
{
|
| 81 |
+
"encoder_hidden_states_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.weight"],
|
| 82 |
+
"encoder_hidden_states_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.bias"],
|
| 83 |
+
}
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# <original>.positional_embedding -> <diffusers>.positional_embedding
|
| 87 |
+
diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.positional_embedding"]})
|
| 88 |
+
|
| 89 |
+
# <original>.prd_emb -> <diffusers>.prd_embedding
|
| 90 |
+
diffusers_checkpoint.update({"prd_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.prd_emb"]})
|
| 91 |
+
|
| 92 |
+
# <original>.time_embed.2 -> <diffusers>.time_embedding.linear_2
|
| 93 |
+
diffusers_checkpoint.update(
|
| 94 |
+
{
|
| 95 |
+
"time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.weight"],
|
| 96 |
+
"time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.bias"],
|
| 97 |
+
}
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# <original>.resblocks.<x> -> <diffusers>.transformer_blocks.<x>
|
| 101 |
+
for idx in range(len(model.transformer_blocks)):
|
| 102 |
+
diffusers_transformer_prefix = f"transformer_blocks.{idx}"
|
| 103 |
+
original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.transformer.resblocks.{idx}"
|
| 104 |
+
|
| 105 |
+
# <original>.attn -> <diffusers>.attn1
|
| 106 |
+
diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"
|
| 107 |
+
original_attention_prefix = f"{original_transformer_prefix}.attn"
|
| 108 |
+
diffusers_checkpoint.update(
|
| 109 |
+
prior_attention_to_diffusers(
|
| 110 |
+
checkpoint,
|
| 111 |
+
diffusers_attention_prefix=diffusers_attention_prefix,
|
| 112 |
+
original_attention_prefix=original_attention_prefix,
|
| 113 |
+
attention_head_dim=model.attention_head_dim,
|
| 114 |
+
)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# <original>.mlp -> <diffusers>.ff
|
| 118 |
+
diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"
|
| 119 |
+
original_ff_prefix = f"{original_transformer_prefix}.mlp"
|
| 120 |
+
diffusers_checkpoint.update(
|
| 121 |
+
prior_ff_to_diffusers(
|
| 122 |
+
checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix
|
| 123 |
+
)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# <original>.ln_1 -> <diffusers>.norm1
|
| 127 |
+
diffusers_checkpoint.update(
|
| 128 |
+
{
|
| 129 |
+
f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[
|
| 130 |
+
f"{original_transformer_prefix}.ln_1.weight"
|
| 131 |
+
],
|
| 132 |
+
f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],
|
| 133 |
+
}
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# <original>.ln_2 -> <diffusers>.norm3
|
| 137 |
+
diffusers_checkpoint.update(
|
| 138 |
+
{
|
| 139 |
+
f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[
|
| 140 |
+
f"{original_transformer_prefix}.ln_2.weight"
|
| 141 |
+
],
|
| 142 |
+
f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],
|
| 143 |
+
}
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# <original>.final_ln -> <diffusers>.norm_out
|
| 147 |
+
diffusers_checkpoint.update(
|
| 148 |
+
{
|
| 149 |
+
"norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.weight"],
|
| 150 |
+
"norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.bias"],
|
| 151 |
+
}
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# <original>.out_proj -> <diffusers>.proj_to_clip_embeddings
|
| 155 |
+
diffusers_checkpoint.update(
|
| 156 |
+
{
|
| 157 |
+
"proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.weight"],
|
| 158 |
+
"proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.bias"],
|
| 159 |
+
}
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# clip stats
|
| 163 |
+
clip_mean, clip_std = clip_stats_checkpoint
|
| 164 |
+
clip_mean = clip_mean[None, :]
|
| 165 |
+
clip_std = clip_std[None, :]
|
| 166 |
+
|
| 167 |
+
diffusers_checkpoint.update({"clip_mean": clip_mean, "clip_std": clip_std})
|
| 168 |
+
|
| 169 |
+
return diffusers_checkpoint
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def prior_attention_to_diffusers(
|
| 173 |
+
checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim
|
| 174 |
+
):
|
| 175 |
+
diffusers_checkpoint = {}
|
| 176 |
+
|
| 177 |
+
# <original>.c_qkv -> <diffusers>.{to_q, to_k, to_v}
|
| 178 |
+
[q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
|
| 179 |
+
weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"],
|
| 180 |
+
bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"],
|
| 181 |
+
split=3,
|
| 182 |
+
chunk_size=attention_head_dim,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
diffusers_checkpoint.update(
|
| 186 |
+
{
|
| 187 |
+
f"{diffusers_attention_prefix}.to_q.weight": q_weight,
|
| 188 |
+
f"{diffusers_attention_prefix}.to_q.bias": q_bias,
|
| 189 |
+
f"{diffusers_attention_prefix}.to_k.weight": k_weight,
|
| 190 |
+
f"{diffusers_attention_prefix}.to_k.bias": k_bias,
|
| 191 |
+
f"{diffusers_attention_prefix}.to_v.weight": v_weight,
|
| 192 |
+
f"{diffusers_attention_prefix}.to_v.bias": v_bias,
|
| 193 |
+
}
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# <original>.c_proj -> <diffusers>.to_out.0
|
| 197 |
+
diffusers_checkpoint.update(
|
| 198 |
+
{
|
| 199 |
+
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"],
|
| 200 |
+
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"],
|
| 201 |
+
}
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
return diffusers_checkpoint
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix):
|
| 208 |
+
diffusers_checkpoint = {
|
| 209 |
+
# <original>.c_fc -> <diffusers>.net.0.proj
|
| 210 |
+
f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"],
|
| 211 |
+
f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"],
|
| 212 |
+
# <original>.c_proj -> <diffusers>.net.2
|
| 213 |
+
f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"],
|
| 214 |
+
f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"],
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
return diffusers_checkpoint
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# done prior
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# decoder
|
| 224 |
+
|
| 225 |
+
DECODER_ORIGINAL_PREFIX = "model"
|
| 226 |
+
|
| 227 |
+
# We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can
|
| 228 |
+
# update then.
|
| 229 |
+
DECODER_CONFIG = {
|
| 230 |
+
"sample_size": 64,
|
| 231 |
+
"layers_per_block": 3,
|
| 232 |
+
"down_block_types": (
|
| 233 |
+
"ResnetDownsampleBlock2D",
|
| 234 |
+
"SimpleCrossAttnDownBlock2D",
|
| 235 |
+
"SimpleCrossAttnDownBlock2D",
|
| 236 |
+
"SimpleCrossAttnDownBlock2D",
|
| 237 |
+
),
|
| 238 |
+
"up_block_types": (
|
| 239 |
+
"SimpleCrossAttnUpBlock2D",
|
| 240 |
+
"SimpleCrossAttnUpBlock2D",
|
| 241 |
+
"SimpleCrossAttnUpBlock2D",
|
| 242 |
+
"ResnetUpsampleBlock2D",
|
| 243 |
+
),
|
| 244 |
+
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
|
| 245 |
+
"block_out_channels": (320, 640, 960, 1280),
|
| 246 |
+
"in_channels": 3,
|
| 247 |
+
"out_channels": 6,
|
| 248 |
+
"cross_attention_dim": 1536,
|
| 249 |
+
"class_embed_type": "identity",
|
| 250 |
+
"attention_head_dim": 64,
|
| 251 |
+
"resnet_time_scale_shift": "scale_shift",
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def decoder_model_from_original_config():
|
| 256 |
+
model = UNet2DConditionModel(**DECODER_CONFIG)
|
| 257 |
+
|
| 258 |
+
return model
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def decoder_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
| 262 |
+
diffusers_checkpoint = {}
|
| 263 |
+
|
| 264 |
+
original_unet_prefix = DECODER_ORIGINAL_PREFIX
|
| 265 |
+
num_head_channels = DECODER_CONFIG["attention_head_dim"]
|
| 266 |
+
|
| 267 |
+
diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))
|
| 268 |
+
diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))
|
| 269 |
+
|
| 270 |
+
# <original>.input_blocks -> <diffusers>.down_blocks
|
| 271 |
+
|
| 272 |
+
original_down_block_idx = 1
|
| 273 |
+
|
| 274 |
+
for diffusers_down_block_idx in range(len(model.down_blocks)):
|
| 275 |
+
checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
|
| 276 |
+
model,
|
| 277 |
+
checkpoint,
|
| 278 |
+
diffusers_down_block_idx=diffusers_down_block_idx,
|
| 279 |
+
original_down_block_idx=original_down_block_idx,
|
| 280 |
+
original_unet_prefix=original_unet_prefix,
|
| 281 |
+
num_head_channels=num_head_channels,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
original_down_block_idx += num_original_down_blocks
|
| 285 |
+
|
| 286 |
+
diffusers_checkpoint.update(checkpoint_update)
|
| 287 |
+
|
| 288 |
+
# done <original>.input_blocks -> <diffusers>.down_blocks
|
| 289 |
+
|
| 290 |
+
diffusers_checkpoint.update(
|
| 291 |
+
unet_midblock_to_diffusers_checkpoint(
|
| 292 |
+
model,
|
| 293 |
+
checkpoint,
|
| 294 |
+
original_unet_prefix=original_unet_prefix,
|
| 295 |
+
num_head_channels=num_head_channels,
|
| 296 |
+
)
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# <original>.output_blocks -> <diffusers>.up_blocks
|
| 300 |
+
|
| 301 |
+
original_up_block_idx = 0
|
| 302 |
+
|
| 303 |
+
for diffusers_up_block_idx in range(len(model.up_blocks)):
|
| 304 |
+
checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
|
| 305 |
+
model,
|
| 306 |
+
checkpoint,
|
| 307 |
+
diffusers_up_block_idx=diffusers_up_block_idx,
|
| 308 |
+
original_up_block_idx=original_up_block_idx,
|
| 309 |
+
original_unet_prefix=original_unet_prefix,
|
| 310 |
+
num_head_channels=num_head_channels,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
original_up_block_idx += num_original_up_blocks
|
| 314 |
+
|
| 315 |
+
diffusers_checkpoint.update(checkpoint_update)
|
| 316 |
+
|
| 317 |
+
# done <original>.output_blocks -> <diffusers>.up_blocks
|
| 318 |
+
|
| 319 |
+
diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))
|
| 320 |
+
diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))
|
| 321 |
+
|
| 322 |
+
return diffusers_checkpoint
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# done decoder
|
| 326 |
+
|
| 327 |
+
# text proj
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def text_proj_from_original_config():
|
| 331 |
+
# From the conditional unet constructor where the dimension of the projected time embeddings is
|
| 332 |
+
# constructed
|
| 333 |
+
time_embed_dim = DECODER_CONFIG["block_out_channels"][0] * 4
|
| 334 |
+
|
| 335 |
+
cross_attention_dim = DECODER_CONFIG["cross_attention_dim"]
|
| 336 |
+
|
| 337 |
+
model = UnCLIPTextProjModel(time_embed_dim=time_embed_dim, cross_attention_dim=cross_attention_dim)
|
| 338 |
+
|
| 339 |
+
return model
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# Note that the input checkpoint is the original decoder checkpoint
|
| 343 |
+
def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint):
|
| 344 |
+
diffusers_checkpoint = {
|
| 345 |
+
# <original>.text_seq_proj.0 -> <diffusers>.encoder_hidden_states_proj
|
| 346 |
+
"encoder_hidden_states_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.weight"],
|
| 347 |
+
"encoder_hidden_states_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.0.bias"],
|
| 348 |
+
# <original>.text_seq_proj.1 -> <diffusers>.text_encoder_hidden_states_norm
|
| 349 |
+
"text_encoder_hidden_states_norm.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.weight"],
|
| 350 |
+
"text_encoder_hidden_states_norm.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_seq_proj.1.bias"],
|
| 351 |
+
# <original>.clip_tok_proj -> <diffusers>.clip_extra_context_tokens_proj
|
| 352 |
+
"clip_extra_context_tokens_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.weight"],
|
| 353 |
+
"clip_extra_context_tokens_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.clip_tok_proj.bias"],
|
| 354 |
+
# <original>.text_feat_proj -> <diffusers>.embedding_proj
|
| 355 |
+
"embedding_proj.weight": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.weight"],
|
| 356 |
+
"embedding_proj.bias": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.text_feat_proj.bias"],
|
| 357 |
+
# <original>.cf_param -> <diffusers>.learned_classifier_free_guidance_embeddings
|
| 358 |
+
"learned_classifier_free_guidance_embeddings": checkpoint[f"{DECODER_ORIGINAL_PREFIX}.cf_param"],
|
| 359 |
+
# <original>.clip_emb -> <diffusers>.clip_image_embeddings_project_to_time_embeddings
|
| 360 |
+
"clip_image_embeddings_project_to_time_embeddings.weight": checkpoint[
|
| 361 |
+
f"{DECODER_ORIGINAL_PREFIX}.clip_emb.weight"
|
| 362 |
+
],
|
| 363 |
+
"clip_image_embeddings_project_to_time_embeddings.bias": checkpoint[
|
| 364 |
+
f"{DECODER_ORIGINAL_PREFIX}.clip_emb.bias"
|
| 365 |
+
],
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
return diffusers_checkpoint
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
# done text proj
|
| 372 |
+
|
| 373 |
+
# super res unet first steps
|
| 374 |
+
|
| 375 |
+
SUPER_RES_UNET_FIRST_STEPS_PREFIX = "model_first_steps"
|
| 376 |
+
|
| 377 |
+
SUPER_RES_UNET_FIRST_STEPS_CONFIG = {
|
| 378 |
+
"sample_size": 256,
|
| 379 |
+
"layers_per_block": 3,
|
| 380 |
+
"down_block_types": (
|
| 381 |
+
"ResnetDownsampleBlock2D",
|
| 382 |
+
"ResnetDownsampleBlock2D",
|
| 383 |
+
"ResnetDownsampleBlock2D",
|
| 384 |
+
"ResnetDownsampleBlock2D",
|
| 385 |
+
),
|
| 386 |
+
"up_block_types": (
|
| 387 |
+
"ResnetUpsampleBlock2D",
|
| 388 |
+
"ResnetUpsampleBlock2D",
|
| 389 |
+
"ResnetUpsampleBlock2D",
|
| 390 |
+
"ResnetUpsampleBlock2D",
|
| 391 |
+
),
|
| 392 |
+
"block_out_channels": (320, 640, 960, 1280),
|
| 393 |
+
"in_channels": 6,
|
| 394 |
+
"out_channels": 3,
|
| 395 |
+
"add_attention": False,
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def super_res_unet_first_steps_model_from_original_config():
|
| 400 |
+
model = UNet2DModel(**SUPER_RES_UNET_FIRST_STEPS_CONFIG)
|
| 401 |
+
|
| 402 |
+
return model
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def super_res_unet_first_steps_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
| 406 |
+
diffusers_checkpoint = {}
|
| 407 |
+
|
| 408 |
+
original_unet_prefix = SUPER_RES_UNET_FIRST_STEPS_PREFIX
|
| 409 |
+
|
| 410 |
+
diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))
|
| 411 |
+
diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))
|
| 412 |
+
|
| 413 |
+
# <original>.input_blocks -> <diffusers>.down_blocks
|
| 414 |
+
|
| 415 |
+
original_down_block_idx = 1
|
| 416 |
+
|
| 417 |
+
for diffusers_down_block_idx in range(len(model.down_blocks)):
|
| 418 |
+
checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
|
| 419 |
+
model,
|
| 420 |
+
checkpoint,
|
| 421 |
+
diffusers_down_block_idx=diffusers_down_block_idx,
|
| 422 |
+
original_down_block_idx=original_down_block_idx,
|
| 423 |
+
original_unet_prefix=original_unet_prefix,
|
| 424 |
+
num_head_channels=None,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
original_down_block_idx += num_original_down_blocks
|
| 428 |
+
|
| 429 |
+
diffusers_checkpoint.update(checkpoint_update)
|
| 430 |
+
|
| 431 |
+
diffusers_checkpoint.update(
|
| 432 |
+
unet_midblock_to_diffusers_checkpoint(
|
| 433 |
+
model,
|
| 434 |
+
checkpoint,
|
| 435 |
+
original_unet_prefix=original_unet_prefix,
|
| 436 |
+
num_head_channels=None,
|
| 437 |
+
)
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# <original>.output_blocks -> <diffusers>.up_blocks
|
| 441 |
+
|
| 442 |
+
original_up_block_idx = 0
|
| 443 |
+
|
| 444 |
+
for diffusers_up_block_idx in range(len(model.up_blocks)):
|
| 445 |
+
checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
|
| 446 |
+
model,
|
| 447 |
+
checkpoint,
|
| 448 |
+
diffusers_up_block_idx=diffusers_up_block_idx,
|
| 449 |
+
original_up_block_idx=original_up_block_idx,
|
| 450 |
+
original_unet_prefix=original_unet_prefix,
|
| 451 |
+
num_head_channels=None,
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
original_up_block_idx += num_original_up_blocks
|
| 455 |
+
|
| 456 |
+
diffusers_checkpoint.update(checkpoint_update)
|
| 457 |
+
|
| 458 |
+
# done <original>.output_blocks -> <diffusers>.up_blocks
|
| 459 |
+
|
| 460 |
+
diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))
|
| 461 |
+
diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))
|
| 462 |
+
|
| 463 |
+
return diffusers_checkpoint
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
# done super res unet first steps
|
| 467 |
+
|
| 468 |
+
# super res unet last step
|
| 469 |
+
|
| 470 |
+
SUPER_RES_UNET_LAST_STEP_PREFIX = "model_last_step"
|
| 471 |
+
|
| 472 |
+
SUPER_RES_UNET_LAST_STEP_CONFIG = {
|
| 473 |
+
"sample_size": 256,
|
| 474 |
+
"layers_per_block": 3,
|
| 475 |
+
"down_block_types": (
|
| 476 |
+
"ResnetDownsampleBlock2D",
|
| 477 |
+
"ResnetDownsampleBlock2D",
|
| 478 |
+
"ResnetDownsampleBlock2D",
|
| 479 |
+
"ResnetDownsampleBlock2D",
|
| 480 |
+
),
|
| 481 |
+
"up_block_types": (
|
| 482 |
+
"ResnetUpsampleBlock2D",
|
| 483 |
+
"ResnetUpsampleBlock2D",
|
| 484 |
+
"ResnetUpsampleBlock2D",
|
| 485 |
+
"ResnetUpsampleBlock2D",
|
| 486 |
+
),
|
| 487 |
+
"block_out_channels": (320, 640, 960, 1280),
|
| 488 |
+
"in_channels": 6,
|
| 489 |
+
"out_channels": 3,
|
| 490 |
+
"add_attention": False,
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def super_res_unet_last_step_model_from_original_config():
|
| 495 |
+
model = UNet2DModel(**SUPER_RES_UNET_LAST_STEP_CONFIG)
|
| 496 |
+
|
| 497 |
+
return model
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
| 501 |
+
diffusers_checkpoint = {}
|
| 502 |
+
|
| 503 |
+
original_unet_prefix = SUPER_RES_UNET_LAST_STEP_PREFIX
|
| 504 |
+
|
| 505 |
+
diffusers_checkpoint.update(unet_time_embeddings(checkpoint, original_unet_prefix))
|
| 506 |
+
diffusers_checkpoint.update(unet_conv_in(checkpoint, original_unet_prefix))
|
| 507 |
+
|
| 508 |
+
# <original>.input_blocks -> <diffusers>.down_blocks
|
| 509 |
+
|
| 510 |
+
original_down_block_idx = 1
|
| 511 |
+
|
| 512 |
+
for diffusers_down_block_idx in range(len(model.down_blocks)):
|
| 513 |
+
checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
|
| 514 |
+
model,
|
| 515 |
+
checkpoint,
|
| 516 |
+
diffusers_down_block_idx=diffusers_down_block_idx,
|
| 517 |
+
original_down_block_idx=original_down_block_idx,
|
| 518 |
+
original_unet_prefix=original_unet_prefix,
|
| 519 |
+
num_head_channels=None,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
original_down_block_idx += num_original_down_blocks
|
| 523 |
+
|
| 524 |
+
diffusers_checkpoint.update(checkpoint_update)
|
| 525 |
+
|
| 526 |
+
diffusers_checkpoint.update(
|
| 527 |
+
unet_midblock_to_diffusers_checkpoint(
|
| 528 |
+
model,
|
| 529 |
+
checkpoint,
|
| 530 |
+
original_unet_prefix=original_unet_prefix,
|
| 531 |
+
num_head_channels=None,
|
| 532 |
+
)
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
# <original>.output_blocks -> <diffusers>.up_blocks
|
| 536 |
+
|
| 537 |
+
original_up_block_idx = 0
|
| 538 |
+
|
| 539 |
+
for diffusers_up_block_idx in range(len(model.up_blocks)):
|
| 540 |
+
checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
|
| 541 |
+
model,
|
| 542 |
+
checkpoint,
|
| 543 |
+
diffusers_up_block_idx=diffusers_up_block_idx,
|
| 544 |
+
original_up_block_idx=original_up_block_idx,
|
| 545 |
+
original_unet_prefix=original_unet_prefix,
|
| 546 |
+
num_head_channels=None,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
original_up_block_idx += num_original_up_blocks
|
| 550 |
+
|
| 551 |
+
diffusers_checkpoint.update(checkpoint_update)
|
| 552 |
+
|
| 553 |
+
# done <original>.output_blocks -> <diffusers>.up_blocks
|
| 554 |
+
|
| 555 |
+
diffusers_checkpoint.update(unet_conv_norm_out(checkpoint, original_unet_prefix))
|
| 556 |
+
diffusers_checkpoint.update(unet_conv_out(checkpoint, original_unet_prefix))
|
| 557 |
+
|
| 558 |
+
return diffusers_checkpoint
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
# done super res unet last step
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
# unet utils
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
# <original>.time_embed -> <diffusers>.time_embedding
|
| 568 |
+
def unet_time_embeddings(checkpoint, original_unet_prefix):
|
| 569 |
+
diffusers_checkpoint = {}
|
| 570 |
+
|
| 571 |
+
diffusers_checkpoint.update(
|
| 572 |
+
{
|
| 573 |
+
"time_embedding.linear_1.weight": checkpoint[f"{original_unet_prefix}.time_embed.0.weight"],
|
| 574 |
+
"time_embedding.linear_1.bias": checkpoint[f"{original_unet_prefix}.time_embed.0.bias"],
|
| 575 |
+
"time_embedding.linear_2.weight": checkpoint[f"{original_unet_prefix}.time_embed.2.weight"],
|
| 576 |
+
"time_embedding.linear_2.bias": checkpoint[f"{original_unet_prefix}.time_embed.2.bias"],
|
| 577 |
+
}
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
return diffusers_checkpoint
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
# <original>.input_blocks.0 -> <diffusers>.conv_in
|
| 584 |
+
def unet_conv_in(checkpoint, original_unet_prefix):
|
| 585 |
+
diffusers_checkpoint = {}
|
| 586 |
+
|
| 587 |
+
diffusers_checkpoint.update(
|
| 588 |
+
{
|
| 589 |
+
"conv_in.weight": checkpoint[f"{original_unet_prefix}.input_blocks.0.0.weight"],
|
| 590 |
+
"conv_in.bias": checkpoint[f"{original_unet_prefix}.input_blocks.0.0.bias"],
|
| 591 |
+
}
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
return diffusers_checkpoint
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
# <original>.out.0 -> <diffusers>.conv_norm_out
|
| 598 |
+
def unet_conv_norm_out(checkpoint, original_unet_prefix):
|
| 599 |
+
diffusers_checkpoint = {}
|
| 600 |
+
|
| 601 |
+
diffusers_checkpoint.update(
|
| 602 |
+
{
|
| 603 |
+
"conv_norm_out.weight": checkpoint[f"{original_unet_prefix}.out.0.weight"],
|
| 604 |
+
"conv_norm_out.bias": checkpoint[f"{original_unet_prefix}.out.0.bias"],
|
| 605 |
+
}
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
return diffusers_checkpoint
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
# <original>.out.2 -> <diffusers>.conv_out
|
| 612 |
+
def unet_conv_out(checkpoint, original_unet_prefix):
|
| 613 |
+
diffusers_checkpoint = {}
|
| 614 |
+
|
| 615 |
+
diffusers_checkpoint.update(
|
| 616 |
+
{
|
| 617 |
+
"conv_out.weight": checkpoint[f"{original_unet_prefix}.out.2.weight"],
|
| 618 |
+
"conv_out.bias": checkpoint[f"{original_unet_prefix}.out.2.bias"],
|
| 619 |
+
}
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
return diffusers_checkpoint
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
# <original>.input_blocks -> <diffusers>.down_blocks
|
| 626 |
+
def unet_downblock_to_diffusers_checkpoint(
|
| 627 |
+
model, checkpoint, *, diffusers_down_block_idx, original_down_block_idx, original_unet_prefix, num_head_channels
|
| 628 |
+
):
|
| 629 |
+
diffusers_checkpoint = {}
|
| 630 |
+
|
| 631 |
+
diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.resnets"
|
| 632 |
+
original_down_block_prefix = f"{original_unet_prefix}.input_blocks"
|
| 633 |
+
|
| 634 |
+
down_block = model.down_blocks[diffusers_down_block_idx]
|
| 635 |
+
|
| 636 |
+
num_resnets = len(down_block.resnets)
|
| 637 |
+
|
| 638 |
+
if down_block.downsamplers is None:
|
| 639 |
+
downsampler = False
|
| 640 |
+
else:
|
| 641 |
+
assert len(down_block.downsamplers) == 1
|
| 642 |
+
downsampler = True
|
| 643 |
+
# The downsample block is also a resnet
|
| 644 |
+
num_resnets += 1
|
| 645 |
+
|
| 646 |
+
for resnet_idx_inc in range(num_resnets):
|
| 647 |
+
full_resnet_prefix = f"{original_down_block_prefix}.{original_down_block_idx + resnet_idx_inc}.0"
|
| 648 |
+
|
| 649 |
+
if downsampler and resnet_idx_inc == num_resnets - 1:
|
| 650 |
+
# this is a downsample block
|
| 651 |
+
full_diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.downsamplers.0"
|
| 652 |
+
else:
|
| 653 |
+
# this is a regular resnet block
|
| 654 |
+
full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}"
|
| 655 |
+
|
| 656 |
+
diffusers_checkpoint.update(
|
| 657 |
+
resnet_to_diffusers_checkpoint(
|
| 658 |
+
checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix
|
| 659 |
+
)
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
if hasattr(down_block, "attentions"):
|
| 663 |
+
num_attentions = len(down_block.attentions)
|
| 664 |
+
diffusers_attention_prefix = f"down_blocks.{diffusers_down_block_idx}.attentions"
|
| 665 |
+
|
| 666 |
+
for attention_idx_inc in range(num_attentions):
|
| 667 |
+
full_attention_prefix = f"{original_down_block_prefix}.{original_down_block_idx + attention_idx_inc}.1"
|
| 668 |
+
full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}"
|
| 669 |
+
|
| 670 |
+
diffusers_checkpoint.update(
|
| 671 |
+
attention_to_diffusers_checkpoint(
|
| 672 |
+
checkpoint,
|
| 673 |
+
attention_prefix=full_attention_prefix,
|
| 674 |
+
diffusers_attention_prefix=full_diffusers_attention_prefix,
|
| 675 |
+
num_head_channels=num_head_channels,
|
| 676 |
+
)
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
num_original_down_blocks = num_resnets
|
| 680 |
+
|
| 681 |
+
return diffusers_checkpoint, num_original_down_blocks
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
# <original>.middle_block -> <diffusers>.mid_block
|
| 685 |
+
def unet_midblock_to_diffusers_checkpoint(model, checkpoint, *, original_unet_prefix, num_head_channels):
|
| 686 |
+
diffusers_checkpoint = {}
|
| 687 |
+
|
| 688 |
+
# block 0
|
| 689 |
+
|
| 690 |
+
original_block_idx = 0
|
| 691 |
+
|
| 692 |
+
diffusers_checkpoint.update(
|
| 693 |
+
resnet_to_diffusers_checkpoint(
|
| 694 |
+
checkpoint,
|
| 695 |
+
diffusers_resnet_prefix="mid_block.resnets.0",
|
| 696 |
+
resnet_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}",
|
| 697 |
+
)
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
original_block_idx += 1
|
| 701 |
+
|
| 702 |
+
# optional block 1
|
| 703 |
+
|
| 704 |
+
if hasattr(model.mid_block, "attentions") and model.mid_block.attentions[0] is not None:
|
| 705 |
+
diffusers_checkpoint.update(
|
| 706 |
+
attention_to_diffusers_checkpoint(
|
| 707 |
+
checkpoint,
|
| 708 |
+
diffusers_attention_prefix="mid_block.attentions.0",
|
| 709 |
+
attention_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}",
|
| 710 |
+
num_head_channels=num_head_channels,
|
| 711 |
+
)
|
| 712 |
+
)
|
| 713 |
+
original_block_idx += 1
|
| 714 |
+
|
| 715 |
+
# block 1 or block 2
|
| 716 |
+
|
| 717 |
+
diffusers_checkpoint.update(
|
| 718 |
+
resnet_to_diffusers_checkpoint(
|
| 719 |
+
checkpoint,
|
| 720 |
+
diffusers_resnet_prefix="mid_block.resnets.1",
|
| 721 |
+
resnet_prefix=f"{original_unet_prefix}.middle_block.{original_block_idx}",
|
| 722 |
+
)
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
return diffusers_checkpoint
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
# <original>.output_blocks -> <diffusers>.up_blocks
|
| 729 |
+
def unet_upblock_to_diffusers_checkpoint(
|
| 730 |
+
model, checkpoint, *, diffusers_up_block_idx, original_up_block_idx, original_unet_prefix, num_head_channels
|
| 731 |
+
):
|
| 732 |
+
diffusers_checkpoint = {}
|
| 733 |
+
|
| 734 |
+
diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.resnets"
|
| 735 |
+
original_up_block_prefix = f"{original_unet_prefix}.output_blocks"
|
| 736 |
+
|
| 737 |
+
up_block = model.up_blocks[diffusers_up_block_idx]
|
| 738 |
+
|
| 739 |
+
num_resnets = len(up_block.resnets)
|
| 740 |
+
|
| 741 |
+
if up_block.upsamplers is None:
|
| 742 |
+
upsampler = False
|
| 743 |
+
else:
|
| 744 |
+
assert len(up_block.upsamplers) == 1
|
| 745 |
+
upsampler = True
|
| 746 |
+
# The upsample block is also a resnet
|
| 747 |
+
num_resnets += 1
|
| 748 |
+
|
| 749 |
+
has_attentions = hasattr(up_block, "attentions")
|
| 750 |
+
|
| 751 |
+
for resnet_idx_inc in range(num_resnets):
|
| 752 |
+
if upsampler and resnet_idx_inc == num_resnets - 1:
|
| 753 |
+
# this is an upsample block
|
| 754 |
+
if has_attentions:
|
| 755 |
+
# There is a middle attention block that we skip
|
| 756 |
+
original_resnet_block_idx = 2
|
| 757 |
+
else:
|
| 758 |
+
original_resnet_block_idx = 1
|
| 759 |
+
|
| 760 |
+
# we add the `minus 1` because the last two resnets are stuck together in the same output block
|
| 761 |
+
full_resnet_prefix = (
|
| 762 |
+
f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc - 1}.{original_resnet_block_idx}"
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
full_diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.upsamplers.0"
|
| 766 |
+
else:
|
| 767 |
+
# this is a regular resnet block
|
| 768 |
+
full_resnet_prefix = f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc}.0"
|
| 769 |
+
full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}"
|
| 770 |
+
|
| 771 |
+
diffusers_checkpoint.update(
|
| 772 |
+
resnet_to_diffusers_checkpoint(
|
| 773 |
+
checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix
|
| 774 |
+
)
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
if has_attentions:
|
| 778 |
+
num_attentions = len(up_block.attentions)
|
| 779 |
+
diffusers_attention_prefix = f"up_blocks.{diffusers_up_block_idx}.attentions"
|
| 780 |
+
|
| 781 |
+
for attention_idx_inc in range(num_attentions):
|
| 782 |
+
full_attention_prefix = f"{original_up_block_prefix}.{original_up_block_idx + attention_idx_inc}.1"
|
| 783 |
+
full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}"
|
| 784 |
+
|
| 785 |
+
diffusers_checkpoint.update(
|
| 786 |
+
attention_to_diffusers_checkpoint(
|
| 787 |
+
checkpoint,
|
| 788 |
+
attention_prefix=full_attention_prefix,
|
| 789 |
+
diffusers_attention_prefix=full_diffusers_attention_prefix,
|
| 790 |
+
num_head_channels=num_head_channels,
|
| 791 |
+
)
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
num_original_down_blocks = num_resnets - 1 if upsampler else num_resnets
|
| 795 |
+
|
| 796 |
+
return diffusers_checkpoint, num_original_down_blocks
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
def resnet_to_diffusers_checkpoint(checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
|
| 800 |
+
diffusers_checkpoint = {
|
| 801 |
+
f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.in_layers.0.weight"],
|
| 802 |
+
f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.in_layers.0.bias"],
|
| 803 |
+
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.in_layers.2.weight"],
|
| 804 |
+
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.in_layers.2.bias"],
|
| 805 |
+
f"{diffusers_resnet_prefix}.time_emb_proj.weight": checkpoint[f"{resnet_prefix}.emb_layers.1.weight"],
|
| 806 |
+
f"{diffusers_resnet_prefix}.time_emb_proj.bias": checkpoint[f"{resnet_prefix}.emb_layers.1.bias"],
|
| 807 |
+
f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.out_layers.0.weight"],
|
| 808 |
+
f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.out_layers.0.bias"],
|
| 809 |
+
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.out_layers.3.weight"],
|
| 810 |
+
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.out_layers.3.bias"],
|
| 811 |
+
}
|
| 812 |
+
|
| 813 |
+
skip_connection_prefix = f"{resnet_prefix}.skip_connection"
|
| 814 |
+
|
| 815 |
+
if f"{skip_connection_prefix}.weight" in checkpoint:
|
| 816 |
+
diffusers_checkpoint.update(
|
| 817 |
+
{
|
| 818 |
+
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{skip_connection_prefix}.weight"],
|
| 819 |
+
f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{skip_connection_prefix}.bias"],
|
| 820 |
+
}
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
return diffusers_checkpoint
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
def attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix, num_head_channels):
|
| 827 |
+
diffusers_checkpoint = {}
|
| 828 |
+
|
| 829 |
+
# <original>.norm -> <diffusers>.group_norm
|
| 830 |
+
diffusers_checkpoint.update(
|
| 831 |
+
{
|
| 832 |
+
f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],
|
| 833 |
+
f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],
|
| 834 |
+
}
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
# <original>.qkv -> <diffusers>.{query, key, value}
|
| 838 |
+
[q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
|
| 839 |
+
weight=checkpoint[f"{attention_prefix}.qkv.weight"][:, :, 0],
|
| 840 |
+
bias=checkpoint[f"{attention_prefix}.qkv.bias"],
|
| 841 |
+
split=3,
|
| 842 |
+
chunk_size=num_head_channels,
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
diffusers_checkpoint.update(
|
| 846 |
+
{
|
| 847 |
+
f"{diffusers_attention_prefix}.to_q.weight": q_weight,
|
| 848 |
+
f"{diffusers_attention_prefix}.to_q.bias": q_bias,
|
| 849 |
+
f"{diffusers_attention_prefix}.to_k.weight": k_weight,
|
| 850 |
+
f"{diffusers_attention_prefix}.to_k.bias": k_bias,
|
| 851 |
+
f"{diffusers_attention_prefix}.to_v.weight": v_weight,
|
| 852 |
+
f"{diffusers_attention_prefix}.to_v.bias": v_bias,
|
| 853 |
+
}
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
# <original>.encoder_kv -> <diffusers>.{context_key, context_value}
|
| 857 |
+
[encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions(
|
| 858 |
+
weight=checkpoint[f"{attention_prefix}.encoder_kv.weight"][:, :, 0],
|
| 859 |
+
bias=checkpoint[f"{attention_prefix}.encoder_kv.bias"],
|
| 860 |
+
split=2,
|
| 861 |
+
chunk_size=num_head_channels,
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
diffusers_checkpoint.update(
|
| 865 |
+
{
|
| 866 |
+
f"{diffusers_attention_prefix}.add_k_proj.weight": encoder_k_weight,
|
| 867 |
+
f"{diffusers_attention_prefix}.add_k_proj.bias": encoder_k_bias,
|
| 868 |
+
f"{diffusers_attention_prefix}.add_v_proj.weight": encoder_v_weight,
|
| 869 |
+
f"{diffusers_attention_prefix}.add_v_proj.bias": encoder_v_bias,
|
| 870 |
+
}
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
# <original>.proj_out (1d conv) -> <diffusers>.proj_attn (linear)
|
| 874 |
+
diffusers_checkpoint.update(
|
| 875 |
+
{
|
| 876 |
+
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][
|
| 877 |
+
:, :, 0
|
| 878 |
+
],
|
| 879 |
+
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
|
| 880 |
+
}
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
return diffusers_checkpoint
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
|
| 887 |
+
def split_attentions(*, weight, bias, split, chunk_size):
|
| 888 |
+
weights = [None] * split
|
| 889 |
+
biases = [None] * split
|
| 890 |
+
|
| 891 |
+
weights_biases_idx = 0
|
| 892 |
+
|
| 893 |
+
for starting_row_index in range(0, weight.shape[0], chunk_size):
|
| 894 |
+
row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
|
| 895 |
+
|
| 896 |
+
weight_rows = weight[row_indices, :]
|
| 897 |
+
bias_rows = bias[row_indices]
|
| 898 |
+
|
| 899 |
+
if weights[weights_biases_idx] is None:
|
| 900 |
+
assert weights[weights_biases_idx] is None
|
| 901 |
+
weights[weights_biases_idx] = weight_rows
|
| 902 |
+
biases[weights_biases_idx] = bias_rows
|
| 903 |
+
else:
|
| 904 |
+
assert weights[weights_biases_idx] is not None
|
| 905 |
+
weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
|
| 906 |
+
biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
|
| 907 |
+
|
| 908 |
+
weights_biases_idx = (weights_biases_idx + 1) % split
|
| 909 |
+
|
| 910 |
+
return weights, biases
|
| 911 |
+
|
| 912 |
+
|
| 913 |
+
# done unet utils
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
# Driver functions
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
def text_encoder():
|
| 920 |
+
print("loading CLIP text encoder")
|
| 921 |
+
|
| 922 |
+
clip_name = "openai/clip-vit-large-patch14"
|
| 923 |
+
|
| 924 |
+
# sets pad_value to 0
|
| 925 |
+
pad_token = "!"
|
| 926 |
+
|
| 927 |
+
tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto")
|
| 928 |
+
|
| 929 |
+
assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0
|
| 930 |
+
|
| 931 |
+
text_encoder_model = CLIPTextModelWithProjection.from_pretrained(
|
| 932 |
+
clip_name,
|
| 933 |
+
# `CLIPTextModel` does not support device_map="auto"
|
| 934 |
+
# device_map="auto"
|
| 935 |
+
)
|
| 936 |
+
|
| 937 |
+
print("done loading CLIP text encoder")
|
| 938 |
+
|
| 939 |
+
return text_encoder_model, tokenizer_model
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
def prior(*, args, checkpoint_map_location):
|
| 943 |
+
print("loading prior")
|
| 944 |
+
|
| 945 |
+
prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location)
|
| 946 |
+
prior_checkpoint = prior_checkpoint["state_dict"]
|
| 947 |
+
|
| 948 |
+
clip_stats_checkpoint = torch.load(args.clip_stat_path, map_location=checkpoint_map_location)
|
| 949 |
+
|
| 950 |
+
prior_model = prior_model_from_original_config()
|
| 951 |
+
|
| 952 |
+
prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint(
|
| 953 |
+
prior_model, prior_checkpoint, clip_stats_checkpoint
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
del prior_checkpoint
|
| 957 |
+
del clip_stats_checkpoint
|
| 958 |
+
|
| 959 |
+
load_checkpoint_to_model(prior_diffusers_checkpoint, prior_model, strict=True)
|
| 960 |
+
|
| 961 |
+
print("done loading prior")
|
| 962 |
+
|
| 963 |
+
return prior_model
|
| 964 |
+
|
| 965 |
+
|
| 966 |
+
def decoder(*, args, checkpoint_map_location):
|
| 967 |
+
print("loading decoder")
|
| 968 |
+
|
| 969 |
+
decoder_checkpoint = torch.load(args.decoder_checkpoint_path, map_location=checkpoint_map_location)
|
| 970 |
+
decoder_checkpoint = decoder_checkpoint["state_dict"]
|
| 971 |
+
|
| 972 |
+
decoder_model = decoder_model_from_original_config()
|
| 973 |
+
|
| 974 |
+
decoder_diffusers_checkpoint = decoder_original_checkpoint_to_diffusers_checkpoint(
|
| 975 |
+
decoder_model, decoder_checkpoint
|
| 976 |
+
)
|
| 977 |
+
|
| 978 |
+
# text proj interlude
|
| 979 |
+
|
| 980 |
+
# The original decoder implementation includes a set of parameters that are used
|
| 981 |
+
# for creating the `encoder_hidden_states` which are what the U-net is conditioned
|
| 982 |
+
# on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
|
| 983 |
+
# the parameters into the UnCLIPTextProjModel class
|
| 984 |
+
text_proj_model = text_proj_from_original_config()
|
| 985 |
+
|
| 986 |
+
text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(decoder_checkpoint)
|
| 987 |
+
|
| 988 |
+
load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True)
|
| 989 |
+
|
| 990 |
+
# done text proj interlude
|
| 991 |
+
|
| 992 |
+
del decoder_checkpoint
|
| 993 |
+
|
| 994 |
+
load_checkpoint_to_model(decoder_diffusers_checkpoint, decoder_model, strict=True)
|
| 995 |
+
|
| 996 |
+
print("done loading decoder")
|
| 997 |
+
|
| 998 |
+
return decoder_model, text_proj_model
|
| 999 |
+
|
| 1000 |
+
|
| 1001 |
+
def super_res_unet(*, args, checkpoint_map_location):
|
| 1002 |
+
print("loading super resolution unet")
|
| 1003 |
+
|
| 1004 |
+
super_res_checkpoint = torch.load(args.super_res_unet_checkpoint_path, map_location=checkpoint_map_location)
|
| 1005 |
+
super_res_checkpoint = super_res_checkpoint["state_dict"]
|
| 1006 |
+
|
| 1007 |
+
# model_first_steps
|
| 1008 |
+
|
| 1009 |
+
super_res_first_model = super_res_unet_first_steps_model_from_original_config()
|
| 1010 |
+
|
| 1011 |
+
super_res_first_steps_checkpoint = super_res_unet_first_steps_original_checkpoint_to_diffusers_checkpoint(
|
| 1012 |
+
super_res_first_model, super_res_checkpoint
|
| 1013 |
+
)
|
| 1014 |
+
|
| 1015 |
+
# model_last_step
|
| 1016 |
+
super_res_last_model = super_res_unet_last_step_model_from_original_config()
|
| 1017 |
+
|
| 1018 |
+
super_res_last_step_checkpoint = super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(
|
| 1019 |
+
super_res_last_model, super_res_checkpoint
|
| 1020 |
+
)
|
| 1021 |
+
|
| 1022 |
+
del super_res_checkpoint
|
| 1023 |
+
|
| 1024 |
+
load_checkpoint_to_model(super_res_first_steps_checkpoint, super_res_first_model, strict=True)
|
| 1025 |
+
|
| 1026 |
+
load_checkpoint_to_model(super_res_last_step_checkpoint, super_res_last_model, strict=True)
|
| 1027 |
+
|
| 1028 |
+
print("done loading super resolution unet")
|
| 1029 |
+
|
| 1030 |
+
return super_res_first_model, super_res_last_model
|
| 1031 |
+
|
| 1032 |
+
|
| 1033 |
+
def load_checkpoint_to_model(checkpoint, model, strict=False):
|
| 1034 |
+
with tempfile.NamedTemporaryFile() as file:
|
| 1035 |
+
torch.save(checkpoint, file.name)
|
| 1036 |
+
del checkpoint
|
| 1037 |
+
if strict:
|
| 1038 |
+
model.load_state_dict(torch.load(file.name), strict=True)
|
| 1039 |
+
else:
|
| 1040 |
+
load_checkpoint_and_dispatch(model, file.name, device_map="auto")
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
if __name__ == "__main__":
|
| 1044 |
+
parser = argparse.ArgumentParser()
|
| 1045 |
+
|
| 1046 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 1047 |
+
|
| 1048 |
+
parser.add_argument(
|
| 1049 |
+
"--prior_checkpoint_path",
|
| 1050 |
+
default=None,
|
| 1051 |
+
type=str,
|
| 1052 |
+
required=True,
|
| 1053 |
+
help="Path to the prior checkpoint to convert.",
|
| 1054 |
+
)
|
| 1055 |
+
|
| 1056 |
+
parser.add_argument(
|
| 1057 |
+
"--decoder_checkpoint_path",
|
| 1058 |
+
default=None,
|
| 1059 |
+
type=str,
|
| 1060 |
+
required=True,
|
| 1061 |
+
help="Path to the decoder checkpoint to convert.",
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
parser.add_argument(
|
| 1065 |
+
"--super_res_unet_checkpoint_path",
|
| 1066 |
+
default=None,
|
| 1067 |
+
type=str,
|
| 1068 |
+
required=True,
|
| 1069 |
+
help="Path to the super resolution checkpoint to convert.",
|
| 1070 |
+
)
|
| 1071 |
+
|
| 1072 |
+
parser.add_argument(
|
| 1073 |
+
"--clip_stat_path", default=None, type=str, required=True, help="Path to the clip stats checkpoint to convert."
|
| 1074 |
+
)
|
| 1075 |
+
|
| 1076 |
+
parser.add_argument(
|
| 1077 |
+
"--checkpoint_load_device",
|
| 1078 |
+
default="cpu",
|
| 1079 |
+
type=str,
|
| 1080 |
+
required=False,
|
| 1081 |
+
help="The device passed to `map_location` when loading checkpoints.",
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
parser.add_argument(
|
| 1085 |
+
"--debug",
|
| 1086 |
+
default=None,
|
| 1087 |
+
type=str,
|
| 1088 |
+
required=False,
|
| 1089 |
+
help="Only run a specific stage of the convert script. Used for debugging",
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
args = parser.parse_args()
|
| 1093 |
+
|
| 1094 |
+
print(f"loading checkpoints to {args.checkpoint_load_device}")
|
| 1095 |
+
|
| 1096 |
+
checkpoint_map_location = torch.device(args.checkpoint_load_device)
|
| 1097 |
+
|
| 1098 |
+
if args.debug is not None:
|
| 1099 |
+
print(f"debug: only executing {args.debug}")
|
| 1100 |
+
|
| 1101 |
+
if args.debug is None:
|
| 1102 |
+
text_encoder_model, tokenizer_model = text_encoder()
|
| 1103 |
+
|
| 1104 |
+
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
|
| 1105 |
+
|
| 1106 |
+
decoder_model, text_proj_model = decoder(args=args, checkpoint_map_location=checkpoint_map_location)
|
| 1107 |
+
|
| 1108 |
+
super_res_first_model, super_res_last_model = super_res_unet(
|
| 1109 |
+
args=args, checkpoint_map_location=checkpoint_map_location
|
| 1110 |
+
)
|
| 1111 |
+
|
| 1112 |
+
prior_scheduler = UnCLIPScheduler(
|
| 1113 |
+
variance_type="fixed_small_log",
|
| 1114 |
+
prediction_type="sample",
|
| 1115 |
+
num_train_timesteps=1000,
|
| 1116 |
+
clip_sample_range=5.0,
|
| 1117 |
+
)
|
| 1118 |
+
|
| 1119 |
+
decoder_scheduler = UnCLIPScheduler(
|
| 1120 |
+
variance_type="learned_range",
|
| 1121 |
+
prediction_type="epsilon",
|
| 1122 |
+
num_train_timesteps=1000,
|
| 1123 |
+
)
|
| 1124 |
+
|
| 1125 |
+
super_res_scheduler = UnCLIPScheduler(
|
| 1126 |
+
variance_type="fixed_small_log",
|
| 1127 |
+
prediction_type="epsilon",
|
| 1128 |
+
num_train_timesteps=1000,
|
| 1129 |
+
)
|
| 1130 |
+
|
| 1131 |
+
print(f"saving Kakao Brain unCLIP to {args.dump_path}")
|
| 1132 |
+
|
| 1133 |
+
pipe = UnCLIPPipeline(
|
| 1134 |
+
prior=prior_model,
|
| 1135 |
+
decoder=decoder_model,
|
| 1136 |
+
text_proj=text_proj_model,
|
| 1137 |
+
tokenizer=tokenizer_model,
|
| 1138 |
+
text_encoder=text_encoder_model,
|
| 1139 |
+
super_res_first=super_res_first_model,
|
| 1140 |
+
super_res_last=super_res_last_model,
|
| 1141 |
+
prior_scheduler=prior_scheduler,
|
| 1142 |
+
decoder_scheduler=decoder_scheduler,
|
| 1143 |
+
super_res_scheduler=super_res_scheduler,
|
| 1144 |
+
)
|
| 1145 |
+
pipe.save_pretrained(args.dump_path)
|
| 1146 |
+
|
| 1147 |
+
print("done writing Kakao Brain unCLIP")
|
| 1148 |
+
elif args.debug == "text_encoder":
|
| 1149 |
+
text_encoder_model, tokenizer_model = text_encoder()
|
| 1150 |
+
elif args.debug == "prior":
|
| 1151 |
+
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
|
| 1152 |
+
elif args.debug == "decoder":
|
| 1153 |
+
decoder_model, text_proj_model = decoder(args=args, checkpoint_map_location=checkpoint_map_location)
|
| 1154 |
+
elif args.debug == "super_res_unet":
|
| 1155 |
+
super_res_first_model, super_res_last_model = super_res_unet(
|
| 1156 |
+
args=args, checkpoint_map_location=checkpoint_map_location
|
| 1157 |
+
)
|
| 1158 |
+
else:
|
| 1159 |
+
raise ValueError(f"unknown debug value : {args.debug}")
|
diffusers/scripts/convert_kandinsky3_unet.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import fnmatch
|
| 4 |
+
|
| 5 |
+
from safetensors.torch import load_file
|
| 6 |
+
|
| 7 |
+
from diffusers import Kandinsky3UNet
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
MAPPING = {
|
| 11 |
+
"to_time_embed.1": "time_embedding.linear_1",
|
| 12 |
+
"to_time_embed.3": "time_embedding.linear_2",
|
| 13 |
+
"in_layer": "conv_in",
|
| 14 |
+
"out_layer.0": "conv_norm_out",
|
| 15 |
+
"out_layer.2": "conv_out",
|
| 16 |
+
"down_samples": "down_blocks",
|
| 17 |
+
"up_samples": "up_blocks",
|
| 18 |
+
"projection_lin": "encoder_hid_proj.projection_linear",
|
| 19 |
+
"projection_ln": "encoder_hid_proj.projection_norm",
|
| 20 |
+
"feature_pooling": "add_time_condition",
|
| 21 |
+
"to_query": "to_q",
|
| 22 |
+
"to_key": "to_k",
|
| 23 |
+
"to_value": "to_v",
|
| 24 |
+
"output_layer": "to_out.0",
|
| 25 |
+
"self_attention_block": "attentions.0",
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
DYNAMIC_MAP = {
|
| 29 |
+
"resnet_attn_blocks.*.0": "resnets_in.*",
|
| 30 |
+
"resnet_attn_blocks.*.1": ("attentions.*", 1),
|
| 31 |
+
"resnet_attn_blocks.*.2": "resnets_out.*",
|
| 32 |
+
}
|
| 33 |
+
# MAPPING = {}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def convert_state_dict(unet_state_dict):
|
| 37 |
+
"""
|
| 38 |
+
Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model.
|
| 39 |
+
Args:
|
| 40 |
+
unet_model (torch.nn.Module): The original U-Net model.
|
| 41 |
+
unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
OrderedDict: The converted state dictionary.
|
| 45 |
+
"""
|
| 46 |
+
# Example of renaming logic (this will vary based on your model's architecture)
|
| 47 |
+
converted_state_dict = {}
|
| 48 |
+
for key in unet_state_dict:
|
| 49 |
+
new_key = key
|
| 50 |
+
for pattern, new_pattern in MAPPING.items():
|
| 51 |
+
new_key = new_key.replace(pattern, new_pattern)
|
| 52 |
+
|
| 53 |
+
for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items():
|
| 54 |
+
has_matched = False
|
| 55 |
+
if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched:
|
| 56 |
+
star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1])
|
| 57 |
+
|
| 58 |
+
if isinstance(dyn_new_pattern, tuple):
|
| 59 |
+
new_star = star + dyn_new_pattern[-1]
|
| 60 |
+
dyn_new_pattern = dyn_new_pattern[0]
|
| 61 |
+
else:
|
| 62 |
+
new_star = star
|
| 63 |
+
|
| 64 |
+
pattern = dyn_pattern.replace("*", str(star))
|
| 65 |
+
new_pattern = dyn_new_pattern.replace("*", str(new_star))
|
| 66 |
+
|
| 67 |
+
new_key = new_key.replace(pattern, new_pattern)
|
| 68 |
+
has_matched = True
|
| 69 |
+
|
| 70 |
+
converted_state_dict[new_key] = unet_state_dict[key]
|
| 71 |
+
|
| 72 |
+
return converted_state_dict
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def main(model_path, output_path):
|
| 76 |
+
# Load your original U-Net model
|
| 77 |
+
unet_state_dict = load_file(model_path)
|
| 78 |
+
|
| 79 |
+
# Initialize your Kandinsky3UNet model
|
| 80 |
+
config = {}
|
| 81 |
+
|
| 82 |
+
# Convert the state dict
|
| 83 |
+
converted_state_dict = convert_state_dict(unet_state_dict)
|
| 84 |
+
|
| 85 |
+
unet = Kandinsky3UNet(config)
|
| 86 |
+
unet.load_state_dict(converted_state_dict)
|
| 87 |
+
|
| 88 |
+
unet.save_pretrained(output_path)
|
| 89 |
+
print(f"Converted model saved to {output_path}")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format")
|
| 94 |
+
parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model")
|
| 95 |
+
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
|
| 96 |
+
|
| 97 |
+
args = parser.parse_args()
|
| 98 |
+
main(args.model_path, args.output_path)
|
diffusers/scripts/convert_kandinsky_to_diffusers.py
ADDED
|
@@ -0,0 +1,1411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import tempfile
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from accelerate import load_checkpoint_and_dispatch
|
| 7 |
+
|
| 8 |
+
from diffusers import UNet2DConditionModel
|
| 9 |
+
from diffusers.models.transformers.prior_transformer import PriorTransformer
|
| 10 |
+
from diffusers.models.vq_model import VQModel
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
Example - From the diffusers root directory:
|
| 15 |
+
|
| 16 |
+
Download weights:
|
| 17 |
+
```sh
|
| 18 |
+
$ wget https://huggingface.co/ai-forever/Kandinsky_2.1/blob/main/prior_fp16.ckpt
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
Convert the model:
|
| 22 |
+
```sh
|
| 23 |
+
python scripts/convert_kandinsky_to_diffusers.py \
|
| 24 |
+
--prior_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/prior_fp16.ckpt \
|
| 25 |
+
--clip_stat_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/ViT-L-14_stats.th \
|
| 26 |
+
--text2img_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/decoder_fp16.ckpt \
|
| 27 |
+
--inpaint_text2img_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/inpainting_fp16.ckpt \
|
| 28 |
+
--movq_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/movq_final.ckpt \
|
| 29 |
+
--dump_path /home/yiyi_huggingface_co/dump \
|
| 30 |
+
--debug decoder
|
| 31 |
+
```
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# prior
|
| 36 |
+
|
| 37 |
+
PRIOR_ORIGINAL_PREFIX = "model"
|
| 38 |
+
|
| 39 |
+
# Uses default arguments
|
| 40 |
+
PRIOR_CONFIG = {}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def prior_model_from_original_config():
|
| 44 |
+
model = PriorTransformer(**PRIOR_CONFIG)
|
| 45 |
+
|
| 46 |
+
return model
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint, clip_stats_checkpoint):
|
| 50 |
+
diffusers_checkpoint = {}
|
| 51 |
+
|
| 52 |
+
# <original>.time_embed.0 -> <diffusers>.time_embedding.linear_1
|
| 53 |
+
diffusers_checkpoint.update(
|
| 54 |
+
{
|
| 55 |
+
"time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.weight"],
|
| 56 |
+
"time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.bias"],
|
| 57 |
+
}
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# <original>.clip_img_proj -> <diffusers>.proj_in
|
| 61 |
+
diffusers_checkpoint.update(
|
| 62 |
+
{
|
| 63 |
+
"proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.weight"],
|
| 64 |
+
"proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.bias"],
|
| 65 |
+
}
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# <original>.text_emb_proj -> <diffusers>.embedding_proj
|
| 69 |
+
diffusers_checkpoint.update(
|
| 70 |
+
{
|
| 71 |
+
"embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.weight"],
|
| 72 |
+
"embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.bias"],
|
| 73 |
+
}
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# <original>.text_enc_proj -> <diffusers>.encoder_hidden_states_proj
|
| 77 |
+
diffusers_checkpoint.update(
|
| 78 |
+
{
|
| 79 |
+
"encoder_hidden_states_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.weight"],
|
| 80 |
+
"encoder_hidden_states_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.bias"],
|
| 81 |
+
}
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# <original>.positional_embedding -> <diffusers>.positional_embedding
|
| 85 |
+
diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.positional_embedding"]})
|
| 86 |
+
|
| 87 |
+
# <original>.prd_emb -> <diffusers>.prd_embedding
|
| 88 |
+
diffusers_checkpoint.update({"prd_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.prd_emb"]})
|
| 89 |
+
|
| 90 |
+
# <original>.time_embed.2 -> <diffusers>.time_embedding.linear_2
|
| 91 |
+
diffusers_checkpoint.update(
|
| 92 |
+
{
|
| 93 |
+
"time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.weight"],
|
| 94 |
+
"time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.bias"],
|
| 95 |
+
}
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# <original>.resblocks.<x> -> <diffusers>.transformer_blocks.<x>
|
| 99 |
+
for idx in range(len(model.transformer_blocks)):
|
| 100 |
+
diffusers_transformer_prefix = f"transformer_blocks.{idx}"
|
| 101 |
+
original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.transformer.resblocks.{idx}"
|
| 102 |
+
|
| 103 |
+
# <original>.attn -> <diffusers>.attn1
|
| 104 |
+
diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1"
|
| 105 |
+
original_attention_prefix = f"{original_transformer_prefix}.attn"
|
| 106 |
+
diffusers_checkpoint.update(
|
| 107 |
+
prior_attention_to_diffusers(
|
| 108 |
+
checkpoint,
|
| 109 |
+
diffusers_attention_prefix=diffusers_attention_prefix,
|
| 110 |
+
original_attention_prefix=original_attention_prefix,
|
| 111 |
+
attention_head_dim=model.attention_head_dim,
|
| 112 |
+
)
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# <original>.mlp -> <diffusers>.ff
|
| 116 |
+
diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff"
|
| 117 |
+
original_ff_prefix = f"{original_transformer_prefix}.mlp"
|
| 118 |
+
diffusers_checkpoint.update(
|
| 119 |
+
prior_ff_to_diffusers(
|
| 120 |
+
checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix
|
| 121 |
+
)
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# <original>.ln_1 -> <diffusers>.norm1
|
| 125 |
+
diffusers_checkpoint.update(
|
| 126 |
+
{
|
| 127 |
+
f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[
|
| 128 |
+
f"{original_transformer_prefix}.ln_1.weight"
|
| 129 |
+
],
|
| 130 |
+
f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"],
|
| 131 |
+
}
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# <original>.ln_2 -> <diffusers>.norm3
|
| 135 |
+
diffusers_checkpoint.update(
|
| 136 |
+
{
|
| 137 |
+
f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[
|
| 138 |
+
f"{original_transformer_prefix}.ln_2.weight"
|
| 139 |
+
],
|
| 140 |
+
f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"],
|
| 141 |
+
}
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# <original>.final_ln -> <diffusers>.norm_out
|
| 145 |
+
diffusers_checkpoint.update(
|
| 146 |
+
{
|
| 147 |
+
"norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.weight"],
|
| 148 |
+
"norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.bias"],
|
| 149 |
+
}
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# <original>.out_proj -> <diffusers>.proj_to_clip_embeddings
|
| 153 |
+
diffusers_checkpoint.update(
|
| 154 |
+
{
|
| 155 |
+
"proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.weight"],
|
| 156 |
+
"proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.bias"],
|
| 157 |
+
}
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# clip stats
|
| 161 |
+
clip_mean, clip_std = clip_stats_checkpoint
|
| 162 |
+
clip_mean = clip_mean[None, :]
|
| 163 |
+
clip_std = clip_std[None, :]
|
| 164 |
+
|
| 165 |
+
diffusers_checkpoint.update({"clip_mean": clip_mean, "clip_std": clip_std})
|
| 166 |
+
|
| 167 |
+
return diffusers_checkpoint
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def prior_attention_to_diffusers(
|
| 171 |
+
checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim
|
| 172 |
+
):
|
| 173 |
+
diffusers_checkpoint = {}
|
| 174 |
+
|
| 175 |
+
# <original>.c_qkv -> <diffusers>.{to_q, to_k, to_v}
|
| 176 |
+
[q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
|
| 177 |
+
weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"],
|
| 178 |
+
bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"],
|
| 179 |
+
split=3,
|
| 180 |
+
chunk_size=attention_head_dim,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
diffusers_checkpoint.update(
|
| 184 |
+
{
|
| 185 |
+
f"{diffusers_attention_prefix}.to_q.weight": q_weight,
|
| 186 |
+
f"{diffusers_attention_prefix}.to_q.bias": q_bias,
|
| 187 |
+
f"{diffusers_attention_prefix}.to_k.weight": k_weight,
|
| 188 |
+
f"{diffusers_attention_prefix}.to_k.bias": k_bias,
|
| 189 |
+
f"{diffusers_attention_prefix}.to_v.weight": v_weight,
|
| 190 |
+
f"{diffusers_attention_prefix}.to_v.bias": v_bias,
|
| 191 |
+
}
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# <original>.c_proj -> <diffusers>.to_out.0
|
| 195 |
+
diffusers_checkpoint.update(
|
| 196 |
+
{
|
| 197 |
+
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"],
|
| 198 |
+
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"],
|
| 199 |
+
}
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
return diffusers_checkpoint
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix):
|
| 206 |
+
diffusers_checkpoint = {
|
| 207 |
+
# <original>.c_fc -> <diffusers>.net.0.proj
|
| 208 |
+
f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"],
|
| 209 |
+
f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"],
|
| 210 |
+
# <original>.c_proj -> <diffusers>.net.2
|
| 211 |
+
f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"],
|
| 212 |
+
f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"],
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
return diffusers_checkpoint
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# done prior
|
| 219 |
+
|
| 220 |
+
# unet
|
| 221 |
+
|
| 222 |
+
# We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can
|
| 223 |
+
# update then.
|
| 224 |
+
|
| 225 |
+
UNET_CONFIG = {
|
| 226 |
+
"act_fn": "silu",
|
| 227 |
+
"addition_embed_type": "text_image",
|
| 228 |
+
"addition_embed_type_num_heads": 64,
|
| 229 |
+
"attention_head_dim": 64,
|
| 230 |
+
"block_out_channels": [384, 768, 1152, 1536],
|
| 231 |
+
"center_input_sample": False,
|
| 232 |
+
"class_embed_type": None,
|
| 233 |
+
"class_embeddings_concat": False,
|
| 234 |
+
"conv_in_kernel": 3,
|
| 235 |
+
"conv_out_kernel": 3,
|
| 236 |
+
"cross_attention_dim": 768,
|
| 237 |
+
"cross_attention_norm": None,
|
| 238 |
+
"down_block_types": [
|
| 239 |
+
"ResnetDownsampleBlock2D",
|
| 240 |
+
"SimpleCrossAttnDownBlock2D",
|
| 241 |
+
"SimpleCrossAttnDownBlock2D",
|
| 242 |
+
"SimpleCrossAttnDownBlock2D",
|
| 243 |
+
],
|
| 244 |
+
"downsample_padding": 1,
|
| 245 |
+
"dual_cross_attention": False,
|
| 246 |
+
"encoder_hid_dim": 1024,
|
| 247 |
+
"encoder_hid_dim_type": "text_image_proj",
|
| 248 |
+
"flip_sin_to_cos": True,
|
| 249 |
+
"freq_shift": 0,
|
| 250 |
+
"in_channels": 4,
|
| 251 |
+
"layers_per_block": 3,
|
| 252 |
+
"mid_block_only_cross_attention": None,
|
| 253 |
+
"mid_block_scale_factor": 1,
|
| 254 |
+
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
|
| 255 |
+
"norm_eps": 1e-05,
|
| 256 |
+
"norm_num_groups": 32,
|
| 257 |
+
"num_class_embeds": None,
|
| 258 |
+
"only_cross_attention": False,
|
| 259 |
+
"out_channels": 8,
|
| 260 |
+
"projection_class_embeddings_input_dim": None,
|
| 261 |
+
"resnet_out_scale_factor": 1.0,
|
| 262 |
+
"resnet_skip_time_act": False,
|
| 263 |
+
"resnet_time_scale_shift": "scale_shift",
|
| 264 |
+
"sample_size": 64,
|
| 265 |
+
"time_cond_proj_dim": None,
|
| 266 |
+
"time_embedding_act_fn": None,
|
| 267 |
+
"time_embedding_dim": None,
|
| 268 |
+
"time_embedding_type": "positional",
|
| 269 |
+
"timestep_post_act": None,
|
| 270 |
+
"up_block_types": [
|
| 271 |
+
"SimpleCrossAttnUpBlock2D",
|
| 272 |
+
"SimpleCrossAttnUpBlock2D",
|
| 273 |
+
"SimpleCrossAttnUpBlock2D",
|
| 274 |
+
"ResnetUpsampleBlock2D",
|
| 275 |
+
],
|
| 276 |
+
"upcast_attention": False,
|
| 277 |
+
"use_linear_projection": False,
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def unet_model_from_original_config():
|
| 282 |
+
model = UNet2DConditionModel(**UNET_CONFIG)
|
| 283 |
+
|
| 284 |
+
return model
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
| 288 |
+
diffusers_checkpoint = {}
|
| 289 |
+
|
| 290 |
+
num_head_channels = UNET_CONFIG["attention_head_dim"]
|
| 291 |
+
|
| 292 |
+
diffusers_checkpoint.update(unet_time_embeddings(checkpoint))
|
| 293 |
+
diffusers_checkpoint.update(unet_conv_in(checkpoint))
|
| 294 |
+
diffusers_checkpoint.update(unet_add_embedding(checkpoint))
|
| 295 |
+
diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint))
|
| 296 |
+
|
| 297 |
+
# <original>.input_blocks -> <diffusers>.down_blocks
|
| 298 |
+
|
| 299 |
+
original_down_block_idx = 1
|
| 300 |
+
|
| 301 |
+
for diffusers_down_block_idx in range(len(model.down_blocks)):
|
| 302 |
+
checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
|
| 303 |
+
model,
|
| 304 |
+
checkpoint,
|
| 305 |
+
diffusers_down_block_idx=diffusers_down_block_idx,
|
| 306 |
+
original_down_block_idx=original_down_block_idx,
|
| 307 |
+
num_head_channels=num_head_channels,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
original_down_block_idx += num_original_down_blocks
|
| 311 |
+
|
| 312 |
+
diffusers_checkpoint.update(checkpoint_update)
|
| 313 |
+
|
| 314 |
+
# done <original>.input_blocks -> <diffusers>.down_blocks
|
| 315 |
+
|
| 316 |
+
diffusers_checkpoint.update(
|
| 317 |
+
unet_midblock_to_diffusers_checkpoint(
|
| 318 |
+
model,
|
| 319 |
+
checkpoint,
|
| 320 |
+
num_head_channels=num_head_channels,
|
| 321 |
+
)
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# <original>.output_blocks -> <diffusers>.up_blocks
|
| 325 |
+
|
| 326 |
+
original_up_block_idx = 0
|
| 327 |
+
|
| 328 |
+
for diffusers_up_block_idx in range(len(model.up_blocks)):
|
| 329 |
+
checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
|
| 330 |
+
model,
|
| 331 |
+
checkpoint,
|
| 332 |
+
diffusers_up_block_idx=diffusers_up_block_idx,
|
| 333 |
+
original_up_block_idx=original_up_block_idx,
|
| 334 |
+
num_head_channels=num_head_channels,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
original_up_block_idx += num_original_up_blocks
|
| 338 |
+
|
| 339 |
+
diffusers_checkpoint.update(checkpoint_update)
|
| 340 |
+
|
| 341 |
+
# done <original>.output_blocks -> <diffusers>.up_blocks
|
| 342 |
+
|
| 343 |
+
diffusers_checkpoint.update(unet_conv_norm_out(checkpoint))
|
| 344 |
+
diffusers_checkpoint.update(unet_conv_out(checkpoint))
|
| 345 |
+
|
| 346 |
+
return diffusers_checkpoint
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# done unet
|
| 350 |
+
|
| 351 |
+
# inpaint unet
|
| 352 |
+
|
| 353 |
+
# We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can
|
| 354 |
+
# update then.
|
| 355 |
+
|
| 356 |
+
INPAINT_UNET_CONFIG = {
|
| 357 |
+
"act_fn": "silu",
|
| 358 |
+
"addition_embed_type": "text_image",
|
| 359 |
+
"addition_embed_type_num_heads": 64,
|
| 360 |
+
"attention_head_dim": 64,
|
| 361 |
+
"block_out_channels": [384, 768, 1152, 1536],
|
| 362 |
+
"center_input_sample": False,
|
| 363 |
+
"class_embed_type": None,
|
| 364 |
+
"class_embeddings_concat": None,
|
| 365 |
+
"conv_in_kernel": 3,
|
| 366 |
+
"conv_out_kernel": 3,
|
| 367 |
+
"cross_attention_dim": 768,
|
| 368 |
+
"cross_attention_norm": None,
|
| 369 |
+
"down_block_types": [
|
| 370 |
+
"ResnetDownsampleBlock2D",
|
| 371 |
+
"SimpleCrossAttnDownBlock2D",
|
| 372 |
+
"SimpleCrossAttnDownBlock2D",
|
| 373 |
+
"SimpleCrossAttnDownBlock2D",
|
| 374 |
+
],
|
| 375 |
+
"downsample_padding": 1,
|
| 376 |
+
"dual_cross_attention": False,
|
| 377 |
+
"encoder_hid_dim": 1024,
|
| 378 |
+
"encoder_hid_dim_type": "text_image_proj",
|
| 379 |
+
"flip_sin_to_cos": True,
|
| 380 |
+
"freq_shift": 0,
|
| 381 |
+
"in_channels": 9,
|
| 382 |
+
"layers_per_block": 3,
|
| 383 |
+
"mid_block_only_cross_attention": None,
|
| 384 |
+
"mid_block_scale_factor": 1,
|
| 385 |
+
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
|
| 386 |
+
"norm_eps": 1e-05,
|
| 387 |
+
"norm_num_groups": 32,
|
| 388 |
+
"num_class_embeds": None,
|
| 389 |
+
"only_cross_attention": False,
|
| 390 |
+
"out_channels": 8,
|
| 391 |
+
"projection_class_embeddings_input_dim": None,
|
| 392 |
+
"resnet_out_scale_factor": 1.0,
|
| 393 |
+
"resnet_skip_time_act": False,
|
| 394 |
+
"resnet_time_scale_shift": "scale_shift",
|
| 395 |
+
"sample_size": 64,
|
| 396 |
+
"time_cond_proj_dim": None,
|
| 397 |
+
"time_embedding_act_fn": None,
|
| 398 |
+
"time_embedding_dim": None,
|
| 399 |
+
"time_embedding_type": "positional",
|
| 400 |
+
"timestep_post_act": None,
|
| 401 |
+
"up_block_types": [
|
| 402 |
+
"SimpleCrossAttnUpBlock2D",
|
| 403 |
+
"SimpleCrossAttnUpBlock2D",
|
| 404 |
+
"SimpleCrossAttnUpBlock2D",
|
| 405 |
+
"ResnetUpsampleBlock2D",
|
| 406 |
+
],
|
| 407 |
+
"upcast_attention": False,
|
| 408 |
+
"use_linear_projection": False,
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def inpaint_unet_model_from_original_config():
|
| 413 |
+
model = UNet2DConditionModel(**INPAINT_UNET_CONFIG)
|
| 414 |
+
|
| 415 |
+
return model
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
| 419 |
+
diffusers_checkpoint = {}
|
| 420 |
+
|
| 421 |
+
num_head_channels = INPAINT_UNET_CONFIG["attention_head_dim"]
|
| 422 |
+
|
| 423 |
+
diffusers_checkpoint.update(unet_time_embeddings(checkpoint))
|
| 424 |
+
diffusers_checkpoint.update(unet_conv_in(checkpoint))
|
| 425 |
+
diffusers_checkpoint.update(unet_add_embedding(checkpoint))
|
| 426 |
+
diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint))
|
| 427 |
+
|
| 428 |
+
# <original>.input_blocks -> <diffusers>.down_blocks
|
| 429 |
+
|
| 430 |
+
original_down_block_idx = 1
|
| 431 |
+
|
| 432 |
+
for diffusers_down_block_idx in range(len(model.down_blocks)):
|
| 433 |
+
checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint(
|
| 434 |
+
model,
|
| 435 |
+
checkpoint,
|
| 436 |
+
diffusers_down_block_idx=diffusers_down_block_idx,
|
| 437 |
+
original_down_block_idx=original_down_block_idx,
|
| 438 |
+
num_head_channels=num_head_channels,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
original_down_block_idx += num_original_down_blocks
|
| 442 |
+
|
| 443 |
+
diffusers_checkpoint.update(checkpoint_update)
|
| 444 |
+
|
| 445 |
+
# done <original>.input_blocks -> <diffusers>.down_blocks
|
| 446 |
+
|
| 447 |
+
diffusers_checkpoint.update(
|
| 448 |
+
unet_midblock_to_diffusers_checkpoint(
|
| 449 |
+
model,
|
| 450 |
+
checkpoint,
|
| 451 |
+
num_head_channels=num_head_channels,
|
| 452 |
+
)
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
# <original>.output_blocks -> <diffusers>.up_blocks
|
| 456 |
+
|
| 457 |
+
original_up_block_idx = 0
|
| 458 |
+
|
| 459 |
+
for diffusers_up_block_idx in range(len(model.up_blocks)):
|
| 460 |
+
checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint(
|
| 461 |
+
model,
|
| 462 |
+
checkpoint,
|
| 463 |
+
diffusers_up_block_idx=diffusers_up_block_idx,
|
| 464 |
+
original_up_block_idx=original_up_block_idx,
|
| 465 |
+
num_head_channels=num_head_channels,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
original_up_block_idx += num_original_up_blocks
|
| 469 |
+
|
| 470 |
+
diffusers_checkpoint.update(checkpoint_update)
|
| 471 |
+
|
| 472 |
+
# done <original>.output_blocks -> <diffusers>.up_blocks
|
| 473 |
+
|
| 474 |
+
diffusers_checkpoint.update(unet_conv_norm_out(checkpoint))
|
| 475 |
+
diffusers_checkpoint.update(unet_conv_out(checkpoint))
|
| 476 |
+
|
| 477 |
+
return diffusers_checkpoint
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
# done inpaint unet
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
# unet utils
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
# <original>.time_embed -> <diffusers>.time_embedding
|
| 487 |
+
def unet_time_embeddings(checkpoint):
|
| 488 |
+
diffusers_checkpoint = {}
|
| 489 |
+
|
| 490 |
+
diffusers_checkpoint.update(
|
| 491 |
+
{
|
| 492 |
+
"time_embedding.linear_1.weight": checkpoint["time_embed.0.weight"],
|
| 493 |
+
"time_embedding.linear_1.bias": checkpoint["time_embed.0.bias"],
|
| 494 |
+
"time_embedding.linear_2.weight": checkpoint["time_embed.2.weight"],
|
| 495 |
+
"time_embedding.linear_2.bias": checkpoint["time_embed.2.bias"],
|
| 496 |
+
}
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
return diffusers_checkpoint
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# <original>.input_blocks.0 -> <diffusers>.conv_in
|
| 503 |
+
def unet_conv_in(checkpoint):
|
| 504 |
+
diffusers_checkpoint = {}
|
| 505 |
+
|
| 506 |
+
diffusers_checkpoint.update(
|
| 507 |
+
{
|
| 508 |
+
"conv_in.weight": checkpoint["input_blocks.0.0.weight"],
|
| 509 |
+
"conv_in.bias": checkpoint["input_blocks.0.0.bias"],
|
| 510 |
+
}
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
return diffusers_checkpoint
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def unet_add_embedding(checkpoint):
|
| 517 |
+
diffusers_checkpoint = {}
|
| 518 |
+
|
| 519 |
+
diffusers_checkpoint.update(
|
| 520 |
+
{
|
| 521 |
+
"add_embedding.text_norm.weight": checkpoint["ln_model_n.weight"],
|
| 522 |
+
"add_embedding.text_norm.bias": checkpoint["ln_model_n.bias"],
|
| 523 |
+
"add_embedding.text_proj.weight": checkpoint["proj_n.weight"],
|
| 524 |
+
"add_embedding.text_proj.bias": checkpoint["proj_n.bias"],
|
| 525 |
+
"add_embedding.image_proj.weight": checkpoint["img_layer.weight"],
|
| 526 |
+
"add_embedding.image_proj.bias": checkpoint["img_layer.bias"],
|
| 527 |
+
}
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
return diffusers_checkpoint
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def unet_encoder_hid_proj(checkpoint):
|
| 534 |
+
diffusers_checkpoint = {}
|
| 535 |
+
|
| 536 |
+
diffusers_checkpoint.update(
|
| 537 |
+
{
|
| 538 |
+
"encoder_hid_proj.image_embeds.weight": checkpoint["clip_to_seq.weight"],
|
| 539 |
+
"encoder_hid_proj.image_embeds.bias": checkpoint["clip_to_seq.bias"],
|
| 540 |
+
"encoder_hid_proj.text_proj.weight": checkpoint["to_model_dim_n.weight"],
|
| 541 |
+
"encoder_hid_proj.text_proj.bias": checkpoint["to_model_dim_n.bias"],
|
| 542 |
+
}
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
return diffusers_checkpoint
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
# <original>.out.0 -> <diffusers>.conv_norm_out
|
| 549 |
+
def unet_conv_norm_out(checkpoint):
|
| 550 |
+
diffusers_checkpoint = {}
|
| 551 |
+
|
| 552 |
+
diffusers_checkpoint.update(
|
| 553 |
+
{
|
| 554 |
+
"conv_norm_out.weight": checkpoint["out.0.weight"],
|
| 555 |
+
"conv_norm_out.bias": checkpoint["out.0.bias"],
|
| 556 |
+
}
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
return diffusers_checkpoint
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
# <original>.out.2 -> <diffusers>.conv_out
|
| 563 |
+
def unet_conv_out(checkpoint):
|
| 564 |
+
diffusers_checkpoint = {}
|
| 565 |
+
|
| 566 |
+
diffusers_checkpoint.update(
|
| 567 |
+
{
|
| 568 |
+
"conv_out.weight": checkpoint["out.2.weight"],
|
| 569 |
+
"conv_out.bias": checkpoint["out.2.bias"],
|
| 570 |
+
}
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
return diffusers_checkpoint
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
# <original>.input_blocks -> <diffusers>.down_blocks
|
| 577 |
+
def unet_downblock_to_diffusers_checkpoint(
|
| 578 |
+
model, checkpoint, *, diffusers_down_block_idx, original_down_block_idx, num_head_channels
|
| 579 |
+
):
|
| 580 |
+
diffusers_checkpoint = {}
|
| 581 |
+
|
| 582 |
+
diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.resnets"
|
| 583 |
+
original_down_block_prefix = "input_blocks"
|
| 584 |
+
|
| 585 |
+
down_block = model.down_blocks[diffusers_down_block_idx]
|
| 586 |
+
|
| 587 |
+
num_resnets = len(down_block.resnets)
|
| 588 |
+
|
| 589 |
+
if down_block.downsamplers is None:
|
| 590 |
+
downsampler = False
|
| 591 |
+
else:
|
| 592 |
+
assert len(down_block.downsamplers) == 1
|
| 593 |
+
downsampler = True
|
| 594 |
+
# The downsample block is also a resnet
|
| 595 |
+
num_resnets += 1
|
| 596 |
+
|
| 597 |
+
for resnet_idx_inc in range(num_resnets):
|
| 598 |
+
full_resnet_prefix = f"{original_down_block_prefix}.{original_down_block_idx + resnet_idx_inc}.0"
|
| 599 |
+
|
| 600 |
+
if downsampler and resnet_idx_inc == num_resnets - 1:
|
| 601 |
+
# this is a downsample block
|
| 602 |
+
full_diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.downsamplers.0"
|
| 603 |
+
else:
|
| 604 |
+
# this is a regular resnet block
|
| 605 |
+
full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}"
|
| 606 |
+
|
| 607 |
+
diffusers_checkpoint.update(
|
| 608 |
+
resnet_to_diffusers_checkpoint(
|
| 609 |
+
checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix
|
| 610 |
+
)
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
if hasattr(down_block, "attentions"):
|
| 614 |
+
num_attentions = len(down_block.attentions)
|
| 615 |
+
diffusers_attention_prefix = f"down_blocks.{diffusers_down_block_idx}.attentions"
|
| 616 |
+
|
| 617 |
+
for attention_idx_inc in range(num_attentions):
|
| 618 |
+
full_attention_prefix = f"{original_down_block_prefix}.{original_down_block_idx + attention_idx_inc}.1"
|
| 619 |
+
full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}"
|
| 620 |
+
|
| 621 |
+
diffusers_checkpoint.update(
|
| 622 |
+
attention_to_diffusers_checkpoint(
|
| 623 |
+
checkpoint,
|
| 624 |
+
attention_prefix=full_attention_prefix,
|
| 625 |
+
diffusers_attention_prefix=full_diffusers_attention_prefix,
|
| 626 |
+
num_head_channels=num_head_channels,
|
| 627 |
+
)
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
num_original_down_blocks = num_resnets
|
| 631 |
+
|
| 632 |
+
return diffusers_checkpoint, num_original_down_blocks
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
# <original>.middle_block -> <diffusers>.mid_block
|
| 636 |
+
def unet_midblock_to_diffusers_checkpoint(model, checkpoint, *, num_head_channels):
|
| 637 |
+
diffusers_checkpoint = {}
|
| 638 |
+
|
| 639 |
+
# block 0
|
| 640 |
+
|
| 641 |
+
original_block_idx = 0
|
| 642 |
+
|
| 643 |
+
diffusers_checkpoint.update(
|
| 644 |
+
resnet_to_diffusers_checkpoint(
|
| 645 |
+
checkpoint,
|
| 646 |
+
diffusers_resnet_prefix="mid_block.resnets.0",
|
| 647 |
+
resnet_prefix=f"middle_block.{original_block_idx}",
|
| 648 |
+
)
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
original_block_idx += 1
|
| 652 |
+
|
| 653 |
+
# optional block 1
|
| 654 |
+
|
| 655 |
+
if hasattr(model.mid_block, "attentions") and model.mid_block.attentions[0] is not None:
|
| 656 |
+
diffusers_checkpoint.update(
|
| 657 |
+
attention_to_diffusers_checkpoint(
|
| 658 |
+
checkpoint,
|
| 659 |
+
diffusers_attention_prefix="mid_block.attentions.0",
|
| 660 |
+
attention_prefix=f"middle_block.{original_block_idx}",
|
| 661 |
+
num_head_channels=num_head_channels,
|
| 662 |
+
)
|
| 663 |
+
)
|
| 664 |
+
original_block_idx += 1
|
| 665 |
+
|
| 666 |
+
# block 1 or block 2
|
| 667 |
+
|
| 668 |
+
diffusers_checkpoint.update(
|
| 669 |
+
resnet_to_diffusers_checkpoint(
|
| 670 |
+
checkpoint,
|
| 671 |
+
diffusers_resnet_prefix="mid_block.resnets.1",
|
| 672 |
+
resnet_prefix=f"middle_block.{original_block_idx}",
|
| 673 |
+
)
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
return diffusers_checkpoint
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
# <original>.output_blocks -> <diffusers>.up_blocks
|
| 680 |
+
def unet_upblock_to_diffusers_checkpoint(
|
| 681 |
+
model, checkpoint, *, diffusers_up_block_idx, original_up_block_idx, num_head_channels
|
| 682 |
+
):
|
| 683 |
+
diffusers_checkpoint = {}
|
| 684 |
+
|
| 685 |
+
diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.resnets"
|
| 686 |
+
original_up_block_prefix = "output_blocks"
|
| 687 |
+
|
| 688 |
+
up_block = model.up_blocks[diffusers_up_block_idx]
|
| 689 |
+
|
| 690 |
+
num_resnets = len(up_block.resnets)
|
| 691 |
+
|
| 692 |
+
if up_block.upsamplers is None:
|
| 693 |
+
upsampler = False
|
| 694 |
+
else:
|
| 695 |
+
assert len(up_block.upsamplers) == 1
|
| 696 |
+
upsampler = True
|
| 697 |
+
# The upsample block is also a resnet
|
| 698 |
+
num_resnets += 1
|
| 699 |
+
|
| 700 |
+
has_attentions = hasattr(up_block, "attentions")
|
| 701 |
+
|
| 702 |
+
for resnet_idx_inc in range(num_resnets):
|
| 703 |
+
if upsampler and resnet_idx_inc == num_resnets - 1:
|
| 704 |
+
# this is an upsample block
|
| 705 |
+
if has_attentions:
|
| 706 |
+
# There is a middle attention block that we skip
|
| 707 |
+
original_resnet_block_idx = 2
|
| 708 |
+
else:
|
| 709 |
+
original_resnet_block_idx = 1
|
| 710 |
+
|
| 711 |
+
# we add the `minus 1` because the last two resnets are stuck together in the same output block
|
| 712 |
+
full_resnet_prefix = (
|
| 713 |
+
f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc - 1}.{original_resnet_block_idx}"
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
full_diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.upsamplers.0"
|
| 717 |
+
else:
|
| 718 |
+
# this is a regular resnet block
|
| 719 |
+
full_resnet_prefix = f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc}.0"
|
| 720 |
+
full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}"
|
| 721 |
+
|
| 722 |
+
diffusers_checkpoint.update(
|
| 723 |
+
resnet_to_diffusers_checkpoint(
|
| 724 |
+
checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix
|
| 725 |
+
)
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
if has_attentions:
|
| 729 |
+
num_attentions = len(up_block.attentions)
|
| 730 |
+
diffusers_attention_prefix = f"up_blocks.{diffusers_up_block_idx}.attentions"
|
| 731 |
+
|
| 732 |
+
for attention_idx_inc in range(num_attentions):
|
| 733 |
+
full_attention_prefix = f"{original_up_block_prefix}.{original_up_block_idx + attention_idx_inc}.1"
|
| 734 |
+
full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}"
|
| 735 |
+
|
| 736 |
+
diffusers_checkpoint.update(
|
| 737 |
+
attention_to_diffusers_checkpoint(
|
| 738 |
+
checkpoint,
|
| 739 |
+
attention_prefix=full_attention_prefix,
|
| 740 |
+
diffusers_attention_prefix=full_diffusers_attention_prefix,
|
| 741 |
+
num_head_channels=num_head_channels,
|
| 742 |
+
)
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
num_original_down_blocks = num_resnets - 1 if upsampler else num_resnets
|
| 746 |
+
|
| 747 |
+
return diffusers_checkpoint, num_original_down_blocks
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def resnet_to_diffusers_checkpoint(checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
|
| 751 |
+
diffusers_checkpoint = {
|
| 752 |
+
f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.in_layers.0.weight"],
|
| 753 |
+
f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.in_layers.0.bias"],
|
| 754 |
+
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.in_layers.2.weight"],
|
| 755 |
+
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.in_layers.2.bias"],
|
| 756 |
+
f"{diffusers_resnet_prefix}.time_emb_proj.weight": checkpoint[f"{resnet_prefix}.emb_layers.1.weight"],
|
| 757 |
+
f"{diffusers_resnet_prefix}.time_emb_proj.bias": checkpoint[f"{resnet_prefix}.emb_layers.1.bias"],
|
| 758 |
+
f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.out_layers.0.weight"],
|
| 759 |
+
f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.out_layers.0.bias"],
|
| 760 |
+
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.out_layers.3.weight"],
|
| 761 |
+
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.out_layers.3.bias"],
|
| 762 |
+
}
|
| 763 |
+
|
| 764 |
+
skip_connection_prefix = f"{resnet_prefix}.skip_connection"
|
| 765 |
+
|
| 766 |
+
if f"{skip_connection_prefix}.weight" in checkpoint:
|
| 767 |
+
diffusers_checkpoint.update(
|
| 768 |
+
{
|
| 769 |
+
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{skip_connection_prefix}.weight"],
|
| 770 |
+
f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{skip_connection_prefix}.bias"],
|
| 771 |
+
}
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
return diffusers_checkpoint
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
def attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix, num_head_channels):
|
| 778 |
+
diffusers_checkpoint = {}
|
| 779 |
+
|
| 780 |
+
# <original>.norm -> <diffusers>.group_norm
|
| 781 |
+
diffusers_checkpoint.update(
|
| 782 |
+
{
|
| 783 |
+
f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],
|
| 784 |
+
f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],
|
| 785 |
+
}
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
# <original>.qkv -> <diffusers>.{query, key, value}
|
| 789 |
+
[q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions(
|
| 790 |
+
weight=checkpoint[f"{attention_prefix}.qkv.weight"][:, :, 0],
|
| 791 |
+
bias=checkpoint[f"{attention_prefix}.qkv.bias"],
|
| 792 |
+
split=3,
|
| 793 |
+
chunk_size=num_head_channels,
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
diffusers_checkpoint.update(
|
| 797 |
+
{
|
| 798 |
+
f"{diffusers_attention_prefix}.to_q.weight": q_weight,
|
| 799 |
+
f"{diffusers_attention_prefix}.to_q.bias": q_bias,
|
| 800 |
+
f"{diffusers_attention_prefix}.to_k.weight": k_weight,
|
| 801 |
+
f"{diffusers_attention_prefix}.to_k.bias": k_bias,
|
| 802 |
+
f"{diffusers_attention_prefix}.to_v.weight": v_weight,
|
| 803 |
+
f"{diffusers_attention_prefix}.to_v.bias": v_bias,
|
| 804 |
+
}
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
# <original>.encoder_kv -> <diffusers>.{context_key, context_value}
|
| 808 |
+
[encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions(
|
| 809 |
+
weight=checkpoint[f"{attention_prefix}.encoder_kv.weight"][:, :, 0],
|
| 810 |
+
bias=checkpoint[f"{attention_prefix}.encoder_kv.bias"],
|
| 811 |
+
split=2,
|
| 812 |
+
chunk_size=num_head_channels,
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
diffusers_checkpoint.update(
|
| 816 |
+
{
|
| 817 |
+
f"{diffusers_attention_prefix}.add_k_proj.weight": encoder_k_weight,
|
| 818 |
+
f"{diffusers_attention_prefix}.add_k_proj.bias": encoder_k_bias,
|
| 819 |
+
f"{diffusers_attention_prefix}.add_v_proj.weight": encoder_v_weight,
|
| 820 |
+
f"{diffusers_attention_prefix}.add_v_proj.bias": encoder_v_bias,
|
| 821 |
+
}
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
# <original>.proj_out (1d conv) -> <diffusers>.proj_attn (linear)
|
| 825 |
+
diffusers_checkpoint.update(
|
| 826 |
+
{
|
| 827 |
+
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][
|
| 828 |
+
:, :, 0
|
| 829 |
+
],
|
| 830 |
+
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
|
| 831 |
+
}
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
return diffusers_checkpoint
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
|
| 838 |
+
def split_attentions(*, weight, bias, split, chunk_size):
|
| 839 |
+
weights = [None] * split
|
| 840 |
+
biases = [None] * split
|
| 841 |
+
|
| 842 |
+
weights_biases_idx = 0
|
| 843 |
+
|
| 844 |
+
for starting_row_index in range(0, weight.shape[0], chunk_size):
|
| 845 |
+
row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
|
| 846 |
+
|
| 847 |
+
weight_rows = weight[row_indices, :]
|
| 848 |
+
bias_rows = bias[row_indices]
|
| 849 |
+
|
| 850 |
+
if weights[weights_biases_idx] is None:
|
| 851 |
+
assert weights[weights_biases_idx] is None
|
| 852 |
+
weights[weights_biases_idx] = weight_rows
|
| 853 |
+
biases[weights_biases_idx] = bias_rows
|
| 854 |
+
else:
|
| 855 |
+
assert weights[weights_biases_idx] is not None
|
| 856 |
+
weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
|
| 857 |
+
biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
|
| 858 |
+
|
| 859 |
+
weights_biases_idx = (weights_biases_idx + 1) % split
|
| 860 |
+
|
| 861 |
+
return weights, biases
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
# done unet utils
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
def prior(*, args, checkpoint_map_location):
|
| 868 |
+
print("loading prior")
|
| 869 |
+
|
| 870 |
+
prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location)
|
| 871 |
+
|
| 872 |
+
clip_stats_checkpoint = torch.load(args.clip_stat_path, map_location=checkpoint_map_location)
|
| 873 |
+
|
| 874 |
+
prior_model = prior_model_from_original_config()
|
| 875 |
+
|
| 876 |
+
prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint(
|
| 877 |
+
prior_model, prior_checkpoint, clip_stats_checkpoint
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
del prior_checkpoint
|
| 881 |
+
del clip_stats_checkpoint
|
| 882 |
+
|
| 883 |
+
load_checkpoint_to_model(prior_diffusers_checkpoint, prior_model, strict=True)
|
| 884 |
+
|
| 885 |
+
print("done loading prior")
|
| 886 |
+
|
| 887 |
+
return prior_model
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
def text2img(*, args, checkpoint_map_location):
|
| 891 |
+
print("loading text2img")
|
| 892 |
+
|
| 893 |
+
text2img_checkpoint = torch.load(args.text2img_checkpoint_path, map_location=checkpoint_map_location)
|
| 894 |
+
|
| 895 |
+
unet_model = unet_model_from_original_config()
|
| 896 |
+
|
| 897 |
+
unet_diffusers_checkpoint = unet_original_checkpoint_to_diffusers_checkpoint(unet_model, text2img_checkpoint)
|
| 898 |
+
|
| 899 |
+
del text2img_checkpoint
|
| 900 |
+
|
| 901 |
+
load_checkpoint_to_model(unet_diffusers_checkpoint, unet_model, strict=True)
|
| 902 |
+
|
| 903 |
+
print("done loading text2img")
|
| 904 |
+
|
| 905 |
+
return unet_model
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
def inpaint_text2img(*, args, checkpoint_map_location):
|
| 909 |
+
print("loading inpaint text2img")
|
| 910 |
+
|
| 911 |
+
inpaint_text2img_checkpoint = torch.load(
|
| 912 |
+
args.inpaint_text2img_checkpoint_path, map_location=checkpoint_map_location
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
inpaint_unet_model = inpaint_unet_model_from_original_config()
|
| 916 |
+
|
| 917 |
+
inpaint_unet_diffusers_checkpoint = inpaint_unet_original_checkpoint_to_diffusers_checkpoint(
|
| 918 |
+
inpaint_unet_model, inpaint_text2img_checkpoint
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
del inpaint_text2img_checkpoint
|
| 922 |
+
|
| 923 |
+
load_checkpoint_to_model(inpaint_unet_diffusers_checkpoint, inpaint_unet_model, strict=True)
|
| 924 |
+
|
| 925 |
+
print("done loading inpaint text2img")
|
| 926 |
+
|
| 927 |
+
return inpaint_unet_model
|
| 928 |
+
|
| 929 |
+
|
| 930 |
+
# movq
|
| 931 |
+
|
| 932 |
+
MOVQ_CONFIG = {
|
| 933 |
+
"in_channels": 3,
|
| 934 |
+
"out_channels": 3,
|
| 935 |
+
"latent_channels": 4,
|
| 936 |
+
"down_block_types": ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "AttnDownEncoderBlock2D"),
|
| 937 |
+
"up_block_types": ("AttnUpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
|
| 938 |
+
"num_vq_embeddings": 16384,
|
| 939 |
+
"block_out_channels": (128, 256, 256, 512),
|
| 940 |
+
"vq_embed_dim": 4,
|
| 941 |
+
"layers_per_block": 2,
|
| 942 |
+
"norm_type": "spatial",
|
| 943 |
+
}
|
| 944 |
+
|
| 945 |
+
|
| 946 |
+
def movq_model_from_original_config():
|
| 947 |
+
movq = VQModel(**MOVQ_CONFIG)
|
| 948 |
+
return movq
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
def movq_encoder_to_diffusers_checkpoint(model, checkpoint):
|
| 952 |
+
diffusers_checkpoint = {}
|
| 953 |
+
|
| 954 |
+
# conv_in
|
| 955 |
+
diffusers_checkpoint.update(
|
| 956 |
+
{
|
| 957 |
+
"encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"],
|
| 958 |
+
"encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"],
|
| 959 |
+
}
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
# down_blocks
|
| 963 |
+
for down_block_idx, down_block in enumerate(model.encoder.down_blocks):
|
| 964 |
+
diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}"
|
| 965 |
+
down_block_prefix = f"encoder.down.{down_block_idx}"
|
| 966 |
+
|
| 967 |
+
# resnets
|
| 968 |
+
for resnet_idx, resnet in enumerate(down_block.resnets):
|
| 969 |
+
diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}"
|
| 970 |
+
resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}"
|
| 971 |
+
|
| 972 |
+
diffusers_checkpoint.update(
|
| 973 |
+
movq_resnet_to_diffusers_checkpoint(
|
| 974 |
+
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
| 975 |
+
)
|
| 976 |
+
)
|
| 977 |
+
|
| 978 |
+
# downsample
|
| 979 |
+
|
| 980 |
+
# do not include the downsample when on the last down block
|
| 981 |
+
# There is no downsample on the last down block
|
| 982 |
+
if down_block_idx != len(model.encoder.down_blocks) - 1:
|
| 983 |
+
# There's a single downsample in the original checkpoint but a list of downsamples
|
| 984 |
+
# in the diffusers model.
|
| 985 |
+
diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv"
|
| 986 |
+
downsample_prefix = f"{down_block_prefix}.downsample.conv"
|
| 987 |
+
diffusers_checkpoint.update(
|
| 988 |
+
{
|
| 989 |
+
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
|
| 990 |
+
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
|
| 991 |
+
}
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
# attentions
|
| 995 |
+
|
| 996 |
+
if hasattr(down_block, "attentions"):
|
| 997 |
+
for attention_idx, _ in enumerate(down_block.attentions):
|
| 998 |
+
diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}"
|
| 999 |
+
attention_prefix = f"{down_block_prefix}.attn.{attention_idx}"
|
| 1000 |
+
diffusers_checkpoint.update(
|
| 1001 |
+
movq_attention_to_diffusers_checkpoint(
|
| 1002 |
+
checkpoint,
|
| 1003 |
+
diffusers_attention_prefix=diffusers_attention_prefix,
|
| 1004 |
+
attention_prefix=attention_prefix,
|
| 1005 |
+
)
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
+
# mid block
|
| 1009 |
+
|
| 1010 |
+
# mid block attentions
|
| 1011 |
+
|
| 1012 |
+
# There is a single hardcoded attention block in the middle of the VQ-diffusion encoder
|
| 1013 |
+
diffusers_attention_prefix = "encoder.mid_block.attentions.0"
|
| 1014 |
+
attention_prefix = "encoder.mid.attn_1"
|
| 1015 |
+
diffusers_checkpoint.update(
|
| 1016 |
+
movq_attention_to_diffusers_checkpoint(
|
| 1017 |
+
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
| 1018 |
+
)
|
| 1019 |
+
)
|
| 1020 |
+
|
| 1021 |
+
# mid block resnets
|
| 1022 |
+
|
| 1023 |
+
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
|
| 1024 |
+
diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}"
|
| 1025 |
+
|
| 1026 |
+
# the hardcoded prefixes to `block_` are 1 and 2
|
| 1027 |
+
orig_resnet_idx = diffusers_resnet_idx + 1
|
| 1028 |
+
# There are two hardcoded resnets in the middle of the VQ-diffusion encoder
|
| 1029 |
+
resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}"
|
| 1030 |
+
|
| 1031 |
+
diffusers_checkpoint.update(
|
| 1032 |
+
movq_resnet_to_diffusers_checkpoint(
|
| 1033 |
+
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
| 1034 |
+
)
|
| 1035 |
+
)
|
| 1036 |
+
|
| 1037 |
+
diffusers_checkpoint.update(
|
| 1038 |
+
{
|
| 1039 |
+
# conv_norm_out
|
| 1040 |
+
"encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"],
|
| 1041 |
+
"encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"],
|
| 1042 |
+
# conv_out
|
| 1043 |
+
"encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"],
|
| 1044 |
+
"encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"],
|
| 1045 |
+
}
|
| 1046 |
+
)
|
| 1047 |
+
|
| 1048 |
+
return diffusers_checkpoint
|
| 1049 |
+
|
| 1050 |
+
|
| 1051 |
+
def movq_decoder_to_diffusers_checkpoint(model, checkpoint):
|
| 1052 |
+
diffusers_checkpoint = {}
|
| 1053 |
+
|
| 1054 |
+
# conv in
|
| 1055 |
+
diffusers_checkpoint.update(
|
| 1056 |
+
{
|
| 1057 |
+
"decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"],
|
| 1058 |
+
"decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"],
|
| 1059 |
+
}
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
# up_blocks
|
| 1063 |
+
|
| 1064 |
+
for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks):
|
| 1065 |
+
# up_blocks are stored in reverse order in the VQ-diffusion checkpoint
|
| 1066 |
+
orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx
|
| 1067 |
+
|
| 1068 |
+
diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}"
|
| 1069 |
+
up_block_prefix = f"decoder.up.{orig_up_block_idx}"
|
| 1070 |
+
|
| 1071 |
+
# resnets
|
| 1072 |
+
for resnet_idx, resnet in enumerate(up_block.resnets):
|
| 1073 |
+
diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
|
| 1074 |
+
resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}"
|
| 1075 |
+
|
| 1076 |
+
diffusers_checkpoint.update(
|
| 1077 |
+
movq_resnet_to_diffusers_checkpoint_spatial_norm(
|
| 1078 |
+
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
| 1079 |
+
)
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
# upsample
|
| 1083 |
+
|
| 1084 |
+
# there is no up sample on the last up block
|
| 1085 |
+
if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1:
|
| 1086 |
+
# There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples
|
| 1087 |
+
# in the diffusers model.
|
| 1088 |
+
diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv"
|
| 1089 |
+
downsample_prefix = f"{up_block_prefix}.upsample.conv"
|
| 1090 |
+
diffusers_checkpoint.update(
|
| 1091 |
+
{
|
| 1092 |
+
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
|
| 1093 |
+
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
|
| 1094 |
+
}
|
| 1095 |
+
)
|
| 1096 |
+
|
| 1097 |
+
# attentions
|
| 1098 |
+
|
| 1099 |
+
if hasattr(up_block, "attentions"):
|
| 1100 |
+
for attention_idx, _ in enumerate(up_block.attentions):
|
| 1101 |
+
diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}"
|
| 1102 |
+
attention_prefix = f"{up_block_prefix}.attn.{attention_idx}"
|
| 1103 |
+
diffusers_checkpoint.update(
|
| 1104 |
+
movq_attention_to_diffusers_checkpoint_spatial_norm(
|
| 1105 |
+
checkpoint,
|
| 1106 |
+
diffusers_attention_prefix=diffusers_attention_prefix,
|
| 1107 |
+
attention_prefix=attention_prefix,
|
| 1108 |
+
)
|
| 1109 |
+
)
|
| 1110 |
+
|
| 1111 |
+
# mid block
|
| 1112 |
+
|
| 1113 |
+
# mid block attentions
|
| 1114 |
+
|
| 1115 |
+
# There is a single hardcoded attention block in the middle of the VQ-diffusion decoder
|
| 1116 |
+
diffusers_attention_prefix = "decoder.mid_block.attentions.0"
|
| 1117 |
+
attention_prefix = "decoder.mid.attn_1"
|
| 1118 |
+
diffusers_checkpoint.update(
|
| 1119 |
+
movq_attention_to_diffusers_checkpoint_spatial_norm(
|
| 1120 |
+
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
| 1121 |
+
)
|
| 1122 |
+
)
|
| 1123 |
+
|
| 1124 |
+
# mid block resnets
|
| 1125 |
+
|
| 1126 |
+
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
|
| 1127 |
+
diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}"
|
| 1128 |
+
|
| 1129 |
+
# the hardcoded prefixes to `block_` are 1 and 2
|
| 1130 |
+
orig_resnet_idx = diffusers_resnet_idx + 1
|
| 1131 |
+
# There are two hardcoded resnets in the middle of the VQ-diffusion decoder
|
| 1132 |
+
resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}"
|
| 1133 |
+
|
| 1134 |
+
diffusers_checkpoint.update(
|
| 1135 |
+
movq_resnet_to_diffusers_checkpoint_spatial_norm(
|
| 1136 |
+
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
| 1137 |
+
)
|
| 1138 |
+
)
|
| 1139 |
+
|
| 1140 |
+
diffusers_checkpoint.update(
|
| 1141 |
+
{
|
| 1142 |
+
# conv_norm_out
|
| 1143 |
+
"decoder.conv_norm_out.norm_layer.weight": checkpoint["decoder.norm_out.norm_layer.weight"],
|
| 1144 |
+
"decoder.conv_norm_out.norm_layer.bias": checkpoint["decoder.norm_out.norm_layer.bias"],
|
| 1145 |
+
"decoder.conv_norm_out.conv_y.weight": checkpoint["decoder.norm_out.conv_y.weight"],
|
| 1146 |
+
"decoder.conv_norm_out.conv_y.bias": checkpoint["decoder.norm_out.conv_y.bias"],
|
| 1147 |
+
"decoder.conv_norm_out.conv_b.weight": checkpoint["decoder.norm_out.conv_b.weight"],
|
| 1148 |
+
"decoder.conv_norm_out.conv_b.bias": checkpoint["decoder.norm_out.conv_b.bias"],
|
| 1149 |
+
# conv_out
|
| 1150 |
+
"decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"],
|
| 1151 |
+
"decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"],
|
| 1152 |
+
}
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
return diffusers_checkpoint
|
| 1156 |
+
|
| 1157 |
+
|
| 1158 |
+
def movq_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
|
| 1159 |
+
rv = {
|
| 1160 |
+
# norm1
|
| 1161 |
+
f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"],
|
| 1162 |
+
f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"],
|
| 1163 |
+
# conv1
|
| 1164 |
+
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"],
|
| 1165 |
+
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"],
|
| 1166 |
+
# norm2
|
| 1167 |
+
f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"],
|
| 1168 |
+
f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"],
|
| 1169 |
+
# conv2
|
| 1170 |
+
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"],
|
| 1171 |
+
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"],
|
| 1172 |
+
}
|
| 1173 |
+
|
| 1174 |
+
if resnet.conv_shortcut is not None:
|
| 1175 |
+
rv.update(
|
| 1176 |
+
{
|
| 1177 |
+
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"],
|
| 1178 |
+
f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"],
|
| 1179 |
+
}
|
| 1180 |
+
)
|
| 1181 |
+
|
| 1182 |
+
return rv
|
| 1183 |
+
|
| 1184 |
+
|
| 1185 |
+
def movq_resnet_to_diffusers_checkpoint_spatial_norm(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
|
| 1186 |
+
rv = {
|
| 1187 |
+
# norm1
|
| 1188 |
+
f"{diffusers_resnet_prefix}.norm1.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm1.norm_layer.weight"],
|
| 1189 |
+
f"{diffusers_resnet_prefix}.norm1.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm1.norm_layer.bias"],
|
| 1190 |
+
f"{diffusers_resnet_prefix}.norm1.conv_y.weight": checkpoint[f"{resnet_prefix}.norm1.conv_y.weight"],
|
| 1191 |
+
f"{diffusers_resnet_prefix}.norm1.conv_y.bias": checkpoint[f"{resnet_prefix}.norm1.conv_y.bias"],
|
| 1192 |
+
f"{diffusers_resnet_prefix}.norm1.conv_b.weight": checkpoint[f"{resnet_prefix}.norm1.conv_b.weight"],
|
| 1193 |
+
f"{diffusers_resnet_prefix}.norm1.conv_b.bias": checkpoint[f"{resnet_prefix}.norm1.conv_b.bias"],
|
| 1194 |
+
# conv1
|
| 1195 |
+
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"],
|
| 1196 |
+
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"],
|
| 1197 |
+
# norm2
|
| 1198 |
+
f"{diffusers_resnet_prefix}.norm2.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm2.norm_layer.weight"],
|
| 1199 |
+
f"{diffusers_resnet_prefix}.norm2.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm2.norm_layer.bias"],
|
| 1200 |
+
f"{diffusers_resnet_prefix}.norm2.conv_y.weight": checkpoint[f"{resnet_prefix}.norm2.conv_y.weight"],
|
| 1201 |
+
f"{diffusers_resnet_prefix}.norm2.conv_y.bias": checkpoint[f"{resnet_prefix}.norm2.conv_y.bias"],
|
| 1202 |
+
f"{diffusers_resnet_prefix}.norm2.conv_b.weight": checkpoint[f"{resnet_prefix}.norm2.conv_b.weight"],
|
| 1203 |
+
f"{diffusers_resnet_prefix}.norm2.conv_b.bias": checkpoint[f"{resnet_prefix}.norm2.conv_b.bias"],
|
| 1204 |
+
# conv2
|
| 1205 |
+
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"],
|
| 1206 |
+
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"],
|
| 1207 |
+
}
|
| 1208 |
+
|
| 1209 |
+
if resnet.conv_shortcut is not None:
|
| 1210 |
+
rv.update(
|
| 1211 |
+
{
|
| 1212 |
+
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"],
|
| 1213 |
+
f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"],
|
| 1214 |
+
}
|
| 1215 |
+
)
|
| 1216 |
+
|
| 1217 |
+
return rv
|
| 1218 |
+
|
| 1219 |
+
|
| 1220 |
+
def movq_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
|
| 1221 |
+
return {
|
| 1222 |
+
# norm
|
| 1223 |
+
f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],
|
| 1224 |
+
f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],
|
| 1225 |
+
# query
|
| 1226 |
+
f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0],
|
| 1227 |
+
f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.q.bias"],
|
| 1228 |
+
# key
|
| 1229 |
+
f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0],
|
| 1230 |
+
f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.k.bias"],
|
| 1231 |
+
# value
|
| 1232 |
+
f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0],
|
| 1233 |
+
f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.v.bias"],
|
| 1234 |
+
# proj_attn
|
| 1235 |
+
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][:, :, 0, 0],
|
| 1236 |
+
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
|
| 1237 |
+
}
|
| 1238 |
+
|
| 1239 |
+
|
| 1240 |
+
def movq_attention_to_diffusers_checkpoint_spatial_norm(checkpoint, *, diffusers_attention_prefix, attention_prefix):
|
| 1241 |
+
return {
|
| 1242 |
+
# norm
|
| 1243 |
+
f"{diffusers_attention_prefix}.spatial_norm.norm_layer.weight": checkpoint[
|
| 1244 |
+
f"{attention_prefix}.norm.norm_layer.weight"
|
| 1245 |
+
],
|
| 1246 |
+
f"{diffusers_attention_prefix}.spatial_norm.norm_layer.bias": checkpoint[
|
| 1247 |
+
f"{attention_prefix}.norm.norm_layer.bias"
|
| 1248 |
+
],
|
| 1249 |
+
f"{diffusers_attention_prefix}.spatial_norm.conv_y.weight": checkpoint[
|
| 1250 |
+
f"{attention_prefix}.norm.conv_y.weight"
|
| 1251 |
+
],
|
| 1252 |
+
f"{diffusers_attention_prefix}.spatial_norm.conv_y.bias": checkpoint[f"{attention_prefix}.norm.conv_y.bias"],
|
| 1253 |
+
f"{diffusers_attention_prefix}.spatial_norm.conv_b.weight": checkpoint[
|
| 1254 |
+
f"{attention_prefix}.norm.conv_b.weight"
|
| 1255 |
+
],
|
| 1256 |
+
f"{diffusers_attention_prefix}.spatial_norm.conv_b.bias": checkpoint[f"{attention_prefix}.norm.conv_b.bias"],
|
| 1257 |
+
# query
|
| 1258 |
+
f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0],
|
| 1259 |
+
f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.q.bias"],
|
| 1260 |
+
# key
|
| 1261 |
+
f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0],
|
| 1262 |
+
f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.k.bias"],
|
| 1263 |
+
# value
|
| 1264 |
+
f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0],
|
| 1265 |
+
f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.v.bias"],
|
| 1266 |
+
# proj_attn
|
| 1267 |
+
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][:, :, 0, 0],
|
| 1268 |
+
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
|
| 1269 |
+
}
|
| 1270 |
+
|
| 1271 |
+
|
| 1272 |
+
def movq_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
| 1273 |
+
diffusers_checkpoint = {}
|
| 1274 |
+
diffusers_checkpoint.update(movq_encoder_to_diffusers_checkpoint(model, checkpoint))
|
| 1275 |
+
|
| 1276 |
+
# quant_conv
|
| 1277 |
+
|
| 1278 |
+
diffusers_checkpoint.update(
|
| 1279 |
+
{
|
| 1280 |
+
"quant_conv.weight": checkpoint["quant_conv.weight"],
|
| 1281 |
+
"quant_conv.bias": checkpoint["quant_conv.bias"],
|
| 1282 |
+
}
|
| 1283 |
+
)
|
| 1284 |
+
|
| 1285 |
+
# quantize
|
| 1286 |
+
diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding.weight"]})
|
| 1287 |
+
|
| 1288 |
+
# post_quant_conv
|
| 1289 |
+
diffusers_checkpoint.update(
|
| 1290 |
+
{
|
| 1291 |
+
"post_quant_conv.weight": checkpoint["post_quant_conv.weight"],
|
| 1292 |
+
"post_quant_conv.bias": checkpoint["post_quant_conv.bias"],
|
| 1293 |
+
}
|
| 1294 |
+
)
|
| 1295 |
+
|
| 1296 |
+
# decoder
|
| 1297 |
+
diffusers_checkpoint.update(movq_decoder_to_diffusers_checkpoint(model, checkpoint))
|
| 1298 |
+
|
| 1299 |
+
return diffusers_checkpoint
|
| 1300 |
+
|
| 1301 |
+
|
| 1302 |
+
def movq(*, args, checkpoint_map_location):
|
| 1303 |
+
print("loading movq")
|
| 1304 |
+
|
| 1305 |
+
movq_checkpoint = torch.load(args.movq_checkpoint_path, map_location=checkpoint_map_location)
|
| 1306 |
+
|
| 1307 |
+
movq_model = movq_model_from_original_config()
|
| 1308 |
+
|
| 1309 |
+
movq_diffusers_checkpoint = movq_original_checkpoint_to_diffusers_checkpoint(movq_model, movq_checkpoint)
|
| 1310 |
+
|
| 1311 |
+
del movq_checkpoint
|
| 1312 |
+
|
| 1313 |
+
load_checkpoint_to_model(movq_diffusers_checkpoint, movq_model, strict=True)
|
| 1314 |
+
|
| 1315 |
+
print("done loading movq")
|
| 1316 |
+
|
| 1317 |
+
return movq_model
|
| 1318 |
+
|
| 1319 |
+
|
| 1320 |
+
def load_checkpoint_to_model(checkpoint, model, strict=False):
|
| 1321 |
+
with tempfile.NamedTemporaryFile(delete=False) as file:
|
| 1322 |
+
torch.save(checkpoint, file.name)
|
| 1323 |
+
del checkpoint
|
| 1324 |
+
if strict:
|
| 1325 |
+
model.load_state_dict(torch.load(file.name), strict=True)
|
| 1326 |
+
else:
|
| 1327 |
+
load_checkpoint_and_dispatch(model, file.name, device_map="auto")
|
| 1328 |
+
os.remove(file.name)
|
| 1329 |
+
|
| 1330 |
+
|
| 1331 |
+
if __name__ == "__main__":
|
| 1332 |
+
parser = argparse.ArgumentParser()
|
| 1333 |
+
|
| 1334 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 1335 |
+
|
| 1336 |
+
parser.add_argument(
|
| 1337 |
+
"--prior_checkpoint_path",
|
| 1338 |
+
default=None,
|
| 1339 |
+
type=str,
|
| 1340 |
+
required=False,
|
| 1341 |
+
help="Path to the prior checkpoint to convert.",
|
| 1342 |
+
)
|
| 1343 |
+
parser.add_argument(
|
| 1344 |
+
"--clip_stat_path",
|
| 1345 |
+
default=None,
|
| 1346 |
+
type=str,
|
| 1347 |
+
required=False,
|
| 1348 |
+
help="Path to the clip stats checkpoint to convert.",
|
| 1349 |
+
)
|
| 1350 |
+
parser.add_argument(
|
| 1351 |
+
"--text2img_checkpoint_path",
|
| 1352 |
+
default=None,
|
| 1353 |
+
type=str,
|
| 1354 |
+
required=False,
|
| 1355 |
+
help="Path to the text2img checkpoint to convert.",
|
| 1356 |
+
)
|
| 1357 |
+
parser.add_argument(
|
| 1358 |
+
"--movq_checkpoint_path",
|
| 1359 |
+
default=None,
|
| 1360 |
+
type=str,
|
| 1361 |
+
required=False,
|
| 1362 |
+
help="Path to the text2img checkpoint to convert.",
|
| 1363 |
+
)
|
| 1364 |
+
parser.add_argument(
|
| 1365 |
+
"--inpaint_text2img_checkpoint_path",
|
| 1366 |
+
default=None,
|
| 1367 |
+
type=str,
|
| 1368 |
+
required=False,
|
| 1369 |
+
help="Path to the inpaint text2img checkpoint to convert.",
|
| 1370 |
+
)
|
| 1371 |
+
parser.add_argument(
|
| 1372 |
+
"--checkpoint_load_device",
|
| 1373 |
+
default="cpu",
|
| 1374 |
+
type=str,
|
| 1375 |
+
required=False,
|
| 1376 |
+
help="The device passed to `map_location` when loading checkpoints.",
|
| 1377 |
+
)
|
| 1378 |
+
|
| 1379 |
+
parser.add_argument(
|
| 1380 |
+
"--debug",
|
| 1381 |
+
default=None,
|
| 1382 |
+
type=str,
|
| 1383 |
+
required=False,
|
| 1384 |
+
help="Only run a specific stage of the convert script. Used for debugging",
|
| 1385 |
+
)
|
| 1386 |
+
|
| 1387 |
+
args = parser.parse_args()
|
| 1388 |
+
|
| 1389 |
+
print(f"loading checkpoints to {args.checkpoint_load_device}")
|
| 1390 |
+
|
| 1391 |
+
checkpoint_map_location = torch.device(args.checkpoint_load_device)
|
| 1392 |
+
|
| 1393 |
+
if args.debug is not None:
|
| 1394 |
+
print(f"debug: only executing {args.debug}")
|
| 1395 |
+
|
| 1396 |
+
if args.debug is None:
|
| 1397 |
+
print("to-do")
|
| 1398 |
+
elif args.debug == "prior":
|
| 1399 |
+
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
|
| 1400 |
+
prior_model.save_pretrained(args.dump_path)
|
| 1401 |
+
elif args.debug == "text2img":
|
| 1402 |
+
unet_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location)
|
| 1403 |
+
unet_model.save_pretrained(f"{args.dump_path}/unet")
|
| 1404 |
+
elif args.debug == "inpaint_text2img":
|
| 1405 |
+
inpaint_unet_model = inpaint_text2img(args=args, checkpoint_map_location=checkpoint_map_location)
|
| 1406 |
+
inpaint_unet_model.save_pretrained(f"{args.dump_path}/inpaint_unet")
|
| 1407 |
+
elif args.debug == "decoder":
|
| 1408 |
+
decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location)
|
| 1409 |
+
decoder.save_pretrained(f"{args.dump_path}/decoder")
|
| 1410 |
+
else:
|
| 1411 |
+
raise ValueError(f"unknown debug value : {args.debug}")
|
diffusers/scripts/convert_ldm_original_checkpoint_to_diffusers.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Conversion script for the LDM checkpoints."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from diffusers import DDPMScheduler, LDMPipeline, UNet2DModel, VQModel
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
| 26 |
+
"""
|
| 27 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
| 28 |
+
"""
|
| 29 |
+
if n_shave_prefix_segments >= 0:
|
| 30 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
| 31 |
+
else:
|
| 32 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 36 |
+
"""
|
| 37 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 38 |
+
"""
|
| 39 |
+
mapping = []
|
| 40 |
+
for old_item in old_list:
|
| 41 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
| 42 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
| 43 |
+
|
| 44 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
| 45 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
| 46 |
+
|
| 47 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
| 48 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
| 49 |
+
|
| 50 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 51 |
+
|
| 52 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 53 |
+
|
| 54 |
+
return mapping
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 58 |
+
"""
|
| 59 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 60 |
+
"""
|
| 61 |
+
mapping = []
|
| 62 |
+
for old_item in old_list:
|
| 63 |
+
new_item = old_item
|
| 64 |
+
|
| 65 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
| 66 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
| 67 |
+
|
| 68 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
| 69 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
| 70 |
+
|
| 71 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 72 |
+
|
| 73 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 74 |
+
|
| 75 |
+
return mapping
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def assign_to_checkpoint(
|
| 79 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
| 80 |
+
):
|
| 81 |
+
"""
|
| 82 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
| 83 |
+
to them. It splits attention layers, and takes into account additional replacements
|
| 84 |
+
that may arise.
|
| 85 |
+
|
| 86 |
+
Assigns the weights to the new checkpoint.
|
| 87 |
+
"""
|
| 88 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
| 89 |
+
|
| 90 |
+
# Splits the attention layers into three variables.
|
| 91 |
+
if attention_paths_to_split is not None:
|
| 92 |
+
for path, path_map in attention_paths_to_split.items():
|
| 93 |
+
old_tensor = old_checkpoint[path]
|
| 94 |
+
channels = old_tensor.shape[0] // 3
|
| 95 |
+
|
| 96 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
| 97 |
+
|
| 98 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
| 99 |
+
|
| 100 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
| 101 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
| 102 |
+
|
| 103 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
| 104 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
| 105 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
| 106 |
+
|
| 107 |
+
for path in paths:
|
| 108 |
+
new_path = path["new"]
|
| 109 |
+
|
| 110 |
+
# These have already been assigned
|
| 111 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
# Global renaming happens here
|
| 115 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
| 116 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
| 117 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
| 118 |
+
|
| 119 |
+
if additional_replacements is not None:
|
| 120 |
+
for replacement in additional_replacements:
|
| 121 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
| 122 |
+
|
| 123 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
| 124 |
+
if "proj_attn.weight" in new_path:
|
| 125 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
| 126 |
+
else:
|
| 127 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def convert_ldm_checkpoint(checkpoint, config):
|
| 131 |
+
"""
|
| 132 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
| 133 |
+
"""
|
| 134 |
+
new_checkpoint = {}
|
| 135 |
+
|
| 136 |
+
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
|
| 137 |
+
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
|
| 138 |
+
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
|
| 139 |
+
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
|
| 140 |
+
|
| 141 |
+
new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
|
| 142 |
+
new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
|
| 143 |
+
|
| 144 |
+
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
|
| 145 |
+
new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
|
| 146 |
+
new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
|
| 147 |
+
new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
|
| 148 |
+
|
| 149 |
+
# Retrieves the keys for the input blocks only
|
| 150 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer})
|
| 151 |
+
input_blocks = {
|
| 152 |
+
layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key]
|
| 153 |
+
for layer_id in range(num_input_blocks)
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
# Retrieves the keys for the middle blocks only
|
| 157 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "middle_block" in layer})
|
| 158 |
+
middle_blocks = {
|
| 159 |
+
layer_id: [key for key in checkpoint if f"middle_block.{layer_id}" in key]
|
| 160 |
+
for layer_id in range(num_middle_blocks)
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
# Retrieves the keys for the output blocks only
|
| 164 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "output_blocks" in layer})
|
| 165 |
+
output_blocks = {
|
| 166 |
+
layer_id: [key for key in checkpoint if f"output_blocks.{layer_id}" in key]
|
| 167 |
+
for layer_id in range(num_output_blocks)
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
for i in range(1, num_input_blocks):
|
| 171 |
+
block_id = (i - 1) // (config["num_res_blocks"] + 1)
|
| 172 |
+
layer_in_block_id = (i - 1) % (config["num_res_blocks"] + 1)
|
| 173 |
+
|
| 174 |
+
resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key]
|
| 175 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
| 176 |
+
|
| 177 |
+
if f"input_blocks.{i}.0.op.weight" in checkpoint:
|
| 178 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[
|
| 179 |
+
f"input_blocks.{i}.0.op.weight"
|
| 180 |
+
]
|
| 181 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[
|
| 182 |
+
f"input_blocks.{i}.0.op.bias"
|
| 183 |
+
]
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
paths = renew_resnet_paths(resnets)
|
| 187 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 188 |
+
resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"}
|
| 189 |
+
assign_to_checkpoint(
|
| 190 |
+
paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
if len(attentions):
|
| 194 |
+
paths = renew_attention_paths(attentions)
|
| 195 |
+
meta_path = {
|
| 196 |
+
"old": f"input_blocks.{i}.1",
|
| 197 |
+
"new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}",
|
| 198 |
+
}
|
| 199 |
+
to_split = {
|
| 200 |
+
f"input_blocks.{i}.1.qkv.bias": {
|
| 201 |
+
"key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
|
| 202 |
+
"query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
|
| 203 |
+
"value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
|
| 204 |
+
},
|
| 205 |
+
f"input_blocks.{i}.1.qkv.weight": {
|
| 206 |
+
"key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
|
| 207 |
+
"query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
|
| 208 |
+
"value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
|
| 209 |
+
},
|
| 210 |
+
}
|
| 211 |
+
assign_to_checkpoint(
|
| 212 |
+
paths,
|
| 213 |
+
new_checkpoint,
|
| 214 |
+
checkpoint,
|
| 215 |
+
additional_replacements=[meta_path],
|
| 216 |
+
attention_paths_to_split=to_split,
|
| 217 |
+
config=config,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
resnet_0 = middle_blocks[0]
|
| 221 |
+
attentions = middle_blocks[1]
|
| 222 |
+
resnet_1 = middle_blocks[2]
|
| 223 |
+
|
| 224 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
| 225 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint, config=config)
|
| 226 |
+
|
| 227 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
| 228 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint, config=config)
|
| 229 |
+
|
| 230 |
+
attentions_paths = renew_attention_paths(attentions)
|
| 231 |
+
to_split = {
|
| 232 |
+
"middle_block.1.qkv.bias": {
|
| 233 |
+
"key": "mid_block.attentions.0.key.bias",
|
| 234 |
+
"query": "mid_block.attentions.0.query.bias",
|
| 235 |
+
"value": "mid_block.attentions.0.value.bias",
|
| 236 |
+
},
|
| 237 |
+
"middle_block.1.qkv.weight": {
|
| 238 |
+
"key": "mid_block.attentions.0.key.weight",
|
| 239 |
+
"query": "mid_block.attentions.0.query.weight",
|
| 240 |
+
"value": "mid_block.attentions.0.value.weight",
|
| 241 |
+
},
|
| 242 |
+
}
|
| 243 |
+
assign_to_checkpoint(
|
| 244 |
+
attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
for i in range(num_output_blocks):
|
| 248 |
+
block_id = i // (config["num_res_blocks"] + 1)
|
| 249 |
+
layer_in_block_id = i % (config["num_res_blocks"] + 1)
|
| 250 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
| 251 |
+
output_block_list = {}
|
| 252 |
+
|
| 253 |
+
for layer in output_block_layers:
|
| 254 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
| 255 |
+
if layer_id in output_block_list:
|
| 256 |
+
output_block_list[layer_id].append(layer_name)
|
| 257 |
+
else:
|
| 258 |
+
output_block_list[layer_id] = [layer_name]
|
| 259 |
+
|
| 260 |
+
if len(output_block_list) > 1:
|
| 261 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
| 262 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
| 263 |
+
|
| 264 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
| 265 |
+
paths = renew_resnet_paths(resnets)
|
| 266 |
+
|
| 267 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 268 |
+
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
|
| 269 |
+
|
| 270 |
+
if ["conv.weight", "conv.bias"] in output_block_list.values():
|
| 271 |
+
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
| 272 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
|
| 273 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
| 274 |
+
]
|
| 275 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
|
| 276 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
| 277 |
+
]
|
| 278 |
+
|
| 279 |
+
# Clear attentions as they have been attributed above.
|
| 280 |
+
if len(attentions) == 2:
|
| 281 |
+
attentions = []
|
| 282 |
+
|
| 283 |
+
if len(attentions):
|
| 284 |
+
paths = renew_attention_paths(attentions)
|
| 285 |
+
meta_path = {
|
| 286 |
+
"old": f"output_blocks.{i}.1",
|
| 287 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
| 288 |
+
}
|
| 289 |
+
to_split = {
|
| 290 |
+
f"output_blocks.{i}.1.qkv.bias": {
|
| 291 |
+
"key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
|
| 292 |
+
"query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
|
| 293 |
+
"value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
|
| 294 |
+
},
|
| 295 |
+
f"output_blocks.{i}.1.qkv.weight": {
|
| 296 |
+
"key": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
|
| 297 |
+
"query": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
|
| 298 |
+
"value": f"up_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
|
| 299 |
+
},
|
| 300 |
+
}
|
| 301 |
+
assign_to_checkpoint(
|
| 302 |
+
paths,
|
| 303 |
+
new_checkpoint,
|
| 304 |
+
checkpoint,
|
| 305 |
+
additional_replacements=[meta_path],
|
| 306 |
+
attention_paths_to_split=to_split if any("qkv" in key for key in attentions) else None,
|
| 307 |
+
config=config,
|
| 308 |
+
)
|
| 309 |
+
else:
|
| 310 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
| 311 |
+
for path in resnet_0_paths:
|
| 312 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
| 313 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
| 314 |
+
|
| 315 |
+
new_checkpoint[new_path] = checkpoint[old_path]
|
| 316 |
+
|
| 317 |
+
return new_checkpoint
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
parser = argparse.ArgumentParser()
|
| 322 |
+
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
parser.add_argument(
|
| 328 |
+
"--config_file",
|
| 329 |
+
default=None,
|
| 330 |
+
type=str,
|
| 331 |
+
required=True,
|
| 332 |
+
help="The config json file corresponding to the architecture.",
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 336 |
+
|
| 337 |
+
args = parser.parse_args()
|
| 338 |
+
|
| 339 |
+
checkpoint = torch.load(args.checkpoint_path)
|
| 340 |
+
|
| 341 |
+
with open(args.config_file) as f:
|
| 342 |
+
config = json.loads(f.read())
|
| 343 |
+
|
| 344 |
+
converted_checkpoint = convert_ldm_checkpoint(checkpoint, config)
|
| 345 |
+
|
| 346 |
+
if "ldm" in config:
|
| 347 |
+
del config["ldm"]
|
| 348 |
+
|
| 349 |
+
model = UNet2DModel(**config)
|
| 350 |
+
model.load_state_dict(converted_checkpoint)
|
| 351 |
+
|
| 352 |
+
try:
|
| 353 |
+
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
|
| 354 |
+
vqvae = VQModel.from_pretrained("/".join(args.checkpoint_path.split("/")[:-1]))
|
| 355 |
+
|
| 356 |
+
pipe = LDMPipeline(unet=model, scheduler=scheduler, vae=vqvae)
|
| 357 |
+
pipe.save_pretrained(args.dump_path)
|
| 358 |
+
except: # noqa: E722
|
| 359 |
+
model.save_pretrained(args.dump_path)
|
diffusers/scripts/convert_ltx_to_diffusers.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from accelerate import init_empty_weights
|
| 7 |
+
from safetensors.torch import load_file
|
| 8 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 9 |
+
|
| 10 |
+
from diffusers import (
|
| 11 |
+
AutoencoderKLLTXVideo,
|
| 12 |
+
FlowMatchEulerDiscreteScheduler,
|
| 13 |
+
LTXConditionPipeline,
|
| 14 |
+
LTXLatentUpsamplePipeline,
|
| 15 |
+
LTXPipeline,
|
| 16 |
+
LTXVideoTransformer3DModel,
|
| 17 |
+
)
|
| 18 |
+
from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def remove_keys_(key: str, state_dict: Dict[str, Any]):
|
| 22 |
+
state_dict.pop(key)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
TOKENIZER_MAX_LENGTH = 128
|
| 26 |
+
|
| 27 |
+
TRANSFORMER_KEYS_RENAME_DICT = {
|
| 28 |
+
"patchify_proj": "proj_in",
|
| 29 |
+
"adaln_single": "time_embed",
|
| 30 |
+
"q_norm": "norm_q",
|
| 31 |
+
"k_norm": "norm_k",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
| 35 |
+
"vae": remove_keys_,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
VAE_KEYS_RENAME_DICT = {
|
| 39 |
+
# decoder
|
| 40 |
+
"up_blocks.0": "mid_block",
|
| 41 |
+
"up_blocks.1": "up_blocks.0",
|
| 42 |
+
"up_blocks.2": "up_blocks.1.upsamplers.0",
|
| 43 |
+
"up_blocks.3": "up_blocks.1",
|
| 44 |
+
"up_blocks.4": "up_blocks.2.conv_in",
|
| 45 |
+
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
| 46 |
+
"up_blocks.6": "up_blocks.2",
|
| 47 |
+
"up_blocks.7": "up_blocks.3.conv_in",
|
| 48 |
+
"up_blocks.8": "up_blocks.3.upsamplers.0",
|
| 49 |
+
"up_blocks.9": "up_blocks.3",
|
| 50 |
+
# encoder
|
| 51 |
+
"down_blocks.0": "down_blocks.0",
|
| 52 |
+
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
| 53 |
+
"down_blocks.2": "down_blocks.0.conv_out",
|
| 54 |
+
"down_blocks.3": "down_blocks.1",
|
| 55 |
+
"down_blocks.4": "down_blocks.1.downsamplers.0",
|
| 56 |
+
"down_blocks.5": "down_blocks.1.conv_out",
|
| 57 |
+
"down_blocks.6": "down_blocks.2",
|
| 58 |
+
"down_blocks.7": "down_blocks.2.downsamplers.0",
|
| 59 |
+
"down_blocks.8": "down_blocks.3",
|
| 60 |
+
"down_blocks.9": "mid_block",
|
| 61 |
+
# common
|
| 62 |
+
"conv_shortcut": "conv_shortcut.conv",
|
| 63 |
+
"res_blocks": "resnets",
|
| 64 |
+
"norm3.norm": "norm3",
|
| 65 |
+
"per_channel_statistics.mean-of-means": "latents_mean",
|
| 66 |
+
"per_channel_statistics.std-of-means": "latents_std",
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
VAE_091_RENAME_DICT = {
|
| 70 |
+
# decoder
|
| 71 |
+
"up_blocks.0": "mid_block",
|
| 72 |
+
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
| 73 |
+
"up_blocks.2": "up_blocks.0",
|
| 74 |
+
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
| 75 |
+
"up_blocks.4": "up_blocks.1",
|
| 76 |
+
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
| 77 |
+
"up_blocks.6": "up_blocks.2",
|
| 78 |
+
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
| 79 |
+
"up_blocks.8": "up_blocks.3",
|
| 80 |
+
# common
|
| 81 |
+
"last_time_embedder": "time_embedder",
|
| 82 |
+
"last_scale_shift_table": "scale_shift_table",
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
VAE_095_RENAME_DICT = {
|
| 86 |
+
# decoder
|
| 87 |
+
"up_blocks.0": "mid_block",
|
| 88 |
+
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
| 89 |
+
"up_blocks.2": "up_blocks.0",
|
| 90 |
+
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
| 91 |
+
"up_blocks.4": "up_blocks.1",
|
| 92 |
+
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
| 93 |
+
"up_blocks.6": "up_blocks.2",
|
| 94 |
+
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
| 95 |
+
"up_blocks.8": "up_blocks.3",
|
| 96 |
+
# encoder
|
| 97 |
+
"down_blocks.0": "down_blocks.0",
|
| 98 |
+
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
| 99 |
+
"down_blocks.2": "down_blocks.1",
|
| 100 |
+
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
| 101 |
+
"down_blocks.4": "down_blocks.2",
|
| 102 |
+
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
| 103 |
+
"down_blocks.6": "down_blocks.3",
|
| 104 |
+
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
| 105 |
+
"down_blocks.8": "mid_block",
|
| 106 |
+
# common
|
| 107 |
+
"last_time_embedder": "time_embedder",
|
| 108 |
+
"last_scale_shift_table": "scale_shift_table",
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
VAE_SPECIAL_KEYS_REMAP = {
|
| 112 |
+
"per_channel_statistics.channel": remove_keys_,
|
| 113 |
+
"per_channel_statistics.mean-of-means": remove_keys_,
|
| 114 |
+
"per_channel_statistics.mean-of-stds": remove_keys_,
|
| 115 |
+
"model.diffusion_model": remove_keys_,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
| 120 |
+
state_dict = saved_dict
|
| 121 |
+
if "model" in saved_dict.keys():
|
| 122 |
+
state_dict = state_dict["model"]
|
| 123 |
+
if "module" in saved_dict.keys():
|
| 124 |
+
state_dict = state_dict["module"]
|
| 125 |
+
if "state_dict" in saved_dict.keys():
|
| 126 |
+
state_dict = state_dict["state_dict"]
|
| 127 |
+
return state_dict
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
| 131 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def convert_transformer(ckpt_path: str, config, dtype: torch.dtype):
|
| 135 |
+
PREFIX_KEY = "model.diffusion_model."
|
| 136 |
+
|
| 137 |
+
original_state_dict = get_state_dict(load_file(ckpt_path))
|
| 138 |
+
with init_empty_weights():
|
| 139 |
+
transformer = LTXVideoTransformer3DModel(**config)
|
| 140 |
+
|
| 141 |
+
for key in list(original_state_dict.keys()):
|
| 142 |
+
new_key = key[:]
|
| 143 |
+
if new_key.startswith(PREFIX_KEY):
|
| 144 |
+
new_key = key[len(PREFIX_KEY) :]
|
| 145 |
+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
| 146 |
+
new_key = new_key.replace(replace_key, rename_key)
|
| 147 |
+
update_state_dict_inplace(original_state_dict, key, new_key)
|
| 148 |
+
|
| 149 |
+
for key in list(original_state_dict.keys()):
|
| 150 |
+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
| 151 |
+
if special_key not in key:
|
| 152 |
+
continue
|
| 153 |
+
handler_fn_inplace(key, original_state_dict)
|
| 154 |
+
|
| 155 |
+
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
| 156 |
+
return transformer
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
|
| 160 |
+
PREFIX_KEY = "vae."
|
| 161 |
+
|
| 162 |
+
original_state_dict = get_state_dict(load_file(ckpt_path))
|
| 163 |
+
with init_empty_weights():
|
| 164 |
+
vae = AutoencoderKLLTXVideo(**config)
|
| 165 |
+
|
| 166 |
+
for key in list(original_state_dict.keys()):
|
| 167 |
+
new_key = key[:]
|
| 168 |
+
if new_key.startswith(PREFIX_KEY):
|
| 169 |
+
new_key = key[len(PREFIX_KEY) :]
|
| 170 |
+
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
| 171 |
+
new_key = new_key.replace(replace_key, rename_key)
|
| 172 |
+
update_state_dict_inplace(original_state_dict, key, new_key)
|
| 173 |
+
|
| 174 |
+
for key in list(original_state_dict.keys()):
|
| 175 |
+
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
|
| 176 |
+
if special_key not in key:
|
| 177 |
+
continue
|
| 178 |
+
handler_fn_inplace(key, original_state_dict)
|
| 179 |
+
|
| 180 |
+
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
| 181 |
+
return vae
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def convert_spatial_latent_upsampler(ckpt_path: str, config, dtype: torch.dtype):
|
| 185 |
+
original_state_dict = get_state_dict(load_file(ckpt_path))
|
| 186 |
+
|
| 187 |
+
with init_empty_weights():
|
| 188 |
+
latent_upsampler = LTXLatentUpsamplerModel(**config)
|
| 189 |
+
|
| 190 |
+
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
|
| 191 |
+
latent_upsampler.to(dtype)
|
| 192 |
+
return latent_upsampler
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def get_transformer_config(version: str) -> Dict[str, Any]:
|
| 196 |
+
if version == "0.9.7":
|
| 197 |
+
config = {
|
| 198 |
+
"in_channels": 128,
|
| 199 |
+
"out_channels": 128,
|
| 200 |
+
"patch_size": 1,
|
| 201 |
+
"patch_size_t": 1,
|
| 202 |
+
"num_attention_heads": 32,
|
| 203 |
+
"attention_head_dim": 128,
|
| 204 |
+
"cross_attention_dim": 4096,
|
| 205 |
+
"num_layers": 48,
|
| 206 |
+
"activation_fn": "gelu-approximate",
|
| 207 |
+
"qk_norm": "rms_norm_across_heads",
|
| 208 |
+
"norm_elementwise_affine": False,
|
| 209 |
+
"norm_eps": 1e-6,
|
| 210 |
+
"caption_channels": 4096,
|
| 211 |
+
"attention_bias": True,
|
| 212 |
+
"attention_out_bias": True,
|
| 213 |
+
}
|
| 214 |
+
else:
|
| 215 |
+
config = {
|
| 216 |
+
"in_channels": 128,
|
| 217 |
+
"out_channels": 128,
|
| 218 |
+
"patch_size": 1,
|
| 219 |
+
"patch_size_t": 1,
|
| 220 |
+
"num_attention_heads": 32,
|
| 221 |
+
"attention_head_dim": 64,
|
| 222 |
+
"cross_attention_dim": 2048,
|
| 223 |
+
"num_layers": 28,
|
| 224 |
+
"activation_fn": "gelu-approximate",
|
| 225 |
+
"qk_norm": "rms_norm_across_heads",
|
| 226 |
+
"norm_elementwise_affine": False,
|
| 227 |
+
"norm_eps": 1e-6,
|
| 228 |
+
"caption_channels": 4096,
|
| 229 |
+
"attention_bias": True,
|
| 230 |
+
"attention_out_bias": True,
|
| 231 |
+
}
|
| 232 |
+
return config
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def get_vae_config(version: str) -> Dict[str, Any]:
|
| 236 |
+
if version in ["0.9.0"]:
|
| 237 |
+
config = {
|
| 238 |
+
"in_channels": 3,
|
| 239 |
+
"out_channels": 3,
|
| 240 |
+
"latent_channels": 128,
|
| 241 |
+
"block_out_channels": (128, 256, 512, 512),
|
| 242 |
+
"down_block_types": (
|
| 243 |
+
"LTXVideoDownBlock3D",
|
| 244 |
+
"LTXVideoDownBlock3D",
|
| 245 |
+
"LTXVideoDownBlock3D",
|
| 246 |
+
"LTXVideoDownBlock3D",
|
| 247 |
+
),
|
| 248 |
+
"decoder_block_out_channels": (128, 256, 512, 512),
|
| 249 |
+
"layers_per_block": (4, 3, 3, 3, 4),
|
| 250 |
+
"decoder_layers_per_block": (4, 3, 3, 3, 4),
|
| 251 |
+
"spatio_temporal_scaling": (True, True, True, False),
|
| 252 |
+
"decoder_spatio_temporal_scaling": (True, True, True, False),
|
| 253 |
+
"decoder_inject_noise": (False, False, False, False, False),
|
| 254 |
+
"downsample_type": ("conv", "conv", "conv", "conv"),
|
| 255 |
+
"upsample_residual": (False, False, False, False),
|
| 256 |
+
"upsample_factor": (1, 1, 1, 1),
|
| 257 |
+
"patch_size": 4,
|
| 258 |
+
"patch_size_t": 1,
|
| 259 |
+
"resnet_norm_eps": 1e-6,
|
| 260 |
+
"scaling_factor": 1.0,
|
| 261 |
+
"encoder_causal": True,
|
| 262 |
+
"decoder_causal": False,
|
| 263 |
+
"timestep_conditioning": False,
|
| 264 |
+
}
|
| 265 |
+
elif version in ["0.9.1"]:
|
| 266 |
+
config = {
|
| 267 |
+
"in_channels": 3,
|
| 268 |
+
"out_channels": 3,
|
| 269 |
+
"latent_channels": 128,
|
| 270 |
+
"block_out_channels": (128, 256, 512, 512),
|
| 271 |
+
"down_block_types": (
|
| 272 |
+
"LTXVideoDownBlock3D",
|
| 273 |
+
"LTXVideoDownBlock3D",
|
| 274 |
+
"LTXVideoDownBlock3D",
|
| 275 |
+
"LTXVideoDownBlock3D",
|
| 276 |
+
),
|
| 277 |
+
"decoder_block_out_channels": (256, 512, 1024),
|
| 278 |
+
"layers_per_block": (4, 3, 3, 3, 4),
|
| 279 |
+
"decoder_layers_per_block": (5, 6, 7, 8),
|
| 280 |
+
"spatio_temporal_scaling": (True, True, True, False),
|
| 281 |
+
"decoder_spatio_temporal_scaling": (True, True, True),
|
| 282 |
+
"decoder_inject_noise": (True, True, True, False),
|
| 283 |
+
"downsample_type": ("conv", "conv", "conv", "conv"),
|
| 284 |
+
"upsample_residual": (True, True, True),
|
| 285 |
+
"upsample_factor": (2, 2, 2),
|
| 286 |
+
"timestep_conditioning": True,
|
| 287 |
+
"patch_size": 4,
|
| 288 |
+
"patch_size_t": 1,
|
| 289 |
+
"resnet_norm_eps": 1e-6,
|
| 290 |
+
"scaling_factor": 1.0,
|
| 291 |
+
"encoder_causal": True,
|
| 292 |
+
"decoder_causal": False,
|
| 293 |
+
}
|
| 294 |
+
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
| 295 |
+
elif version in ["0.9.5"]:
|
| 296 |
+
config = {
|
| 297 |
+
"in_channels": 3,
|
| 298 |
+
"out_channels": 3,
|
| 299 |
+
"latent_channels": 128,
|
| 300 |
+
"block_out_channels": (128, 256, 512, 1024, 2048),
|
| 301 |
+
"down_block_types": (
|
| 302 |
+
"LTXVideo095DownBlock3D",
|
| 303 |
+
"LTXVideo095DownBlock3D",
|
| 304 |
+
"LTXVideo095DownBlock3D",
|
| 305 |
+
"LTXVideo095DownBlock3D",
|
| 306 |
+
),
|
| 307 |
+
"decoder_block_out_channels": (256, 512, 1024),
|
| 308 |
+
"layers_per_block": (4, 6, 6, 2, 2),
|
| 309 |
+
"decoder_layers_per_block": (5, 5, 5, 5),
|
| 310 |
+
"spatio_temporal_scaling": (True, True, True, True),
|
| 311 |
+
"decoder_spatio_temporal_scaling": (True, True, True),
|
| 312 |
+
"decoder_inject_noise": (False, False, False, False),
|
| 313 |
+
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
| 314 |
+
"upsample_residual": (True, True, True),
|
| 315 |
+
"upsample_factor": (2, 2, 2),
|
| 316 |
+
"timestep_conditioning": True,
|
| 317 |
+
"patch_size": 4,
|
| 318 |
+
"patch_size_t": 1,
|
| 319 |
+
"resnet_norm_eps": 1e-6,
|
| 320 |
+
"scaling_factor": 1.0,
|
| 321 |
+
"encoder_causal": True,
|
| 322 |
+
"decoder_causal": False,
|
| 323 |
+
"spatial_compression_ratio": 32,
|
| 324 |
+
"temporal_compression_ratio": 8,
|
| 325 |
+
}
|
| 326 |
+
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
|
| 327 |
+
elif version in ["0.9.7"]:
|
| 328 |
+
config = {
|
| 329 |
+
"in_channels": 3,
|
| 330 |
+
"out_channels": 3,
|
| 331 |
+
"latent_channels": 128,
|
| 332 |
+
"block_out_channels": (128, 256, 512, 1024, 2048),
|
| 333 |
+
"down_block_types": (
|
| 334 |
+
"LTXVideo095DownBlock3D",
|
| 335 |
+
"LTXVideo095DownBlock3D",
|
| 336 |
+
"LTXVideo095DownBlock3D",
|
| 337 |
+
"LTXVideo095DownBlock3D",
|
| 338 |
+
),
|
| 339 |
+
"decoder_block_out_channels": (256, 512, 1024),
|
| 340 |
+
"layers_per_block": (4, 6, 6, 2, 2),
|
| 341 |
+
"decoder_layers_per_block": (5, 5, 5, 5),
|
| 342 |
+
"spatio_temporal_scaling": (True, True, True, True),
|
| 343 |
+
"decoder_spatio_temporal_scaling": (True, True, True),
|
| 344 |
+
"decoder_inject_noise": (False, False, False, False),
|
| 345 |
+
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
| 346 |
+
"upsample_residual": (True, True, True),
|
| 347 |
+
"upsample_factor": (2, 2, 2),
|
| 348 |
+
"timestep_conditioning": True,
|
| 349 |
+
"patch_size": 4,
|
| 350 |
+
"patch_size_t": 1,
|
| 351 |
+
"resnet_norm_eps": 1e-6,
|
| 352 |
+
"scaling_factor": 1.0,
|
| 353 |
+
"encoder_causal": True,
|
| 354 |
+
"decoder_causal": False,
|
| 355 |
+
"spatial_compression_ratio": 32,
|
| 356 |
+
"temporal_compression_ratio": 8,
|
| 357 |
+
}
|
| 358 |
+
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
|
| 359 |
+
return config
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def get_spatial_latent_upsampler_config(version: str) -> Dict[str, Any]:
|
| 363 |
+
if version == "0.9.7":
|
| 364 |
+
config = {
|
| 365 |
+
"in_channels": 128,
|
| 366 |
+
"mid_channels": 512,
|
| 367 |
+
"num_blocks_per_stage": 4,
|
| 368 |
+
"dims": 3,
|
| 369 |
+
"spatial_upsample": True,
|
| 370 |
+
"temporal_upsample": False,
|
| 371 |
+
}
|
| 372 |
+
else:
|
| 373 |
+
raise ValueError(f"Unsupported version: {version}")
|
| 374 |
+
return config
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def get_args():
|
| 378 |
+
parser = argparse.ArgumentParser()
|
| 379 |
+
parser.add_argument(
|
| 380 |
+
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
| 381 |
+
)
|
| 382 |
+
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
| 383 |
+
parser.add_argument(
|
| 384 |
+
"--spatial_latent_upsampler_path",
|
| 385 |
+
type=str,
|
| 386 |
+
default=None,
|
| 387 |
+
help="Path to original spatial latent upsampler checkpoint",
|
| 388 |
+
)
|
| 389 |
+
parser.add_argument(
|
| 390 |
+
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
| 391 |
+
)
|
| 392 |
+
parser.add_argument(
|
| 393 |
+
"--typecast_text_encoder",
|
| 394 |
+
action="store_true",
|
| 395 |
+
default=False,
|
| 396 |
+
help="Whether or not to apply fp16/bf16 precision to text_encoder",
|
| 397 |
+
)
|
| 398 |
+
parser.add_argument("--save_pipeline", action="store_true")
|
| 399 |
+
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
| 400 |
+
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
| 401 |
+
parser.add_argument(
|
| 402 |
+
"--version",
|
| 403 |
+
type=str,
|
| 404 |
+
default="0.9.0",
|
| 405 |
+
choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7"],
|
| 406 |
+
help="Version of the LTX model",
|
| 407 |
+
)
|
| 408 |
+
return parser.parse_args()
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
DTYPE_MAPPING = {
|
| 412 |
+
"fp32": torch.float32,
|
| 413 |
+
"fp16": torch.float16,
|
| 414 |
+
"bf16": torch.bfloat16,
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
VARIANT_MAPPING = {
|
| 418 |
+
"fp32": None,
|
| 419 |
+
"fp16": "fp16",
|
| 420 |
+
"bf16": "bf16",
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
if __name__ == "__main__":
|
| 425 |
+
args = get_args()
|
| 426 |
+
|
| 427 |
+
transformer = None
|
| 428 |
+
dtype = DTYPE_MAPPING[args.dtype]
|
| 429 |
+
variant = VARIANT_MAPPING[args.dtype]
|
| 430 |
+
output_path = Path(args.output_path)
|
| 431 |
+
|
| 432 |
+
if args.transformer_ckpt_path is not None:
|
| 433 |
+
config = get_transformer_config(args.version)
|
| 434 |
+
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, config, dtype)
|
| 435 |
+
if not args.save_pipeline:
|
| 436 |
+
transformer.save_pretrained(
|
| 437 |
+
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
if args.vae_ckpt_path is not None:
|
| 441 |
+
config = get_vae_config(args.version)
|
| 442 |
+
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype)
|
| 443 |
+
if not args.save_pipeline:
|
| 444 |
+
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
|
| 445 |
+
|
| 446 |
+
if args.spatial_latent_upsampler_path is not None:
|
| 447 |
+
config = get_spatial_latent_upsampler_config(args.version)
|
| 448 |
+
latent_upsampler: LTXLatentUpsamplerModel = convert_spatial_latent_upsampler(
|
| 449 |
+
args.spatial_latent_upsampler_path, config, dtype
|
| 450 |
+
)
|
| 451 |
+
if not args.save_pipeline:
|
| 452 |
+
latent_upsampler.save_pretrained(
|
| 453 |
+
output_path / "latent_upsampler", safe_serialization=True, max_shard_size="5GB", variant=variant
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
if args.save_pipeline:
|
| 457 |
+
text_encoder_id = "google/t5-v1_1-xxl"
|
| 458 |
+
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
| 459 |
+
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
| 460 |
+
|
| 461 |
+
if args.typecast_text_encoder:
|
| 462 |
+
text_encoder = text_encoder.to(dtype=dtype)
|
| 463 |
+
|
| 464 |
+
# Apparently, the conversion does not work anymore without this :shrug:
|
| 465 |
+
for param in text_encoder.parameters():
|
| 466 |
+
param.data = param.data.contiguous()
|
| 467 |
+
|
| 468 |
+
if args.version in ["0.9.5", "0.9.7"]:
|
| 469 |
+
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
|
| 470 |
+
else:
|
| 471 |
+
scheduler = FlowMatchEulerDiscreteScheduler(
|
| 472 |
+
use_dynamic_shifting=True,
|
| 473 |
+
base_shift=0.95,
|
| 474 |
+
max_shift=2.05,
|
| 475 |
+
base_image_seq_len=1024,
|
| 476 |
+
max_image_seq_len=4096,
|
| 477 |
+
shift_terminal=0.1,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
if args.version in ["0.9.0", "0.9.1", "0.9.5"]:
|
| 481 |
+
pipe = LTXPipeline(
|
| 482 |
+
scheduler=scheduler,
|
| 483 |
+
vae=vae,
|
| 484 |
+
text_encoder=text_encoder,
|
| 485 |
+
tokenizer=tokenizer,
|
| 486 |
+
transformer=transformer,
|
| 487 |
+
)
|
| 488 |
+
pipe.save_pretrained(
|
| 489 |
+
output_path.as_posix(), safe_serialization=True, variant=variant, max_shard_size="5GB"
|
| 490 |
+
)
|
| 491 |
+
elif args.version in ["0.9.7"]:
|
| 492 |
+
pipe = LTXConditionPipeline(
|
| 493 |
+
scheduler=scheduler,
|
| 494 |
+
vae=vae,
|
| 495 |
+
text_encoder=text_encoder,
|
| 496 |
+
tokenizer=tokenizer,
|
| 497 |
+
transformer=transformer,
|
| 498 |
+
)
|
| 499 |
+
pipe_upsample = LTXLatentUpsamplePipeline(
|
| 500 |
+
vae=vae,
|
| 501 |
+
latent_upsampler=latent_upsampler,
|
| 502 |
+
)
|
| 503 |
+
pipe.save_pretrained(
|
| 504 |
+
(output_path / "ltx_pipeline").as_posix(),
|
| 505 |
+
safe_serialization=True,
|
| 506 |
+
variant=variant,
|
| 507 |
+
max_shard_size="5GB",
|
| 508 |
+
)
|
| 509 |
+
pipe_upsample.save_pretrained(
|
| 510 |
+
(output_path / "ltx_upsample_pipeline").as_posix(),
|
| 511 |
+
safe_serialization=True,
|
| 512 |
+
variant=variant,
|
| 513 |
+
max_shard_size="5GB",
|
| 514 |
+
)
|
| 515 |
+
else:
|
| 516 |
+
raise ValueError(f"Unsupported version: {args.version}")
|
diffusers/scripts/convert_lumina_to_diffusers.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from safetensors.torch import load_file
|
| 6 |
+
from transformers import AutoModel, AutoTokenizer
|
| 7 |
+
|
| 8 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaPipeline
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main(args):
|
| 12 |
+
# checkpoint from https://huggingface.co/Alpha-VLLM/Lumina-Next-SFT or https://huggingface.co/Alpha-VLLM/Lumina-Next-T2I
|
| 13 |
+
all_sd = load_file(args.origin_ckpt_path, device="cpu")
|
| 14 |
+
converted_state_dict = {}
|
| 15 |
+
# pad token
|
| 16 |
+
converted_state_dict["pad_token"] = all_sd["pad_token"]
|
| 17 |
+
|
| 18 |
+
# patch embed
|
| 19 |
+
converted_state_dict["patch_embedder.weight"] = all_sd["x_embedder.weight"]
|
| 20 |
+
converted_state_dict["patch_embedder.bias"] = all_sd["x_embedder.bias"]
|
| 21 |
+
|
| 22 |
+
# time and caption embed
|
| 23 |
+
converted_state_dict["time_caption_embed.timestep_embedder.linear_1.weight"] = all_sd["t_embedder.mlp.0.weight"]
|
| 24 |
+
converted_state_dict["time_caption_embed.timestep_embedder.linear_1.bias"] = all_sd["t_embedder.mlp.0.bias"]
|
| 25 |
+
converted_state_dict["time_caption_embed.timestep_embedder.linear_2.weight"] = all_sd["t_embedder.mlp.2.weight"]
|
| 26 |
+
converted_state_dict["time_caption_embed.timestep_embedder.linear_2.bias"] = all_sd["t_embedder.mlp.2.bias"]
|
| 27 |
+
converted_state_dict["time_caption_embed.caption_embedder.0.weight"] = all_sd["cap_embedder.0.weight"]
|
| 28 |
+
converted_state_dict["time_caption_embed.caption_embedder.0.bias"] = all_sd["cap_embedder.0.bias"]
|
| 29 |
+
converted_state_dict["time_caption_embed.caption_embedder.1.weight"] = all_sd["cap_embedder.1.weight"]
|
| 30 |
+
converted_state_dict["time_caption_embed.caption_embedder.1.bias"] = all_sd["cap_embedder.1.bias"]
|
| 31 |
+
|
| 32 |
+
for i in range(24):
|
| 33 |
+
# adaln
|
| 34 |
+
converted_state_dict[f"layers.{i}.gate"] = all_sd[f"layers.{i}.attention.gate"]
|
| 35 |
+
converted_state_dict[f"layers.{i}.adaLN_modulation.1.weight"] = all_sd[f"layers.{i}.adaLN_modulation.1.weight"]
|
| 36 |
+
converted_state_dict[f"layers.{i}.adaLN_modulation.1.bias"] = all_sd[f"layers.{i}.adaLN_modulation.1.bias"]
|
| 37 |
+
|
| 38 |
+
# qkv
|
| 39 |
+
converted_state_dict[f"layers.{i}.attn1.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
|
| 40 |
+
converted_state_dict[f"layers.{i}.attn1.to_k.weight"] = all_sd[f"layers.{i}.attention.wk.weight"]
|
| 41 |
+
converted_state_dict[f"layers.{i}.attn1.to_v.weight"] = all_sd[f"layers.{i}.attention.wv.weight"]
|
| 42 |
+
|
| 43 |
+
# cap
|
| 44 |
+
converted_state_dict[f"layers.{i}.attn2.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
|
| 45 |
+
converted_state_dict[f"layers.{i}.attn2.to_k.weight"] = all_sd[f"layers.{i}.attention.wk_y.weight"]
|
| 46 |
+
converted_state_dict[f"layers.{i}.attn2.to_v.weight"] = all_sd[f"layers.{i}.attention.wv_y.weight"]
|
| 47 |
+
|
| 48 |
+
# output
|
| 49 |
+
converted_state_dict[f"layers.{i}.attn2.to_out.0.weight"] = all_sd[f"layers.{i}.attention.wo.weight"]
|
| 50 |
+
|
| 51 |
+
# attention
|
| 52 |
+
# qk norm
|
| 53 |
+
converted_state_dict[f"layers.{i}.attn1.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
|
| 54 |
+
converted_state_dict[f"layers.{i}.attn1.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]
|
| 55 |
+
|
| 56 |
+
converted_state_dict[f"layers.{i}.attn1.norm_k.weight"] = all_sd[f"layers.{i}.attention.k_norm.weight"]
|
| 57 |
+
converted_state_dict[f"layers.{i}.attn1.norm_k.bias"] = all_sd[f"layers.{i}.attention.k_norm.bias"]
|
| 58 |
+
|
| 59 |
+
converted_state_dict[f"layers.{i}.attn2.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
|
| 60 |
+
converted_state_dict[f"layers.{i}.attn2.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]
|
| 61 |
+
|
| 62 |
+
converted_state_dict[f"layers.{i}.attn2.norm_k.weight"] = all_sd[f"layers.{i}.attention.ky_norm.weight"]
|
| 63 |
+
converted_state_dict[f"layers.{i}.attn2.norm_k.bias"] = all_sd[f"layers.{i}.attention.ky_norm.bias"]
|
| 64 |
+
|
| 65 |
+
# attention norm
|
| 66 |
+
converted_state_dict[f"layers.{i}.attn_norm1.weight"] = all_sd[f"layers.{i}.attention_norm1.weight"]
|
| 67 |
+
converted_state_dict[f"layers.{i}.attn_norm2.weight"] = all_sd[f"layers.{i}.attention_norm2.weight"]
|
| 68 |
+
converted_state_dict[f"layers.{i}.norm1_context.weight"] = all_sd[f"layers.{i}.attention_y_norm.weight"]
|
| 69 |
+
|
| 70 |
+
# feed forward
|
| 71 |
+
converted_state_dict[f"layers.{i}.feed_forward.linear_1.weight"] = all_sd[f"layers.{i}.feed_forward.w1.weight"]
|
| 72 |
+
converted_state_dict[f"layers.{i}.feed_forward.linear_2.weight"] = all_sd[f"layers.{i}.feed_forward.w2.weight"]
|
| 73 |
+
converted_state_dict[f"layers.{i}.feed_forward.linear_3.weight"] = all_sd[f"layers.{i}.feed_forward.w3.weight"]
|
| 74 |
+
|
| 75 |
+
# feed forward norm
|
| 76 |
+
converted_state_dict[f"layers.{i}.ffn_norm1.weight"] = all_sd[f"layers.{i}.ffn_norm1.weight"]
|
| 77 |
+
converted_state_dict[f"layers.{i}.ffn_norm2.weight"] = all_sd[f"layers.{i}.ffn_norm2.weight"]
|
| 78 |
+
|
| 79 |
+
# final layer
|
| 80 |
+
converted_state_dict["final_layer.linear.weight"] = all_sd["final_layer.linear.weight"]
|
| 81 |
+
converted_state_dict["final_layer.linear.bias"] = all_sd["final_layer.linear.bias"]
|
| 82 |
+
|
| 83 |
+
converted_state_dict["final_layer.adaLN_modulation.1.weight"] = all_sd["final_layer.adaLN_modulation.1.weight"]
|
| 84 |
+
converted_state_dict["final_layer.adaLN_modulation.1.bias"] = all_sd["final_layer.adaLN_modulation.1.bias"]
|
| 85 |
+
|
| 86 |
+
# Lumina-Next-SFT 2B
|
| 87 |
+
transformer = LuminaNextDiT2DModel(
|
| 88 |
+
sample_size=128,
|
| 89 |
+
patch_size=2,
|
| 90 |
+
in_channels=4,
|
| 91 |
+
hidden_size=2304,
|
| 92 |
+
num_layers=24,
|
| 93 |
+
num_attention_heads=32,
|
| 94 |
+
num_kv_heads=8,
|
| 95 |
+
multiple_of=256,
|
| 96 |
+
ffn_dim_multiplier=None,
|
| 97 |
+
norm_eps=1e-5,
|
| 98 |
+
learn_sigma=True,
|
| 99 |
+
qk_norm=True,
|
| 100 |
+
cross_attention_dim=2048,
|
| 101 |
+
scaling_factor=1.0,
|
| 102 |
+
)
|
| 103 |
+
transformer.load_state_dict(converted_state_dict, strict=True)
|
| 104 |
+
|
| 105 |
+
num_model_params = sum(p.numel() for p in transformer.parameters())
|
| 106 |
+
print(f"Total number of transformer parameters: {num_model_params}")
|
| 107 |
+
|
| 108 |
+
if args.only_transformer:
|
| 109 |
+
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
|
| 110 |
+
else:
|
| 111 |
+
scheduler = FlowMatchEulerDiscreteScheduler()
|
| 112 |
+
|
| 113 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float32)
|
| 114 |
+
|
| 115 |
+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
|
| 116 |
+
text_encoder = AutoModel.from_pretrained("google/gemma-2b")
|
| 117 |
+
|
| 118 |
+
pipeline = LuminaPipeline(
|
| 119 |
+
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
|
| 120 |
+
)
|
| 121 |
+
pipeline.save_pretrained(args.dump_path)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
parser = argparse.ArgumentParser()
|
| 126 |
+
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--image_size",
|
| 132 |
+
default=1024,
|
| 133 |
+
type=int,
|
| 134 |
+
choices=[256, 512, 1024],
|
| 135 |
+
required=False,
|
| 136 |
+
help="Image size of pretrained model, either 512 or 1024.",
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
| 139 |
+
parser.add_argument("--only_transformer", default=True, type=bool, required=True)
|
| 140 |
+
|
| 141 |
+
args = parser.parse_args()
|
| 142 |
+
main(args)
|
diffusers/scripts/convert_mochi_to_diffusers.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from contextlib import nullcontext
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from accelerate import init_empty_weights
|
| 6 |
+
from safetensors.torch import load_file
|
| 7 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 8 |
+
|
| 9 |
+
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
|
| 10 |
+
from diffusers.utils.import_utils import is_accelerate_available
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
| 14 |
+
|
| 15 |
+
TOKENIZER_MAX_LENGTH = 256
|
| 16 |
+
|
| 17 |
+
parser = argparse.ArgumentParser()
|
| 18 |
+
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
|
| 19 |
+
parser.add_argument("--vae_encoder_checkpoint_path", default=None, type=str)
|
| 20 |
+
parser.add_argument("--vae_decoder_checkpoint_path", default=None, type=str)
|
| 21 |
+
parser.add_argument("--output_path", required=True, type=str)
|
| 22 |
+
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
|
| 23 |
+
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
|
| 24 |
+
parser.add_argument("--dtype", type=str, default=None)
|
| 25 |
+
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# This is specific to `AdaLayerNormContinuous`:
|
| 30 |
+
# Diffusers implementation split the linear projection into the scale, shift while Mochi split it into shift, scale
|
| 31 |
+
def swap_scale_shift(weight, dim):
|
| 32 |
+
shift, scale = weight.chunk(2, dim=0)
|
| 33 |
+
new_weight = torch.cat([scale, shift], dim=0)
|
| 34 |
+
return new_weight
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def swap_proj_gate(weight):
|
| 38 |
+
proj, gate = weight.chunk(2, dim=0)
|
| 39 |
+
new_weight = torch.cat([gate, proj], dim=0)
|
| 40 |
+
return new_weight
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
|
| 44 |
+
original_state_dict = load_file(ckpt_path, device="cpu")
|
| 45 |
+
new_state_dict = {}
|
| 46 |
+
|
| 47 |
+
# Convert patch_embed
|
| 48 |
+
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight")
|
| 49 |
+
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias")
|
| 50 |
+
|
| 51 |
+
# Convert time_embed
|
| 52 |
+
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight")
|
| 53 |
+
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias")
|
| 54 |
+
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight")
|
| 55 |
+
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("t_embedder.mlp.2.bias")
|
| 56 |
+
new_state_dict["time_embed.pooler.to_kv.weight"] = original_state_dict.pop("t5_y_embedder.to_kv.weight")
|
| 57 |
+
new_state_dict["time_embed.pooler.to_kv.bias"] = original_state_dict.pop("t5_y_embedder.to_kv.bias")
|
| 58 |
+
new_state_dict["time_embed.pooler.to_q.weight"] = original_state_dict.pop("t5_y_embedder.to_q.weight")
|
| 59 |
+
new_state_dict["time_embed.pooler.to_q.bias"] = original_state_dict.pop("t5_y_embedder.to_q.bias")
|
| 60 |
+
new_state_dict["time_embed.pooler.to_out.weight"] = original_state_dict.pop("t5_y_embedder.to_out.weight")
|
| 61 |
+
new_state_dict["time_embed.pooler.to_out.bias"] = original_state_dict.pop("t5_y_embedder.to_out.bias")
|
| 62 |
+
new_state_dict["time_embed.caption_proj.weight"] = original_state_dict.pop("t5_yproj.weight")
|
| 63 |
+
new_state_dict["time_embed.caption_proj.bias"] = original_state_dict.pop("t5_yproj.bias")
|
| 64 |
+
|
| 65 |
+
# Convert transformer blocks
|
| 66 |
+
num_layers = 48
|
| 67 |
+
for i in range(num_layers):
|
| 68 |
+
block_prefix = f"transformer_blocks.{i}."
|
| 69 |
+
old_prefix = f"blocks.{i}."
|
| 70 |
+
|
| 71 |
+
# norm1
|
| 72 |
+
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(old_prefix + "mod_x.weight")
|
| 73 |
+
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(old_prefix + "mod_x.bias")
|
| 74 |
+
if i < num_layers - 1:
|
| 75 |
+
new_state_dict[block_prefix + "norm1_context.linear.weight"] = original_state_dict.pop(
|
| 76 |
+
old_prefix + "mod_y.weight"
|
| 77 |
+
)
|
| 78 |
+
new_state_dict[block_prefix + "norm1_context.linear.bias"] = original_state_dict.pop(
|
| 79 |
+
old_prefix + "mod_y.bias"
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop(
|
| 83 |
+
old_prefix + "mod_y.weight"
|
| 84 |
+
)
|
| 85 |
+
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop(
|
| 86 |
+
old_prefix + "mod_y.bias"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Visual attention
|
| 90 |
+
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight")
|
| 91 |
+
q, k, v = qkv_weight.chunk(3, dim=0)
|
| 92 |
+
|
| 93 |
+
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
| 94 |
+
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
| 95 |
+
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
| 96 |
+
new_state_dict[block_prefix + "attn1.norm_q.weight"] = original_state_dict.pop(
|
| 97 |
+
old_prefix + "attn.q_norm_x.weight"
|
| 98 |
+
)
|
| 99 |
+
new_state_dict[block_prefix + "attn1.norm_k.weight"] = original_state_dict.pop(
|
| 100 |
+
old_prefix + "attn.k_norm_x.weight"
|
| 101 |
+
)
|
| 102 |
+
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
|
| 103 |
+
old_prefix + "attn.proj_x.weight"
|
| 104 |
+
)
|
| 105 |
+
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(old_prefix + "attn.proj_x.bias")
|
| 106 |
+
|
| 107 |
+
# Context attention
|
| 108 |
+
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight")
|
| 109 |
+
q, k, v = qkv_weight.chunk(3, dim=0)
|
| 110 |
+
|
| 111 |
+
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
|
| 112 |
+
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
|
| 113 |
+
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
|
| 114 |
+
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = original_state_dict.pop(
|
| 115 |
+
old_prefix + "attn.q_norm_y.weight"
|
| 116 |
+
)
|
| 117 |
+
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = original_state_dict.pop(
|
| 118 |
+
old_prefix + "attn.k_norm_y.weight"
|
| 119 |
+
)
|
| 120 |
+
if i < num_layers - 1:
|
| 121 |
+
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = original_state_dict.pop(
|
| 122 |
+
old_prefix + "attn.proj_y.weight"
|
| 123 |
+
)
|
| 124 |
+
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = original_state_dict.pop(
|
| 125 |
+
old_prefix + "attn.proj_y.bias"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# MLP
|
| 129 |
+
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
|
| 130 |
+
original_state_dict.pop(old_prefix + "mlp_x.w1.weight")
|
| 131 |
+
)
|
| 132 |
+
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight")
|
| 133 |
+
if i < num_layers - 1:
|
| 134 |
+
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
|
| 135 |
+
original_state_dict.pop(old_prefix + "mlp_y.w1.weight")
|
| 136 |
+
)
|
| 137 |
+
new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop(
|
| 138 |
+
old_prefix + "mlp_y.w2.weight"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Output layers
|
| 142 |
+
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
| 143 |
+
original_state_dict.pop("final_layer.mod.weight"), dim=0
|
| 144 |
+
)
|
| 145 |
+
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.bias"), dim=0)
|
| 146 |
+
new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
|
| 147 |
+
new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
|
| 148 |
+
|
| 149 |
+
new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies")
|
| 150 |
+
|
| 151 |
+
print("Remaining Keys:", original_state_dict.keys())
|
| 152 |
+
|
| 153 |
+
return new_state_dict
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_path):
|
| 157 |
+
encoder_state_dict = load_file(encoder_ckpt_path, device="cpu")
|
| 158 |
+
decoder_state_dict = load_file(decoder_ckpt_path, device="cpu")
|
| 159 |
+
new_state_dict = {}
|
| 160 |
+
|
| 161 |
+
# ==== Decoder =====
|
| 162 |
+
prefix = "decoder."
|
| 163 |
+
|
| 164 |
+
# Convert conv_in
|
| 165 |
+
new_state_dict[f"{prefix}conv_in.weight"] = decoder_state_dict.pop("blocks.0.0.weight")
|
| 166 |
+
new_state_dict[f"{prefix}conv_in.bias"] = decoder_state_dict.pop("blocks.0.0.bias")
|
| 167 |
+
|
| 168 |
+
# Convert block_in (MochiMidBlock3D)
|
| 169 |
+
for i in range(3): # layers_per_block[-1] = 3
|
| 170 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
|
| 171 |
+
f"blocks.0.{i + 1}.stack.0.weight"
|
| 172 |
+
)
|
| 173 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
|
| 174 |
+
f"blocks.0.{i + 1}.stack.0.bias"
|
| 175 |
+
)
|
| 176 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
|
| 177 |
+
f"blocks.0.{i + 1}.stack.2.weight"
|
| 178 |
+
)
|
| 179 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
|
| 180 |
+
f"blocks.0.{i + 1}.stack.2.bias"
|
| 181 |
+
)
|
| 182 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
|
| 183 |
+
f"blocks.0.{i + 1}.stack.3.weight"
|
| 184 |
+
)
|
| 185 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
|
| 186 |
+
f"blocks.0.{i + 1}.stack.3.bias"
|
| 187 |
+
)
|
| 188 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
|
| 189 |
+
f"blocks.0.{i + 1}.stack.5.weight"
|
| 190 |
+
)
|
| 191 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
|
| 192 |
+
f"blocks.0.{i + 1}.stack.5.bias"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Convert up_blocks (MochiUpBlock3D)
|
| 196 |
+
down_block_layers = [6, 4, 3] # layers_per_block[-2], layers_per_block[-3], layers_per_block[-4]
|
| 197 |
+
for block in range(3):
|
| 198 |
+
for i in range(down_block_layers[block]):
|
| 199 |
+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
|
| 200 |
+
f"blocks.{block + 1}.blocks.{i}.stack.0.weight"
|
| 201 |
+
)
|
| 202 |
+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
|
| 203 |
+
f"blocks.{block + 1}.blocks.{i}.stack.0.bias"
|
| 204 |
+
)
|
| 205 |
+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
|
| 206 |
+
f"blocks.{block + 1}.blocks.{i}.stack.2.weight"
|
| 207 |
+
)
|
| 208 |
+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
|
| 209 |
+
f"blocks.{block + 1}.blocks.{i}.stack.2.bias"
|
| 210 |
+
)
|
| 211 |
+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
|
| 212 |
+
f"blocks.{block + 1}.blocks.{i}.stack.3.weight"
|
| 213 |
+
)
|
| 214 |
+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
|
| 215 |
+
f"blocks.{block + 1}.blocks.{i}.stack.3.bias"
|
| 216 |
+
)
|
| 217 |
+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
|
| 218 |
+
f"blocks.{block + 1}.blocks.{i}.stack.5.weight"
|
| 219 |
+
)
|
| 220 |
+
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
|
| 221 |
+
f"blocks.{block + 1}.blocks.{i}.stack.5.bias"
|
| 222 |
+
)
|
| 223 |
+
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
|
| 224 |
+
f"blocks.{block + 1}.proj.weight"
|
| 225 |
+
)
|
| 226 |
+
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(
|
| 227 |
+
f"blocks.{block + 1}.proj.bias"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Convert block_out (MochiMidBlock3D)
|
| 231 |
+
for i in range(3): # layers_per_block[0] = 3
|
| 232 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
|
| 233 |
+
f"blocks.4.{i}.stack.0.weight"
|
| 234 |
+
)
|
| 235 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
|
| 236 |
+
f"blocks.4.{i}.stack.0.bias"
|
| 237 |
+
)
|
| 238 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
|
| 239 |
+
f"blocks.4.{i}.stack.2.weight"
|
| 240 |
+
)
|
| 241 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
|
| 242 |
+
f"blocks.4.{i}.stack.2.bias"
|
| 243 |
+
)
|
| 244 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
|
| 245 |
+
f"blocks.4.{i}.stack.3.weight"
|
| 246 |
+
)
|
| 247 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
|
| 248 |
+
f"blocks.4.{i}.stack.3.bias"
|
| 249 |
+
)
|
| 250 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
|
| 251 |
+
f"blocks.4.{i}.stack.5.weight"
|
| 252 |
+
)
|
| 253 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
|
| 254 |
+
f"blocks.4.{i}.stack.5.bias"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Convert proj_out (Conv1x1 ~= nn.Linear)
|
| 258 |
+
new_state_dict[f"{prefix}proj_out.weight"] = decoder_state_dict.pop("output_proj.weight")
|
| 259 |
+
new_state_dict[f"{prefix}proj_out.bias"] = decoder_state_dict.pop("output_proj.bias")
|
| 260 |
+
|
| 261 |
+
print("Remaining Decoder Keys:", decoder_state_dict.keys())
|
| 262 |
+
|
| 263 |
+
# ==== Encoder =====
|
| 264 |
+
prefix = "encoder."
|
| 265 |
+
|
| 266 |
+
new_state_dict[f"{prefix}proj_in.weight"] = encoder_state_dict.pop("layers.0.weight")
|
| 267 |
+
new_state_dict[f"{prefix}proj_in.bias"] = encoder_state_dict.pop("layers.0.bias")
|
| 268 |
+
|
| 269 |
+
# Convert block_in (MochiMidBlock3D)
|
| 270 |
+
for i in range(3): # layers_per_block[0] = 3
|
| 271 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
|
| 272 |
+
f"layers.{i + 1}.stack.0.weight"
|
| 273 |
+
)
|
| 274 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
| 275 |
+
f"layers.{i + 1}.stack.0.bias"
|
| 276 |
+
)
|
| 277 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
|
| 278 |
+
f"layers.{i + 1}.stack.2.weight"
|
| 279 |
+
)
|
| 280 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
| 281 |
+
f"layers.{i + 1}.stack.2.bias"
|
| 282 |
+
)
|
| 283 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
|
| 284 |
+
f"layers.{i + 1}.stack.3.weight"
|
| 285 |
+
)
|
| 286 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
| 287 |
+
f"layers.{i + 1}.stack.3.bias"
|
| 288 |
+
)
|
| 289 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
|
| 290 |
+
f"layers.{i + 1}.stack.5.weight"
|
| 291 |
+
)
|
| 292 |
+
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
|
| 293 |
+
f"layers.{i + 1}.stack.5.bias"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Convert down_blocks (MochiDownBlock3D)
|
| 297 |
+
down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3]
|
| 298 |
+
for block in range(3):
|
| 299 |
+
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop(
|
| 300 |
+
f"layers.{block + 4}.layers.0.weight"
|
| 301 |
+
)
|
| 302 |
+
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop(
|
| 303 |
+
f"layers.{block + 4}.layers.0.bias"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
for i in range(down_block_layers[block]):
|
| 307 |
+
# Convert resnets
|
| 308 |
+
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = (
|
| 309 |
+
encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight")
|
| 310 |
+
)
|
| 311 |
+
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
| 312 |
+
f"layers.{block + 4}.layers.{i + 1}.stack.0.bias"
|
| 313 |
+
)
|
| 314 |
+
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
|
| 315 |
+
f"layers.{block + 4}.layers.{i + 1}.stack.2.weight"
|
| 316 |
+
)
|
| 317 |
+
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
| 318 |
+
f"layers.{block + 4}.layers.{i + 1}.stack.2.bias"
|
| 319 |
+
)
|
| 320 |
+
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = (
|
| 321 |
+
encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight")
|
| 322 |
+
)
|
| 323 |
+
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
| 324 |
+
f"layers.{block + 4}.layers.{i + 1}.stack.3.bias"
|
| 325 |
+
)
|
| 326 |
+
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
|
| 327 |
+
f"layers.{block + 4}.layers.{i + 1}.stack.5.weight"
|
| 328 |
+
)
|
| 329 |
+
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
|
| 330 |
+
f"layers.{block + 4}.layers.{i + 1}.stack.5.bias"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# Convert attentions
|
| 334 |
+
qkv_weight = encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.qkv.weight")
|
| 335 |
+
q, k, v = qkv_weight.chunk(3, dim=0)
|
| 336 |
+
|
| 337 |
+
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q
|
| 338 |
+
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k
|
| 339 |
+
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v
|
| 340 |
+
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
|
| 341 |
+
f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.weight"
|
| 342 |
+
)
|
| 343 |
+
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
|
| 344 |
+
f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.bias"
|
| 345 |
+
)
|
| 346 |
+
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
|
| 347 |
+
f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.weight"
|
| 348 |
+
)
|
| 349 |
+
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
|
| 350 |
+
f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.bias"
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# Convert block_out (MochiMidBlock3D)
|
| 354 |
+
for i in range(3): # layers_per_block[-1] = 3
|
| 355 |
+
# Convert resnets
|
| 356 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
|
| 357 |
+
f"layers.{i + 7}.stack.0.weight"
|
| 358 |
+
)
|
| 359 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
| 360 |
+
f"layers.{i + 7}.stack.0.bias"
|
| 361 |
+
)
|
| 362 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
|
| 363 |
+
f"layers.{i + 7}.stack.2.weight"
|
| 364 |
+
)
|
| 365 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
| 366 |
+
f"layers.{i + 7}.stack.2.bias"
|
| 367 |
+
)
|
| 368 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
|
| 369 |
+
f"layers.{i + 7}.stack.3.weight"
|
| 370 |
+
)
|
| 371 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
| 372 |
+
f"layers.{i + 7}.stack.3.bias"
|
| 373 |
+
)
|
| 374 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
|
| 375 |
+
f"layers.{i + 7}.stack.5.weight"
|
| 376 |
+
)
|
| 377 |
+
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
|
| 378 |
+
f"layers.{i + 7}.stack.5.bias"
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
# Convert attentions
|
| 382 |
+
qkv_weight = encoder_state_dict.pop(f"layers.{i + 7}.attn_block.attn.qkv.weight")
|
| 383 |
+
q, k, v = qkv_weight.chunk(3, dim=0)
|
| 384 |
+
|
| 385 |
+
new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q
|
| 386 |
+
new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k
|
| 387 |
+
new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v
|
| 388 |
+
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
|
| 389 |
+
f"layers.{i + 7}.attn_block.attn.out.weight"
|
| 390 |
+
)
|
| 391 |
+
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
|
| 392 |
+
f"layers.{i + 7}.attn_block.attn.out.bias"
|
| 393 |
+
)
|
| 394 |
+
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
|
| 395 |
+
f"layers.{i + 7}.attn_block.norm.weight"
|
| 396 |
+
)
|
| 397 |
+
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
|
| 398 |
+
f"layers.{i + 7}.attn_block.norm.bias"
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
# Convert output layers
|
| 402 |
+
new_state_dict[f"{prefix}norm_out.norm_layer.weight"] = encoder_state_dict.pop("output_norm.weight")
|
| 403 |
+
new_state_dict[f"{prefix}norm_out.norm_layer.bias"] = encoder_state_dict.pop("output_norm.bias")
|
| 404 |
+
new_state_dict[f"{prefix}proj_out.weight"] = encoder_state_dict.pop("output_proj.weight")
|
| 405 |
+
|
| 406 |
+
print("Remaining Encoder Keys:", encoder_state_dict.keys())
|
| 407 |
+
|
| 408 |
+
return new_state_dict
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def main(args):
|
| 412 |
+
if args.dtype is None:
|
| 413 |
+
dtype = None
|
| 414 |
+
if args.dtype == "fp16":
|
| 415 |
+
dtype = torch.float16
|
| 416 |
+
elif args.dtype == "bf16":
|
| 417 |
+
dtype = torch.bfloat16
|
| 418 |
+
elif args.dtype == "fp32":
|
| 419 |
+
dtype = torch.float32
|
| 420 |
+
else:
|
| 421 |
+
raise ValueError(f"Unsupported dtype: {args.dtype}")
|
| 422 |
+
|
| 423 |
+
transformer = None
|
| 424 |
+
vae = None
|
| 425 |
+
|
| 426 |
+
if args.transformer_checkpoint_path is not None:
|
| 427 |
+
converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers(
|
| 428 |
+
args.transformer_checkpoint_path
|
| 429 |
+
)
|
| 430 |
+
transformer = MochiTransformer3DModel()
|
| 431 |
+
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
|
| 432 |
+
if dtype is not None:
|
| 433 |
+
transformer = transformer.to(dtype=dtype)
|
| 434 |
+
|
| 435 |
+
if args.vae_encoder_checkpoint_path is not None and args.vae_decoder_checkpoint_path is not None:
|
| 436 |
+
vae = AutoencoderKLMochi(latent_channels=12, out_channels=3)
|
| 437 |
+
converted_vae_state_dict = convert_mochi_vae_state_dict_to_diffusers(
|
| 438 |
+
args.vae_encoder_checkpoint_path, args.vae_decoder_checkpoint_path
|
| 439 |
+
)
|
| 440 |
+
vae.load_state_dict(converted_vae_state_dict, strict=True)
|
| 441 |
+
if dtype is not None:
|
| 442 |
+
vae = vae.to(dtype=dtype)
|
| 443 |
+
|
| 444 |
+
text_encoder_id = "google/t5-v1_1-xxl"
|
| 445 |
+
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
| 446 |
+
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
| 447 |
+
|
| 448 |
+
# Apparently, the conversion does not work anymore without this :shrug:
|
| 449 |
+
for param in text_encoder.parameters():
|
| 450 |
+
param.data = param.data.contiguous()
|
| 451 |
+
|
| 452 |
+
pipe = MochiPipeline(
|
| 453 |
+
scheduler=FlowMatchEulerDiscreteScheduler(invert_sigmas=True),
|
| 454 |
+
vae=vae,
|
| 455 |
+
text_encoder=text_encoder,
|
| 456 |
+
tokenizer=tokenizer,
|
| 457 |
+
transformer=transformer,
|
| 458 |
+
)
|
| 459 |
+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
if __name__ == "__main__":
|
| 463 |
+
main(args)
|
diffusers/scripts/convert_models_diffuser_to_diffusers.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from diffusers import UNet1DModel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True)
|
| 10 |
+
os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True)
|
| 11 |
+
|
| 12 |
+
os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def unet(hor):
|
| 16 |
+
if hor == 128:
|
| 17 |
+
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
|
| 18 |
+
block_out_channels = (32, 128, 256)
|
| 19 |
+
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D")
|
| 20 |
+
|
| 21 |
+
elif hor == 32:
|
| 22 |
+
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
|
| 23 |
+
block_out_channels = (32, 64, 128, 256)
|
| 24 |
+
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
|
| 25 |
+
model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
|
| 26 |
+
state_dict = model.state_dict()
|
| 27 |
+
config = {
|
| 28 |
+
"down_block_types": down_block_types,
|
| 29 |
+
"block_out_channels": block_out_channels,
|
| 30 |
+
"up_block_types": up_block_types,
|
| 31 |
+
"layers_per_block": 1,
|
| 32 |
+
"use_timestep_embedding": True,
|
| 33 |
+
"out_block_type": "OutConv1DBlock",
|
| 34 |
+
"norm_num_groups": 8,
|
| 35 |
+
"downsample_each_block": False,
|
| 36 |
+
"in_channels": 14,
|
| 37 |
+
"out_channels": 14,
|
| 38 |
+
"extra_in_channels": 0,
|
| 39 |
+
"time_embedding_type": "positional",
|
| 40 |
+
"flip_sin_to_cos": False,
|
| 41 |
+
"freq_shift": 1,
|
| 42 |
+
"sample_size": 65536,
|
| 43 |
+
"mid_block_type": "MidResTemporalBlock1D",
|
| 44 |
+
"act_fn": "mish",
|
| 45 |
+
}
|
| 46 |
+
hf_value_function = UNet1DModel(**config)
|
| 47 |
+
print(f"length of state dict: {len(state_dict.keys())}")
|
| 48 |
+
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
|
| 49 |
+
mapping = dict(zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
|
| 50 |
+
for k, v in mapping.items():
|
| 51 |
+
state_dict[v] = state_dict.pop(k)
|
| 52 |
+
hf_value_function.load_state_dict(state_dict)
|
| 53 |
+
|
| 54 |
+
torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin")
|
| 55 |
+
with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f:
|
| 56 |
+
json.dump(config, f)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def value_function():
|
| 60 |
+
config = {
|
| 61 |
+
"in_channels": 14,
|
| 62 |
+
"down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
|
| 63 |
+
"up_block_types": (),
|
| 64 |
+
"out_block_type": "ValueFunction",
|
| 65 |
+
"mid_block_type": "ValueFunctionMidBlock1D",
|
| 66 |
+
"block_out_channels": (32, 64, 128, 256),
|
| 67 |
+
"layers_per_block": 1,
|
| 68 |
+
"downsample_each_block": True,
|
| 69 |
+
"sample_size": 65536,
|
| 70 |
+
"out_channels": 14,
|
| 71 |
+
"extra_in_channels": 0,
|
| 72 |
+
"time_embedding_type": "positional",
|
| 73 |
+
"use_timestep_embedding": True,
|
| 74 |
+
"flip_sin_to_cos": False,
|
| 75 |
+
"freq_shift": 1,
|
| 76 |
+
"norm_num_groups": 8,
|
| 77 |
+
"act_fn": "mish",
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
|
| 81 |
+
state_dict = model
|
| 82 |
+
hf_value_function = UNet1DModel(**config)
|
| 83 |
+
print(f"length of state dict: {len(state_dict.keys())}")
|
| 84 |
+
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
|
| 85 |
+
|
| 86 |
+
mapping = dict(zip(state_dict.keys(), hf_value_function.state_dict().keys()))
|
| 87 |
+
for k, v in mapping.items():
|
| 88 |
+
state_dict[v] = state_dict.pop(k)
|
| 89 |
+
|
| 90 |
+
hf_value_function.load_state_dict(state_dict)
|
| 91 |
+
|
| 92 |
+
torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin")
|
| 93 |
+
with open("hub/hopper-medium-v2/value_function/config.json", "w") as f:
|
| 94 |
+
json.dump(config, f)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
unet(32)
|
| 99 |
+
# unet(128)
|
| 100 |
+
value_function()
|
diffusers/scripts/convert_ms_text_to_video_to_diffusers.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Conversion script for the LDM checkpoints."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from diffusers import UNet3DConditionModel
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def assign_to_checkpoint(
|
| 25 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
| 29 |
+
attention layers, and takes into account additional replacements that may arise.
|
| 30 |
+
|
| 31 |
+
Assigns the weights to the new checkpoint.
|
| 32 |
+
"""
|
| 33 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
| 34 |
+
|
| 35 |
+
# Splits the attention layers into three variables.
|
| 36 |
+
if attention_paths_to_split is not None:
|
| 37 |
+
for path, path_map in attention_paths_to_split.items():
|
| 38 |
+
old_tensor = old_checkpoint[path]
|
| 39 |
+
channels = old_tensor.shape[0] // 3
|
| 40 |
+
|
| 41 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
| 42 |
+
|
| 43 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
| 44 |
+
|
| 45 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
| 46 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
| 47 |
+
|
| 48 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
| 49 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
| 50 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
| 51 |
+
|
| 52 |
+
for path in paths:
|
| 53 |
+
new_path = path["new"]
|
| 54 |
+
|
| 55 |
+
# These have already been assigned
|
| 56 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
if additional_replacements is not None:
|
| 60 |
+
for replacement in additional_replacements:
|
| 61 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
| 62 |
+
|
| 63 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
| 64 |
+
weight = old_checkpoint[path["old"]]
|
| 65 |
+
names = ["proj_attn.weight"]
|
| 66 |
+
names_2 = ["proj_out.weight", "proj_in.weight"]
|
| 67 |
+
if any(k in new_path for k in names):
|
| 68 |
+
checkpoint[new_path] = weight[:, :, 0]
|
| 69 |
+
elif any(k in new_path for k in names_2) and len(weight.shape) > 2 and ".attentions." not in new_path:
|
| 70 |
+
checkpoint[new_path] = weight[:, :, 0]
|
| 71 |
+
else:
|
| 72 |
+
checkpoint[new_path] = weight
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 76 |
+
"""
|
| 77 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 78 |
+
"""
|
| 79 |
+
mapping = []
|
| 80 |
+
for old_item in old_list:
|
| 81 |
+
new_item = old_item
|
| 82 |
+
|
| 83 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
| 84 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
| 85 |
+
|
| 86 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
| 87 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
| 88 |
+
|
| 89 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 90 |
+
|
| 91 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 92 |
+
|
| 93 |
+
return mapping
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
| 97 |
+
"""
|
| 98 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
| 99 |
+
"""
|
| 100 |
+
if n_shave_prefix_segments >= 0:
|
| 101 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
| 102 |
+
else:
|
| 103 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def renew_temp_conv_paths(old_list, n_shave_prefix_segments=0):
|
| 107 |
+
"""
|
| 108 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 109 |
+
"""
|
| 110 |
+
mapping = []
|
| 111 |
+
for old_item in old_list:
|
| 112 |
+
mapping.append({"old": old_item, "new": old_item})
|
| 113 |
+
|
| 114 |
+
return mapping
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 118 |
+
"""
|
| 119 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 120 |
+
"""
|
| 121 |
+
mapping = []
|
| 122 |
+
for old_item in old_list:
|
| 123 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
| 124 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
| 125 |
+
|
| 126 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
| 127 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
| 128 |
+
|
| 129 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
| 130 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
| 131 |
+
|
| 132 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 133 |
+
|
| 134 |
+
if "temopral_conv" not in old_item:
|
| 135 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 136 |
+
|
| 137 |
+
return mapping
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
|
| 141 |
+
"""
|
| 142 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
# extract state_dict for UNet
|
| 146 |
+
unet_state_dict = {}
|
| 147 |
+
keys = list(checkpoint.keys())
|
| 148 |
+
|
| 149 |
+
unet_key = "model.diffusion_model."
|
| 150 |
+
|
| 151 |
+
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
| 152 |
+
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
| 153 |
+
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
| 154 |
+
print(
|
| 155 |
+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
| 156 |
+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
| 157 |
+
)
|
| 158 |
+
for key in keys:
|
| 159 |
+
if key.startswith("model.diffusion_model"):
|
| 160 |
+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
| 161 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
| 162 |
+
else:
|
| 163 |
+
if sum(k.startswith("model_ema") for k in keys) > 100:
|
| 164 |
+
print(
|
| 165 |
+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
| 166 |
+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
for key in keys:
|
| 170 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
| 171 |
+
|
| 172 |
+
new_checkpoint = {}
|
| 173 |
+
|
| 174 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
| 175 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
| 176 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
| 177 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
| 178 |
+
|
| 179 |
+
if config["class_embed_type"] is None:
|
| 180 |
+
# No parameters to port
|
| 181 |
+
...
|
| 182 |
+
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
| 183 |
+
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
| 184 |
+
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
| 185 |
+
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
| 186 |
+
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
| 187 |
+
else:
|
| 188 |
+
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
| 189 |
+
|
| 190 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
| 191 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
| 192 |
+
|
| 193 |
+
first_temp_attention = [v for v in unet_state_dict if v.startswith("input_blocks.0.1")]
|
| 194 |
+
paths = renew_attention_paths(first_temp_attention)
|
| 195 |
+
meta_path = {"old": "input_blocks.0.1", "new": "transformer_in"}
|
| 196 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
| 197 |
+
|
| 198 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
| 199 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
| 200 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
| 201 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
| 202 |
+
|
| 203 |
+
# Retrieves the keys for the input blocks only
|
| 204 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
| 205 |
+
input_blocks = {
|
| 206 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
| 207 |
+
for layer_id in range(num_input_blocks)
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
# Retrieves the keys for the middle blocks only
|
| 211 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
| 212 |
+
middle_blocks = {
|
| 213 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
| 214 |
+
for layer_id in range(num_middle_blocks)
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
# Retrieves the keys for the output blocks only
|
| 218 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
| 219 |
+
output_blocks = {
|
| 220 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
| 221 |
+
for layer_id in range(num_output_blocks)
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
for i in range(1, num_input_blocks):
|
| 225 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
| 226 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
| 227 |
+
|
| 228 |
+
resnets = [
|
| 229 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
| 230 |
+
]
|
| 231 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
| 232 |
+
temp_attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.2" in key]
|
| 233 |
+
|
| 234 |
+
if f"input_blocks.{i}.op.weight" in unet_state_dict:
|
| 235 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
| 236 |
+
f"input_blocks.{i}.op.weight"
|
| 237 |
+
)
|
| 238 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
| 239 |
+
f"input_blocks.{i}.op.bias"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
paths = renew_resnet_paths(resnets)
|
| 243 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 244 |
+
assign_to_checkpoint(
|
| 245 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
temporal_convs = [key for key in resnets if "temopral_conv" in key]
|
| 249 |
+
paths = renew_temp_conv_paths(temporal_convs)
|
| 250 |
+
meta_path = {
|
| 251 |
+
"old": f"input_blocks.{i}.0.temopral_conv",
|
| 252 |
+
"new": f"down_blocks.{block_id}.temp_convs.{layer_in_block_id}",
|
| 253 |
+
}
|
| 254 |
+
assign_to_checkpoint(
|
| 255 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
if len(attentions):
|
| 259 |
+
paths = renew_attention_paths(attentions)
|
| 260 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
| 261 |
+
assign_to_checkpoint(
|
| 262 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if len(temp_attentions):
|
| 266 |
+
paths = renew_attention_paths(temp_attentions)
|
| 267 |
+
meta_path = {
|
| 268 |
+
"old": f"input_blocks.{i}.2",
|
| 269 |
+
"new": f"down_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
|
| 270 |
+
}
|
| 271 |
+
assign_to_checkpoint(
|
| 272 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
resnet_0 = middle_blocks[0]
|
| 276 |
+
temporal_convs_0 = [key for key in resnet_0 if "temopral_conv" in key]
|
| 277 |
+
attentions = middle_blocks[1]
|
| 278 |
+
temp_attentions = middle_blocks[2]
|
| 279 |
+
resnet_1 = middle_blocks[3]
|
| 280 |
+
temporal_convs_1 = [key for key in resnet_1 if "temopral_conv" in key]
|
| 281 |
+
|
| 282 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
| 283 |
+
meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"}
|
| 284 |
+
assign_to_checkpoint(
|
| 285 |
+
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
temp_conv_0_paths = renew_temp_conv_paths(temporal_convs_0)
|
| 289 |
+
meta_path = {"old": "middle_block.0.temopral_conv", "new": "mid_block.temp_convs.0"}
|
| 290 |
+
assign_to_checkpoint(
|
| 291 |
+
temp_conv_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
| 295 |
+
meta_path = {"old": "middle_block.3", "new": "mid_block.resnets.1"}
|
| 296 |
+
assign_to_checkpoint(
|
| 297 |
+
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
temp_conv_1_paths = renew_temp_conv_paths(temporal_convs_1)
|
| 301 |
+
meta_path = {"old": "middle_block.3.temopral_conv", "new": "mid_block.temp_convs.1"}
|
| 302 |
+
assign_to_checkpoint(
|
| 303 |
+
temp_conv_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
attentions_paths = renew_attention_paths(attentions)
|
| 307 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
| 308 |
+
assign_to_checkpoint(
|
| 309 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
temp_attentions_paths = renew_attention_paths(temp_attentions)
|
| 313 |
+
meta_path = {"old": "middle_block.2", "new": "mid_block.temp_attentions.0"}
|
| 314 |
+
assign_to_checkpoint(
|
| 315 |
+
temp_attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
for i in range(num_output_blocks):
|
| 319 |
+
block_id = i // (config["layers_per_block"] + 1)
|
| 320 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
| 321 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
| 322 |
+
output_block_list = {}
|
| 323 |
+
|
| 324 |
+
for layer in output_block_layers:
|
| 325 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
| 326 |
+
if layer_id in output_block_list:
|
| 327 |
+
output_block_list[layer_id].append(layer_name)
|
| 328 |
+
else:
|
| 329 |
+
output_block_list[layer_id] = [layer_name]
|
| 330 |
+
|
| 331 |
+
if len(output_block_list) > 1:
|
| 332 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
| 333 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
| 334 |
+
temp_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key]
|
| 335 |
+
|
| 336 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
| 337 |
+
paths = renew_resnet_paths(resnets)
|
| 338 |
+
|
| 339 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 340 |
+
assign_to_checkpoint(
|
| 341 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
temporal_convs = [key for key in resnets if "temopral_conv" in key]
|
| 345 |
+
paths = renew_temp_conv_paths(temporal_convs)
|
| 346 |
+
meta_path = {
|
| 347 |
+
"old": f"output_blocks.{i}.0.temopral_conv",
|
| 348 |
+
"new": f"up_blocks.{block_id}.temp_convs.{layer_in_block_id}",
|
| 349 |
+
}
|
| 350 |
+
assign_to_checkpoint(
|
| 351 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
| 355 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
| 356 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
| 357 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
| 358 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
| 359 |
+
]
|
| 360 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
| 361 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
| 362 |
+
]
|
| 363 |
+
|
| 364 |
+
# Clear attentions as they have been attributed above.
|
| 365 |
+
if len(attentions) == 2:
|
| 366 |
+
attentions = []
|
| 367 |
+
|
| 368 |
+
if len(attentions):
|
| 369 |
+
paths = renew_attention_paths(attentions)
|
| 370 |
+
meta_path = {
|
| 371 |
+
"old": f"output_blocks.{i}.1",
|
| 372 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
| 373 |
+
}
|
| 374 |
+
assign_to_checkpoint(
|
| 375 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
if len(temp_attentions):
|
| 379 |
+
paths = renew_attention_paths(temp_attentions)
|
| 380 |
+
meta_path = {
|
| 381 |
+
"old": f"output_blocks.{i}.2",
|
| 382 |
+
"new": f"up_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
|
| 383 |
+
}
|
| 384 |
+
assign_to_checkpoint(
|
| 385 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 386 |
+
)
|
| 387 |
+
else:
|
| 388 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
| 389 |
+
for path in resnet_0_paths:
|
| 390 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
| 391 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
| 392 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
| 393 |
+
|
| 394 |
+
temopral_conv_paths = [l for l in output_block_layers if "temopral_conv" in l]
|
| 395 |
+
for path in temopral_conv_paths:
|
| 396 |
+
pruned_path = path.split("temopral_conv.")[-1]
|
| 397 |
+
old_path = ".".join(["output_blocks", str(i), str(block_id), "temopral_conv", pruned_path])
|
| 398 |
+
new_path = ".".join(["up_blocks", str(block_id), "temp_convs", str(layer_in_block_id), pruned_path])
|
| 399 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
| 400 |
+
|
| 401 |
+
return new_checkpoint
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
if __name__ == "__main__":
|
| 405 |
+
parser = argparse.ArgumentParser()
|
| 406 |
+
|
| 407 |
+
parser.add_argument(
|
| 408 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
| 409 |
+
)
|
| 410 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 411 |
+
args = parser.parse_args()
|
| 412 |
+
|
| 413 |
+
unet_checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
| 414 |
+
unet = UNet3DConditionModel()
|
| 415 |
+
|
| 416 |
+
converted_ckpt = convert_ldm_unet_checkpoint(unet_checkpoint, unet.config)
|
| 417 |
+
|
| 418 |
+
diff_0 = set(unet.state_dict().keys()) - set(converted_ckpt.keys())
|
| 419 |
+
diff_1 = set(converted_ckpt.keys()) - set(unet.state_dict().keys())
|
| 420 |
+
|
| 421 |
+
assert len(diff_0) == len(diff_1) == 0, "Converted weights don't match"
|
| 422 |
+
|
| 423 |
+
# load state_dict
|
| 424 |
+
unet.load_state_dict(converted_ckpt)
|
| 425 |
+
|
| 426 |
+
unet.save_pretrained(args.dump_path)
|
| 427 |
+
|
| 428 |
+
# -- finish converting the unet --
|
diffusers/scripts/convert_music_spectrogram_to_diffusers.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import jax as jnp
|
| 6 |
+
import numpy as onp
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from music_spectrogram_diffusion import inference
|
| 10 |
+
from t5x import checkpoints
|
| 11 |
+
|
| 12 |
+
from diffusers import DDPMScheduler, OnnxRuntimeModel, SpectrogramDiffusionPipeline
|
| 13 |
+
from diffusers.pipelines.spectrogram_diffusion import SpectrogramContEncoder, SpectrogramNotesEncoder, T5FilmDecoder
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
MODEL = "base_with_context"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_notes_encoder(weights, model):
|
| 20 |
+
model.token_embedder.weight = nn.Parameter(torch.Tensor(weights["token_embedder"]["embedding"]))
|
| 21 |
+
model.position_encoding.weight = nn.Parameter(torch.Tensor(weights["Embed_0"]["embedding"]), requires_grad=False)
|
| 22 |
+
for lyr_num, lyr in enumerate(model.encoders):
|
| 23 |
+
ly_weight = weights[f"layers_{lyr_num}"]
|
| 24 |
+
lyr.layer[0].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_attention_layer_norm"]["scale"]))
|
| 25 |
+
|
| 26 |
+
attention_weights = ly_weight["attention"]
|
| 27 |
+
lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
|
| 28 |
+
lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
|
| 29 |
+
lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
|
| 30 |
+
lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
|
| 31 |
+
|
| 32 |
+
lyr.layer[1].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
|
| 33 |
+
|
| 34 |
+
lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
|
| 35 |
+
lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
|
| 36 |
+
lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wo"]["kernel"].T))
|
| 37 |
+
|
| 38 |
+
model.layer_norm.weight = nn.Parameter(torch.Tensor(weights["encoder_norm"]["scale"]))
|
| 39 |
+
return model
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def load_continuous_encoder(weights, model):
|
| 43 |
+
model.input_proj.weight = nn.Parameter(torch.Tensor(weights["input_proj"]["kernel"].T))
|
| 44 |
+
|
| 45 |
+
model.position_encoding.weight = nn.Parameter(torch.Tensor(weights["Embed_0"]["embedding"]), requires_grad=False)
|
| 46 |
+
|
| 47 |
+
for lyr_num, lyr in enumerate(model.encoders):
|
| 48 |
+
ly_weight = weights[f"layers_{lyr_num}"]
|
| 49 |
+
attention_weights = ly_weight["attention"]
|
| 50 |
+
|
| 51 |
+
lyr.layer[0].SelfAttention.q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
|
| 52 |
+
lyr.layer[0].SelfAttention.k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
|
| 53 |
+
lyr.layer[0].SelfAttention.v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
|
| 54 |
+
lyr.layer[0].SelfAttention.o.weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
|
| 55 |
+
lyr.layer[0].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_attention_layer_norm"]["scale"]))
|
| 56 |
+
|
| 57 |
+
lyr.layer[1].DenseReluDense.wi_0.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
|
| 58 |
+
lyr.layer[1].DenseReluDense.wi_1.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
|
| 59 |
+
lyr.layer[1].DenseReluDense.wo.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wo"]["kernel"].T))
|
| 60 |
+
lyr.layer[1].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
|
| 61 |
+
|
| 62 |
+
model.layer_norm.weight = nn.Parameter(torch.Tensor(weights["encoder_norm"]["scale"]))
|
| 63 |
+
|
| 64 |
+
return model
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def load_decoder(weights, model):
|
| 68 |
+
model.conditioning_emb[0].weight = nn.Parameter(torch.Tensor(weights["time_emb_dense0"]["kernel"].T))
|
| 69 |
+
model.conditioning_emb[2].weight = nn.Parameter(torch.Tensor(weights["time_emb_dense1"]["kernel"].T))
|
| 70 |
+
|
| 71 |
+
model.position_encoding.weight = nn.Parameter(torch.Tensor(weights["Embed_0"]["embedding"]), requires_grad=False)
|
| 72 |
+
|
| 73 |
+
model.continuous_inputs_projection.weight = nn.Parameter(
|
| 74 |
+
torch.Tensor(weights["continuous_inputs_projection"]["kernel"].T)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
for lyr_num, lyr in enumerate(model.decoders):
|
| 78 |
+
ly_weight = weights[f"layers_{lyr_num}"]
|
| 79 |
+
lyr.layer[0].layer_norm.weight = nn.Parameter(
|
| 80 |
+
torch.Tensor(ly_weight["pre_self_attention_layer_norm"]["scale"])
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
lyr.layer[0].FiLMLayer.scale_bias.weight = nn.Parameter(
|
| 84 |
+
torch.Tensor(ly_weight["FiLMLayer_0"]["DenseGeneral_0"]["kernel"].T)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
attention_weights = ly_weight["self_attention"]
|
| 88 |
+
lyr.layer[0].attention.to_q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
|
| 89 |
+
lyr.layer[0].attention.to_k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
|
| 90 |
+
lyr.layer[0].attention.to_v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
|
| 91 |
+
lyr.layer[0].attention.to_out[0].weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
|
| 92 |
+
|
| 93 |
+
attention_weights = ly_weight["MultiHeadDotProductAttention_0"]
|
| 94 |
+
lyr.layer[1].attention.to_q.weight = nn.Parameter(torch.Tensor(attention_weights["query"]["kernel"].T))
|
| 95 |
+
lyr.layer[1].attention.to_k.weight = nn.Parameter(torch.Tensor(attention_weights["key"]["kernel"].T))
|
| 96 |
+
lyr.layer[1].attention.to_v.weight = nn.Parameter(torch.Tensor(attention_weights["value"]["kernel"].T))
|
| 97 |
+
lyr.layer[1].attention.to_out[0].weight = nn.Parameter(torch.Tensor(attention_weights["out"]["kernel"].T))
|
| 98 |
+
lyr.layer[1].layer_norm.weight = nn.Parameter(
|
| 99 |
+
torch.Tensor(ly_weight["pre_cross_attention_layer_norm"]["scale"])
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
lyr.layer[2].layer_norm.weight = nn.Parameter(torch.Tensor(ly_weight["pre_mlp_layer_norm"]["scale"]))
|
| 103 |
+
lyr.layer[2].film.scale_bias.weight = nn.Parameter(
|
| 104 |
+
torch.Tensor(ly_weight["FiLMLayer_1"]["DenseGeneral_0"]["kernel"].T)
|
| 105 |
+
)
|
| 106 |
+
lyr.layer[2].DenseReluDense.wi_0.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_0"]["kernel"].T))
|
| 107 |
+
lyr.layer[2].DenseReluDense.wi_1.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wi_1"]["kernel"].T))
|
| 108 |
+
lyr.layer[2].DenseReluDense.wo.weight = nn.Parameter(torch.Tensor(ly_weight["mlp"]["wo"]["kernel"].T))
|
| 109 |
+
|
| 110 |
+
model.decoder_norm.weight = nn.Parameter(torch.Tensor(weights["decoder_norm"]["scale"]))
|
| 111 |
+
|
| 112 |
+
model.spec_out.weight = nn.Parameter(torch.Tensor(weights["spec_out_dense"]["kernel"].T))
|
| 113 |
+
|
| 114 |
+
return model
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def main(args):
|
| 118 |
+
t5_checkpoint = checkpoints.load_t5x_checkpoint(args.checkpoint_path)
|
| 119 |
+
t5_checkpoint = jnp.tree_util.tree_map(onp.array, t5_checkpoint)
|
| 120 |
+
|
| 121 |
+
gin_overrides = [
|
| 122 |
+
"from __gin__ import dynamic_registration",
|
| 123 |
+
"from music_spectrogram_diffusion.models.diffusion import diffusion_utils",
|
| 124 |
+
"diffusion_utils.ClassifierFreeGuidanceConfig.eval_condition_weight = 2.0",
|
| 125 |
+
"diffusion_utils.DiffusionConfig.classifier_free_guidance = @diffusion_utils.ClassifierFreeGuidanceConfig()",
|
| 126 |
+
]
|
| 127 |
+
|
| 128 |
+
gin_file = os.path.join(args.checkpoint_path, "..", "config.gin")
|
| 129 |
+
gin_config = inference.parse_training_gin_file(gin_file, gin_overrides)
|
| 130 |
+
synth_model = inference.InferenceModel(args.checkpoint_path, gin_config)
|
| 131 |
+
|
| 132 |
+
scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", variance_type="fixed_large")
|
| 133 |
+
|
| 134 |
+
notes_encoder = SpectrogramNotesEncoder(
|
| 135 |
+
max_length=synth_model.sequence_length["inputs"],
|
| 136 |
+
vocab_size=synth_model.model.module.config.vocab_size,
|
| 137 |
+
d_model=synth_model.model.module.config.emb_dim,
|
| 138 |
+
dropout_rate=synth_model.model.module.config.dropout_rate,
|
| 139 |
+
num_layers=synth_model.model.module.config.num_encoder_layers,
|
| 140 |
+
num_heads=synth_model.model.module.config.num_heads,
|
| 141 |
+
d_kv=synth_model.model.module.config.head_dim,
|
| 142 |
+
d_ff=synth_model.model.module.config.mlp_dim,
|
| 143 |
+
feed_forward_proj="gated-gelu",
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
continuous_encoder = SpectrogramContEncoder(
|
| 147 |
+
input_dims=synth_model.audio_codec.n_dims,
|
| 148 |
+
targets_context_length=synth_model.sequence_length["targets_context"],
|
| 149 |
+
d_model=synth_model.model.module.config.emb_dim,
|
| 150 |
+
dropout_rate=synth_model.model.module.config.dropout_rate,
|
| 151 |
+
num_layers=synth_model.model.module.config.num_encoder_layers,
|
| 152 |
+
num_heads=synth_model.model.module.config.num_heads,
|
| 153 |
+
d_kv=synth_model.model.module.config.head_dim,
|
| 154 |
+
d_ff=synth_model.model.module.config.mlp_dim,
|
| 155 |
+
feed_forward_proj="gated-gelu",
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
decoder = T5FilmDecoder(
|
| 159 |
+
input_dims=synth_model.audio_codec.n_dims,
|
| 160 |
+
targets_length=synth_model.sequence_length["targets_context"],
|
| 161 |
+
max_decoder_noise_time=synth_model.model.module.config.max_decoder_noise_time,
|
| 162 |
+
d_model=synth_model.model.module.config.emb_dim,
|
| 163 |
+
num_layers=synth_model.model.module.config.num_decoder_layers,
|
| 164 |
+
num_heads=synth_model.model.module.config.num_heads,
|
| 165 |
+
d_kv=synth_model.model.module.config.head_dim,
|
| 166 |
+
d_ff=synth_model.model.module.config.mlp_dim,
|
| 167 |
+
dropout_rate=synth_model.model.module.config.dropout_rate,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
notes_encoder = load_notes_encoder(t5_checkpoint["target"]["token_encoder"], notes_encoder)
|
| 171 |
+
continuous_encoder = load_continuous_encoder(t5_checkpoint["target"]["continuous_encoder"], continuous_encoder)
|
| 172 |
+
decoder = load_decoder(t5_checkpoint["target"]["decoder"], decoder)
|
| 173 |
+
|
| 174 |
+
melgan = OnnxRuntimeModel.from_pretrained("kashif/soundstream_mel_decoder")
|
| 175 |
+
|
| 176 |
+
pipe = SpectrogramDiffusionPipeline(
|
| 177 |
+
notes_encoder=notes_encoder,
|
| 178 |
+
continuous_encoder=continuous_encoder,
|
| 179 |
+
decoder=decoder,
|
| 180 |
+
scheduler=scheduler,
|
| 181 |
+
melgan=melgan,
|
| 182 |
+
)
|
| 183 |
+
if args.save:
|
| 184 |
+
pipe.save_pretrained(args.output_path)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
if __name__ == "__main__":
|
| 188 |
+
parser = argparse.ArgumentParser()
|
| 189 |
+
|
| 190 |
+
parser.add_argument("--output_path", default=None, type=str, required=True, help="Path to the converted model.")
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."
|
| 193 |
+
)
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--checkpoint_path",
|
| 196 |
+
default=f"{MODEL}/checkpoint_500000",
|
| 197 |
+
type=str,
|
| 198 |
+
required=False,
|
| 199 |
+
help="Path to the original jax model checkpoint.",
|
| 200 |
+
)
|
| 201 |
+
args = parser.parse_args()
|
| 202 |
+
|
| 203 |
+
main(args)
|
diffusers/scripts/convert_original_audioldm_to_diffusers.py
ADDED
|
@@ -0,0 +1,1042 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Conversion script for the AudioLDM checkpoints."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import re
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import yaml
|
| 22 |
+
from transformers import (
|
| 23 |
+
AutoTokenizer,
|
| 24 |
+
ClapTextConfig,
|
| 25 |
+
ClapTextModelWithProjection,
|
| 26 |
+
SpeechT5HifiGan,
|
| 27 |
+
SpeechT5HifiGanConfig,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
from diffusers import (
|
| 31 |
+
AudioLDMPipeline,
|
| 32 |
+
AutoencoderKL,
|
| 33 |
+
DDIMScheduler,
|
| 34 |
+
DPMSolverMultistepScheduler,
|
| 35 |
+
EulerAncestralDiscreteScheduler,
|
| 36 |
+
EulerDiscreteScheduler,
|
| 37 |
+
HeunDiscreteScheduler,
|
| 38 |
+
LMSDiscreteScheduler,
|
| 39 |
+
PNDMScheduler,
|
| 40 |
+
UNet2DConditionModel,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
|
| 45 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
| 46 |
+
"""
|
| 47 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
| 48 |
+
"""
|
| 49 |
+
if n_shave_prefix_segments >= 0:
|
| 50 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
| 51 |
+
else:
|
| 52 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_resnet_paths
|
| 56 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 57 |
+
"""
|
| 58 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 59 |
+
"""
|
| 60 |
+
mapping = []
|
| 61 |
+
for old_item in old_list:
|
| 62 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
| 63 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
| 64 |
+
|
| 65 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
| 66 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
| 67 |
+
|
| 68 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
| 69 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
| 70 |
+
|
| 71 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 72 |
+
|
| 73 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 74 |
+
|
| 75 |
+
return mapping
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths
|
| 79 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 80 |
+
"""
|
| 81 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 82 |
+
"""
|
| 83 |
+
mapping = []
|
| 84 |
+
for old_item in old_list:
|
| 85 |
+
new_item = old_item
|
| 86 |
+
|
| 87 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
| 88 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 89 |
+
|
| 90 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 91 |
+
|
| 92 |
+
return mapping
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_attention_paths
|
| 96 |
+
def renew_attention_paths(old_list):
|
| 97 |
+
"""
|
| 98 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 99 |
+
"""
|
| 100 |
+
mapping = []
|
| 101 |
+
for old_item in old_list:
|
| 102 |
+
new_item = old_item
|
| 103 |
+
|
| 104 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
| 105 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
| 106 |
+
|
| 107 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
| 108 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
| 109 |
+
|
| 110 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 111 |
+
|
| 112 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 113 |
+
|
| 114 |
+
return mapping
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_attention_paths
|
| 118 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 119 |
+
"""
|
| 120 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 121 |
+
"""
|
| 122 |
+
mapping = []
|
| 123 |
+
for old_item in old_list:
|
| 124 |
+
new_item = old_item
|
| 125 |
+
|
| 126 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
| 127 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
| 128 |
+
|
| 129 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
| 130 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
| 131 |
+
|
| 132 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
| 133 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
| 134 |
+
|
| 135 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
| 136 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
| 137 |
+
|
| 138 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
| 139 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
| 140 |
+
|
| 141 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 142 |
+
|
| 143 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 144 |
+
|
| 145 |
+
return mapping
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
|
| 149 |
+
def assign_to_checkpoint(
|
| 150 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
| 151 |
+
):
|
| 152 |
+
"""
|
| 153 |
+
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
| 154 |
+
attention layers, and takes into account additional replacements that may arise.
|
| 155 |
+
|
| 156 |
+
Assigns the weights to the new checkpoint.
|
| 157 |
+
"""
|
| 158 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
| 159 |
+
|
| 160 |
+
# Splits the attention layers into three variables.
|
| 161 |
+
if attention_paths_to_split is not None:
|
| 162 |
+
for path, path_map in attention_paths_to_split.items():
|
| 163 |
+
old_tensor = old_checkpoint[path]
|
| 164 |
+
channels = old_tensor.shape[0] // 3
|
| 165 |
+
|
| 166 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
| 167 |
+
|
| 168 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
| 169 |
+
|
| 170 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
| 171 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
| 172 |
+
|
| 173 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
| 174 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
| 175 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
| 176 |
+
|
| 177 |
+
for path in paths:
|
| 178 |
+
new_path = path["new"]
|
| 179 |
+
|
| 180 |
+
# These have already been assigned
|
| 181 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
# Global renaming happens here
|
| 185 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
| 186 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
| 187 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
| 188 |
+
|
| 189 |
+
if additional_replacements is not None:
|
| 190 |
+
for replacement in additional_replacements:
|
| 191 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
| 192 |
+
|
| 193 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
| 194 |
+
if "proj_attn.weight" in new_path:
|
| 195 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
| 196 |
+
else:
|
| 197 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
|
| 201 |
+
def conv_attn_to_linear(checkpoint):
|
| 202 |
+
keys = list(checkpoint.keys())
|
| 203 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
| 204 |
+
for key in keys:
|
| 205 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
| 206 |
+
if checkpoint[key].ndim > 2:
|
| 207 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
| 208 |
+
elif "proj_attn.weight" in key:
|
| 209 |
+
if checkpoint[key].ndim > 2:
|
| 210 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def create_unet_diffusers_config(original_config, image_size: int):
|
| 214 |
+
"""
|
| 215 |
+
Creates a UNet config for diffusers based on the config of the original AudioLDM model.
|
| 216 |
+
"""
|
| 217 |
+
unet_params = original_config["model"]["params"]["unet_config"]["params"]
|
| 218 |
+
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
| 219 |
+
|
| 220 |
+
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
|
| 221 |
+
|
| 222 |
+
down_block_types = []
|
| 223 |
+
resolution = 1
|
| 224 |
+
for i in range(len(block_out_channels)):
|
| 225 |
+
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
|
| 226 |
+
down_block_types.append(block_type)
|
| 227 |
+
if i != len(block_out_channels) - 1:
|
| 228 |
+
resolution *= 2
|
| 229 |
+
|
| 230 |
+
up_block_types = []
|
| 231 |
+
for i in range(len(block_out_channels)):
|
| 232 |
+
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
|
| 233 |
+
up_block_types.append(block_type)
|
| 234 |
+
resolution //= 2
|
| 235 |
+
|
| 236 |
+
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
|
| 237 |
+
|
| 238 |
+
cross_attention_dim = (
|
| 239 |
+
unet_params["cross_attention_dim"] if "cross_attention_dim" in unet_params else block_out_channels
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None
|
| 243 |
+
projection_class_embeddings_input_dim = (
|
| 244 |
+
unet_params["extra_film_condition_dim"] if "extra_film_condition_dim" in unet_params else None
|
| 245 |
+
)
|
| 246 |
+
class_embeddings_concat = unet_params["extra_film_use_concat"] if "extra_film_use_concat" in unet_params else None
|
| 247 |
+
|
| 248 |
+
config = {
|
| 249 |
+
"sample_size": image_size // vae_scale_factor,
|
| 250 |
+
"in_channels": unet_params["in_channels"],
|
| 251 |
+
"out_channels": unet_params["out_channels"],
|
| 252 |
+
"down_block_types": tuple(down_block_types),
|
| 253 |
+
"up_block_types": tuple(up_block_types),
|
| 254 |
+
"block_out_channels": tuple(block_out_channels),
|
| 255 |
+
"layers_per_block": unet_params["num_res_blocks"],
|
| 256 |
+
"cross_attention_dim": cross_attention_dim,
|
| 257 |
+
"class_embed_type": class_embed_type,
|
| 258 |
+
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
| 259 |
+
"class_embeddings_concat": class_embeddings_concat,
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
return config
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config
|
| 266 |
+
def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
|
| 267 |
+
"""
|
| 268 |
+
Creates a VAE config for diffusers based on the config of the original AudioLDM model. Compared to the original
|
| 269 |
+
Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
|
| 270 |
+
"""
|
| 271 |
+
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
| 272 |
+
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
|
| 273 |
+
|
| 274 |
+
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
| 275 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
| 276 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
| 277 |
+
|
| 278 |
+
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215
|
| 279 |
+
|
| 280 |
+
config = {
|
| 281 |
+
"sample_size": image_size,
|
| 282 |
+
"in_channels": vae_params["in_channels"],
|
| 283 |
+
"out_channels": vae_params["out_ch"],
|
| 284 |
+
"down_block_types": tuple(down_block_types),
|
| 285 |
+
"up_block_types": tuple(up_block_types),
|
| 286 |
+
"block_out_channels": tuple(block_out_channels),
|
| 287 |
+
"latent_channels": vae_params["z_channels"],
|
| 288 |
+
"layers_per_block": vae_params["num_res_blocks"],
|
| 289 |
+
"scaling_factor": float(scaling_factor),
|
| 290 |
+
}
|
| 291 |
+
return config
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
|
| 295 |
+
def create_diffusers_schedular(original_config):
|
| 296 |
+
schedular = DDIMScheduler(
|
| 297 |
+
num_train_timesteps=original_config["model"]["params"]["timesteps"],
|
| 298 |
+
beta_start=original_config["model"]["params"]["linear_start"],
|
| 299 |
+
beta_end=original_config["model"]["params"]["linear_end"],
|
| 300 |
+
beta_schedule="scaled_linear",
|
| 301 |
+
)
|
| 302 |
+
return schedular
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_unet_checkpoint
|
| 306 |
+
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
|
| 307 |
+
"""
|
| 308 |
+
Takes a state dict and a config, and returns a converted checkpoint. Compared to the original Stable Diffusion
|
| 309 |
+
conversion, this function additionally converts the learnt film embedding linear layer.
|
| 310 |
+
"""
|
| 311 |
+
|
| 312 |
+
# extract state_dict for UNet
|
| 313 |
+
unet_state_dict = {}
|
| 314 |
+
keys = list(checkpoint.keys())
|
| 315 |
+
|
| 316 |
+
unet_key = "model.diffusion_model."
|
| 317 |
+
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
| 318 |
+
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
| 319 |
+
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
| 320 |
+
print(
|
| 321 |
+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
| 322 |
+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
| 323 |
+
)
|
| 324 |
+
for key in keys:
|
| 325 |
+
if key.startswith("model.diffusion_model"):
|
| 326 |
+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
| 327 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
| 328 |
+
else:
|
| 329 |
+
if sum(k.startswith("model_ema") for k in keys) > 100:
|
| 330 |
+
print(
|
| 331 |
+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
| 332 |
+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
for key in keys:
|
| 336 |
+
if key.startswith(unet_key):
|
| 337 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
| 338 |
+
|
| 339 |
+
new_checkpoint = {}
|
| 340 |
+
|
| 341 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
| 342 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
| 343 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
| 344 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
| 345 |
+
|
| 346 |
+
new_checkpoint["class_embedding.weight"] = unet_state_dict["film_emb.weight"]
|
| 347 |
+
new_checkpoint["class_embedding.bias"] = unet_state_dict["film_emb.bias"]
|
| 348 |
+
|
| 349 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
| 350 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
| 351 |
+
|
| 352 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
| 353 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
| 354 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
| 355 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
| 356 |
+
|
| 357 |
+
# Retrieves the keys for the input blocks only
|
| 358 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
| 359 |
+
input_blocks = {
|
| 360 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
| 361 |
+
for layer_id in range(num_input_blocks)
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
# Retrieves the keys for the middle blocks only
|
| 365 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
| 366 |
+
middle_blocks = {
|
| 367 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
| 368 |
+
for layer_id in range(num_middle_blocks)
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
# Retrieves the keys for the output blocks only
|
| 372 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
| 373 |
+
output_blocks = {
|
| 374 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
| 375 |
+
for layer_id in range(num_output_blocks)
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
for i in range(1, num_input_blocks):
|
| 379 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
| 380 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
| 381 |
+
|
| 382 |
+
resnets = [
|
| 383 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
| 384 |
+
]
|
| 385 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
| 386 |
+
|
| 387 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
| 388 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
| 389 |
+
f"input_blocks.{i}.0.op.weight"
|
| 390 |
+
)
|
| 391 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
| 392 |
+
f"input_blocks.{i}.0.op.bias"
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
paths = renew_resnet_paths(resnets)
|
| 396 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 397 |
+
assign_to_checkpoint(
|
| 398 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
if len(attentions):
|
| 402 |
+
paths = renew_attention_paths(attentions)
|
| 403 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
| 404 |
+
assign_to_checkpoint(
|
| 405 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
resnet_0 = middle_blocks[0]
|
| 409 |
+
attentions = middle_blocks[1]
|
| 410 |
+
resnet_1 = middle_blocks[2]
|
| 411 |
+
|
| 412 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
| 413 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
| 414 |
+
|
| 415 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
| 416 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
| 417 |
+
|
| 418 |
+
attentions_paths = renew_attention_paths(attentions)
|
| 419 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
| 420 |
+
assign_to_checkpoint(
|
| 421 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
for i in range(num_output_blocks):
|
| 425 |
+
block_id = i // (config["layers_per_block"] + 1)
|
| 426 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
| 427 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
| 428 |
+
output_block_list = {}
|
| 429 |
+
|
| 430 |
+
for layer in output_block_layers:
|
| 431 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
| 432 |
+
if layer_id in output_block_list:
|
| 433 |
+
output_block_list[layer_id].append(layer_name)
|
| 434 |
+
else:
|
| 435 |
+
output_block_list[layer_id] = [layer_name]
|
| 436 |
+
|
| 437 |
+
if len(output_block_list) > 1:
|
| 438 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
| 439 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
| 440 |
+
|
| 441 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
| 442 |
+
paths = renew_resnet_paths(resnets)
|
| 443 |
+
|
| 444 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 445 |
+
assign_to_checkpoint(
|
| 446 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
| 450 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
| 451 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
| 452 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
| 453 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
| 454 |
+
]
|
| 455 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
| 456 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
| 457 |
+
]
|
| 458 |
+
|
| 459 |
+
# Clear attentions as they have been attributed above.
|
| 460 |
+
if len(attentions) == 2:
|
| 461 |
+
attentions = []
|
| 462 |
+
|
| 463 |
+
if len(attentions):
|
| 464 |
+
paths = renew_attention_paths(attentions)
|
| 465 |
+
meta_path = {
|
| 466 |
+
"old": f"output_blocks.{i}.1",
|
| 467 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
| 468 |
+
}
|
| 469 |
+
assign_to_checkpoint(
|
| 470 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 471 |
+
)
|
| 472 |
+
else:
|
| 473 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
| 474 |
+
for path in resnet_0_paths:
|
| 475 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
| 476 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
| 477 |
+
|
| 478 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
| 479 |
+
|
| 480 |
+
return new_checkpoint
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
|
| 484 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
| 485 |
+
# extract state dict for VAE
|
| 486 |
+
vae_state_dict = {}
|
| 487 |
+
vae_key = "first_stage_model."
|
| 488 |
+
keys = list(checkpoint.keys())
|
| 489 |
+
for key in keys:
|
| 490 |
+
if key.startswith(vae_key):
|
| 491 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
| 492 |
+
|
| 493 |
+
new_checkpoint = {}
|
| 494 |
+
|
| 495 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
| 496 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
| 497 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
| 498 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
| 499 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
| 500 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
| 501 |
+
|
| 502 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
| 503 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
| 504 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
| 505 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
| 506 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
| 507 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
| 508 |
+
|
| 509 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
| 510 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
| 511 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
| 512 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
| 513 |
+
|
| 514 |
+
# Retrieves the keys for the encoder down blocks only
|
| 515 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
| 516 |
+
down_blocks = {
|
| 517 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
# Retrieves the keys for the decoder up blocks only
|
| 521 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
| 522 |
+
up_blocks = {
|
| 523 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
for i in range(num_down_blocks):
|
| 527 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
| 528 |
+
|
| 529 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
| 530 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
| 531 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
| 532 |
+
)
|
| 533 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
| 534 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 538 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
| 539 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 540 |
+
|
| 541 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
| 542 |
+
num_mid_res_blocks = 2
|
| 543 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 544 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
| 545 |
+
|
| 546 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 547 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 548 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 549 |
+
|
| 550 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
| 551 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 552 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 553 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 554 |
+
conv_attn_to_linear(new_checkpoint)
|
| 555 |
+
|
| 556 |
+
for i in range(num_up_blocks):
|
| 557 |
+
block_id = num_up_blocks - 1 - i
|
| 558 |
+
resnets = [
|
| 559 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
| 560 |
+
]
|
| 561 |
+
|
| 562 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
| 563 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
| 564 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
| 565 |
+
]
|
| 566 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
| 567 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
| 568 |
+
]
|
| 569 |
+
|
| 570 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 571 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
| 572 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 573 |
+
|
| 574 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
| 575 |
+
num_mid_res_blocks = 2
|
| 576 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 577 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
| 578 |
+
|
| 579 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 580 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 581 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 582 |
+
|
| 583 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
| 584 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 585 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 586 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 587 |
+
conv_attn_to_linear(new_checkpoint)
|
| 588 |
+
return new_checkpoint
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
CLAP_KEYS_TO_MODIFY_MAPPING = {
|
| 592 |
+
"text_branch": "text_model",
|
| 593 |
+
"attn": "attention.self",
|
| 594 |
+
"self.proj": "output.dense",
|
| 595 |
+
"attention.self_mask": "attn_mask",
|
| 596 |
+
"mlp.fc1": "intermediate.dense",
|
| 597 |
+
"mlp.fc2": "output.dense",
|
| 598 |
+
"norm1": "layernorm_before",
|
| 599 |
+
"norm2": "layernorm_after",
|
| 600 |
+
"bn0": "batch_norm",
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
CLAP_KEYS_TO_IGNORE = ["text_transform"]
|
| 604 |
+
|
| 605 |
+
CLAP_EXPECTED_MISSING_KEYS = ["text_model.embeddings.token_type_ids"]
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def convert_open_clap_checkpoint(checkpoint):
|
| 609 |
+
"""
|
| 610 |
+
Takes a state dict and returns a converted CLAP checkpoint.
|
| 611 |
+
"""
|
| 612 |
+
# extract state dict for CLAP text embedding model, discarding the audio component
|
| 613 |
+
model_state_dict = {}
|
| 614 |
+
model_key = "cond_stage_model.model.text_"
|
| 615 |
+
keys = list(checkpoint.keys())
|
| 616 |
+
for key in keys:
|
| 617 |
+
if key.startswith(model_key):
|
| 618 |
+
model_state_dict[key.replace(model_key, "text_")] = checkpoint.get(key)
|
| 619 |
+
|
| 620 |
+
new_checkpoint = {}
|
| 621 |
+
|
| 622 |
+
sequential_layers_pattern = r".*sequential.(\d+).*"
|
| 623 |
+
text_projection_pattern = r".*_projection.(\d+).*"
|
| 624 |
+
|
| 625 |
+
for key, value in model_state_dict.items():
|
| 626 |
+
# check if key should be ignored in mapping
|
| 627 |
+
if key.split(".")[0] in CLAP_KEYS_TO_IGNORE:
|
| 628 |
+
continue
|
| 629 |
+
|
| 630 |
+
# check if any key needs to be modified
|
| 631 |
+
for key_to_modify, new_key in CLAP_KEYS_TO_MODIFY_MAPPING.items():
|
| 632 |
+
if key_to_modify in key:
|
| 633 |
+
key = key.replace(key_to_modify, new_key)
|
| 634 |
+
|
| 635 |
+
if re.match(sequential_layers_pattern, key):
|
| 636 |
+
# replace sequential layers with list
|
| 637 |
+
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
|
| 638 |
+
|
| 639 |
+
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
|
| 640 |
+
elif re.match(text_projection_pattern, key):
|
| 641 |
+
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
|
| 642 |
+
|
| 643 |
+
# Because in CLAP they use `nn.Sequential`...
|
| 644 |
+
transformers_projection_layer = 1 if projecton_layer == 0 else 2
|
| 645 |
+
|
| 646 |
+
key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
|
| 647 |
+
|
| 648 |
+
if "audio" and "qkv" in key:
|
| 649 |
+
# split qkv into query key and value
|
| 650 |
+
mixed_qkv = value
|
| 651 |
+
qkv_dim = mixed_qkv.size(0) // 3
|
| 652 |
+
|
| 653 |
+
query_layer = mixed_qkv[:qkv_dim]
|
| 654 |
+
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
|
| 655 |
+
value_layer = mixed_qkv[qkv_dim * 2 :]
|
| 656 |
+
|
| 657 |
+
new_checkpoint[key.replace("qkv", "query")] = query_layer
|
| 658 |
+
new_checkpoint[key.replace("qkv", "key")] = key_layer
|
| 659 |
+
new_checkpoint[key.replace("qkv", "value")] = value_layer
|
| 660 |
+
else:
|
| 661 |
+
new_checkpoint[key] = value
|
| 662 |
+
|
| 663 |
+
return new_checkpoint
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
def create_transformers_vocoder_config(original_config):
|
| 667 |
+
"""
|
| 668 |
+
Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
|
| 669 |
+
"""
|
| 670 |
+
vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"]
|
| 671 |
+
|
| 672 |
+
config = {
|
| 673 |
+
"model_in_dim": vocoder_params["num_mels"],
|
| 674 |
+
"sampling_rate": vocoder_params["sampling_rate"],
|
| 675 |
+
"upsample_initial_channel": vocoder_params["upsample_initial_channel"],
|
| 676 |
+
"upsample_rates": list(vocoder_params["upsample_rates"]),
|
| 677 |
+
"upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]),
|
| 678 |
+
"resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]),
|
| 679 |
+
"resblock_dilation_sizes": [
|
| 680 |
+
list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"]
|
| 681 |
+
],
|
| 682 |
+
"normalize_before": False,
|
| 683 |
+
}
|
| 684 |
+
|
| 685 |
+
return config
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
def convert_hifigan_checkpoint(checkpoint, config):
|
| 689 |
+
"""
|
| 690 |
+
Takes a state dict and config, and returns a converted HiFiGAN vocoder checkpoint.
|
| 691 |
+
"""
|
| 692 |
+
# extract state dict for vocoder
|
| 693 |
+
vocoder_state_dict = {}
|
| 694 |
+
vocoder_key = "first_stage_model.vocoder."
|
| 695 |
+
keys = list(checkpoint.keys())
|
| 696 |
+
for key in keys:
|
| 697 |
+
if key.startswith(vocoder_key):
|
| 698 |
+
vocoder_state_dict[key.replace(vocoder_key, "")] = checkpoint.get(key)
|
| 699 |
+
|
| 700 |
+
# fix upsampler keys, everything else is correct already
|
| 701 |
+
for i in range(len(config.upsample_rates)):
|
| 702 |
+
vocoder_state_dict[f"upsampler.{i}.weight"] = vocoder_state_dict.pop(f"ups.{i}.weight")
|
| 703 |
+
vocoder_state_dict[f"upsampler.{i}.bias"] = vocoder_state_dict.pop(f"ups.{i}.bias")
|
| 704 |
+
|
| 705 |
+
if not config.normalize_before:
|
| 706 |
+
# if we don't set normalize_before then these variables are unused, so we set them to their initialised values
|
| 707 |
+
vocoder_state_dict["mean"] = torch.zeros(config.model_in_dim)
|
| 708 |
+
vocoder_state_dict["scale"] = torch.ones(config.model_in_dim)
|
| 709 |
+
|
| 710 |
+
return vocoder_state_dict
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
# Adapted from https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/blob/84a0384742a22bd80c44e903e241f0623e874f1d/audioldm/utils.py#L72-L73
|
| 714 |
+
DEFAULT_CONFIG = {
|
| 715 |
+
"model": {
|
| 716 |
+
"params": {
|
| 717 |
+
"linear_start": 0.0015,
|
| 718 |
+
"linear_end": 0.0195,
|
| 719 |
+
"timesteps": 1000,
|
| 720 |
+
"channels": 8,
|
| 721 |
+
"scale_by_std": True,
|
| 722 |
+
"unet_config": {
|
| 723 |
+
"target": "audioldm.latent_diffusion.openaimodel.UNetModel",
|
| 724 |
+
"params": {
|
| 725 |
+
"extra_film_condition_dim": 512,
|
| 726 |
+
"extra_film_use_concat": True,
|
| 727 |
+
"in_channels": 8,
|
| 728 |
+
"out_channels": 8,
|
| 729 |
+
"model_channels": 128,
|
| 730 |
+
"attention_resolutions": [8, 4, 2],
|
| 731 |
+
"num_res_blocks": 2,
|
| 732 |
+
"channel_mult": [1, 2, 3, 5],
|
| 733 |
+
"num_head_channels": 32,
|
| 734 |
+
},
|
| 735 |
+
},
|
| 736 |
+
"first_stage_config": {
|
| 737 |
+
"target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL",
|
| 738 |
+
"params": {
|
| 739 |
+
"embed_dim": 8,
|
| 740 |
+
"ddconfig": {
|
| 741 |
+
"z_channels": 8,
|
| 742 |
+
"resolution": 256,
|
| 743 |
+
"in_channels": 1,
|
| 744 |
+
"out_ch": 1,
|
| 745 |
+
"ch": 128,
|
| 746 |
+
"ch_mult": [1, 2, 4],
|
| 747 |
+
"num_res_blocks": 2,
|
| 748 |
+
},
|
| 749 |
+
},
|
| 750 |
+
},
|
| 751 |
+
"vocoder_config": {
|
| 752 |
+
"target": "audioldm.first_stage_model.vocoder",
|
| 753 |
+
"params": {
|
| 754 |
+
"upsample_rates": [5, 4, 2, 2, 2],
|
| 755 |
+
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
|
| 756 |
+
"upsample_initial_channel": 1024,
|
| 757 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
| 758 |
+
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 759 |
+
"num_mels": 64,
|
| 760 |
+
"sampling_rate": 16000,
|
| 761 |
+
},
|
| 762 |
+
},
|
| 763 |
+
},
|
| 764 |
+
},
|
| 765 |
+
}
|
| 766 |
+
|
| 767 |
+
|
| 768 |
+
def load_pipeline_from_original_audioldm_ckpt(
|
| 769 |
+
checkpoint_path: str,
|
| 770 |
+
original_config_file: str = None,
|
| 771 |
+
image_size: int = 512,
|
| 772 |
+
prediction_type: str = None,
|
| 773 |
+
extract_ema: bool = False,
|
| 774 |
+
scheduler_type: str = "ddim",
|
| 775 |
+
num_in_channels: int = None,
|
| 776 |
+
model_channels: int = None,
|
| 777 |
+
num_head_channels: int = None,
|
| 778 |
+
device: str = None,
|
| 779 |
+
from_safetensors: bool = False,
|
| 780 |
+
) -> AudioLDMPipeline:
|
| 781 |
+
"""
|
| 782 |
+
Load an AudioLDM pipeline object from a `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file.
|
| 783 |
+
|
| 784 |
+
Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the
|
| 785 |
+
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
|
| 786 |
+
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
|
| 787 |
+
|
| 788 |
+
Args:
|
| 789 |
+
checkpoint_path (`str`): Path to `.ckpt` file.
|
| 790 |
+
original_config_file (`str`):
|
| 791 |
+
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
|
| 792 |
+
set to the audioldm-s-full-v2 config.
|
| 793 |
+
image_size (`int`, *optional*, defaults to 512):
|
| 794 |
+
The image size that the model was trained on.
|
| 795 |
+
prediction_type (`str`, *optional*):
|
| 796 |
+
The prediction type that the model was trained on. If `None`, will be automatically
|
| 797 |
+
inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`.
|
| 798 |
+
num_in_channels (`int`, *optional*, defaults to None):
|
| 799 |
+
The number of UNet input channels. If `None`, it will be automatically inferred from the config.
|
| 800 |
+
model_channels (`int`, *optional*, defaults to None):
|
| 801 |
+
The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override
|
| 802 |
+
to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.
|
| 803 |
+
num_head_channels (`int`, *optional*, defaults to None):
|
| 804 |
+
The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override
|
| 805 |
+
to 32 for the small and medium checkpoints, and 64 for the large.
|
| 806 |
+
scheduler_type (`str`, *optional*, defaults to 'pndm'):
|
| 807 |
+
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
| 808 |
+
"ddim"]`.
|
| 809 |
+
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
| 810 |
+
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
|
| 811 |
+
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
| 812 |
+
inference. Non-EMA weights are usually better to continue fine-tuning.
|
| 813 |
+
device (`str`, *optional*, defaults to `None`):
|
| 814 |
+
The device to use. Pass `None` to determine automatically.
|
| 815 |
+
from_safetensors (`str`, *optional*, defaults to `False`):
|
| 816 |
+
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
|
| 817 |
+
return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
| 818 |
+
"""
|
| 819 |
+
|
| 820 |
+
if from_safetensors:
|
| 821 |
+
from safetensors import safe_open
|
| 822 |
+
|
| 823 |
+
checkpoint = {}
|
| 824 |
+
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
|
| 825 |
+
for key in f.keys():
|
| 826 |
+
checkpoint[key] = f.get_tensor(key)
|
| 827 |
+
else:
|
| 828 |
+
if device is None:
|
| 829 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 830 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 831 |
+
else:
|
| 832 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 833 |
+
|
| 834 |
+
if "state_dict" in checkpoint:
|
| 835 |
+
checkpoint = checkpoint["state_dict"]
|
| 836 |
+
|
| 837 |
+
if original_config_file is None:
|
| 838 |
+
original_config = DEFAULT_CONFIG
|
| 839 |
+
else:
|
| 840 |
+
original_config = yaml.safe_load(original_config_file)
|
| 841 |
+
|
| 842 |
+
if num_in_channels is not None:
|
| 843 |
+
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
| 844 |
+
|
| 845 |
+
if model_channels is not None:
|
| 846 |
+
original_config["model"]["params"]["unet_config"]["params"]["model_channels"] = model_channels
|
| 847 |
+
|
| 848 |
+
if num_head_channels is not None:
|
| 849 |
+
original_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = num_head_channels
|
| 850 |
+
|
| 851 |
+
if (
|
| 852 |
+
"parameterization" in original_config["model"]["params"]
|
| 853 |
+
and original_config["model"]["params"]["parameterization"] == "v"
|
| 854 |
+
):
|
| 855 |
+
if prediction_type is None:
|
| 856 |
+
prediction_type = "v_prediction"
|
| 857 |
+
else:
|
| 858 |
+
if prediction_type is None:
|
| 859 |
+
prediction_type = "epsilon"
|
| 860 |
+
|
| 861 |
+
if image_size is None:
|
| 862 |
+
image_size = 512
|
| 863 |
+
|
| 864 |
+
num_train_timesteps = original_config["model"]["params"]["timesteps"]
|
| 865 |
+
beta_start = original_config["model"]["params"]["linear_start"]
|
| 866 |
+
beta_end = original_config["model"]["params"]["linear_end"]
|
| 867 |
+
|
| 868 |
+
scheduler = DDIMScheduler(
|
| 869 |
+
beta_end=beta_end,
|
| 870 |
+
beta_schedule="scaled_linear",
|
| 871 |
+
beta_start=beta_start,
|
| 872 |
+
num_train_timesteps=num_train_timesteps,
|
| 873 |
+
steps_offset=1,
|
| 874 |
+
clip_sample=False,
|
| 875 |
+
set_alpha_to_one=False,
|
| 876 |
+
prediction_type=prediction_type,
|
| 877 |
+
)
|
| 878 |
+
# make sure scheduler works correctly with DDIM
|
| 879 |
+
scheduler.register_to_config(clip_sample=False)
|
| 880 |
+
|
| 881 |
+
if scheduler_type == "pndm":
|
| 882 |
+
config = dict(scheduler.config)
|
| 883 |
+
config["skip_prk_steps"] = True
|
| 884 |
+
scheduler = PNDMScheduler.from_config(config)
|
| 885 |
+
elif scheduler_type == "lms":
|
| 886 |
+
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
| 887 |
+
elif scheduler_type == "heun":
|
| 888 |
+
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
| 889 |
+
elif scheduler_type == "euler":
|
| 890 |
+
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
| 891 |
+
elif scheduler_type == "euler-ancestral":
|
| 892 |
+
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
| 893 |
+
elif scheduler_type == "dpm":
|
| 894 |
+
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
| 895 |
+
elif scheduler_type == "ddim":
|
| 896 |
+
scheduler = scheduler
|
| 897 |
+
else:
|
| 898 |
+
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
| 899 |
+
|
| 900 |
+
# Convert the UNet2DModel
|
| 901 |
+
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
| 902 |
+
unet = UNet2DConditionModel(**unet_config)
|
| 903 |
+
|
| 904 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
| 905 |
+
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
|
| 906 |
+
)
|
| 907 |
+
|
| 908 |
+
unet.load_state_dict(converted_unet_checkpoint)
|
| 909 |
+
|
| 910 |
+
# Convert the VAE model
|
| 911 |
+
vae_config = create_vae_diffusers_config(original_config, checkpoint=checkpoint, image_size=image_size)
|
| 912 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
| 913 |
+
|
| 914 |
+
vae = AutoencoderKL(**vae_config)
|
| 915 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
| 916 |
+
|
| 917 |
+
# Convert the text model
|
| 918 |
+
# AudioLDM uses the same configuration and tokenizer as the original CLAP model
|
| 919 |
+
config = ClapTextConfig.from_pretrained("laion/clap-htsat-unfused")
|
| 920 |
+
tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
|
| 921 |
+
|
| 922 |
+
converted_text_model = convert_open_clap_checkpoint(checkpoint)
|
| 923 |
+
text_model = ClapTextModelWithProjection(config)
|
| 924 |
+
|
| 925 |
+
missing_keys, unexpected_keys = text_model.load_state_dict(converted_text_model, strict=False)
|
| 926 |
+
# we expect not to have token_type_ids in our original state dict so let's ignore them
|
| 927 |
+
missing_keys = list(set(missing_keys) - set(CLAP_EXPECTED_MISSING_KEYS))
|
| 928 |
+
|
| 929 |
+
if len(unexpected_keys) > 0:
|
| 930 |
+
raise ValueError(f"Unexpected keys when loading CLAP model: {unexpected_keys}")
|
| 931 |
+
|
| 932 |
+
if len(missing_keys) > 0:
|
| 933 |
+
raise ValueError(f"Missing keys when loading CLAP model: {missing_keys}")
|
| 934 |
+
|
| 935 |
+
# Convert the vocoder model
|
| 936 |
+
vocoder_config = create_transformers_vocoder_config(original_config)
|
| 937 |
+
vocoder_config = SpeechT5HifiGanConfig(**vocoder_config)
|
| 938 |
+
converted_vocoder_checkpoint = convert_hifigan_checkpoint(checkpoint, vocoder_config)
|
| 939 |
+
|
| 940 |
+
vocoder = SpeechT5HifiGan(vocoder_config)
|
| 941 |
+
vocoder.load_state_dict(converted_vocoder_checkpoint)
|
| 942 |
+
|
| 943 |
+
# Instantiate the diffusers pipeline
|
| 944 |
+
pipe = AudioLDMPipeline(
|
| 945 |
+
vae=vae,
|
| 946 |
+
text_encoder=text_model,
|
| 947 |
+
tokenizer=tokenizer,
|
| 948 |
+
unet=unet,
|
| 949 |
+
scheduler=scheduler,
|
| 950 |
+
vocoder=vocoder,
|
| 951 |
+
)
|
| 952 |
+
|
| 953 |
+
return pipe
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
if __name__ == "__main__":
|
| 957 |
+
parser = argparse.ArgumentParser()
|
| 958 |
+
|
| 959 |
+
parser.add_argument(
|
| 960 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
| 961 |
+
)
|
| 962 |
+
parser.add_argument(
|
| 963 |
+
"--original_config_file",
|
| 964 |
+
default=None,
|
| 965 |
+
type=str,
|
| 966 |
+
help="The YAML config file corresponding to the original architecture.",
|
| 967 |
+
)
|
| 968 |
+
parser.add_argument(
|
| 969 |
+
"--num_in_channels",
|
| 970 |
+
default=None,
|
| 971 |
+
type=int,
|
| 972 |
+
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
|
| 973 |
+
)
|
| 974 |
+
parser.add_argument(
|
| 975 |
+
"--model_channels",
|
| 976 |
+
default=None,
|
| 977 |
+
type=int,
|
| 978 |
+
help="The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override"
|
| 979 |
+
" to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.",
|
| 980 |
+
)
|
| 981 |
+
parser.add_argument(
|
| 982 |
+
"--num_head_channels",
|
| 983 |
+
default=None,
|
| 984 |
+
type=int,
|
| 985 |
+
help="The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override"
|
| 986 |
+
" to 32 for the small and medium checkpoints, and 64 for the large.",
|
| 987 |
+
)
|
| 988 |
+
parser.add_argument(
|
| 989 |
+
"--scheduler_type",
|
| 990 |
+
default="ddim",
|
| 991 |
+
type=str,
|
| 992 |
+
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
|
| 993 |
+
)
|
| 994 |
+
parser.add_argument(
|
| 995 |
+
"--image_size",
|
| 996 |
+
default=None,
|
| 997 |
+
type=int,
|
| 998 |
+
help=("The image size that the model was trained on."),
|
| 999 |
+
)
|
| 1000 |
+
parser.add_argument(
|
| 1001 |
+
"--prediction_type",
|
| 1002 |
+
default=None,
|
| 1003 |
+
type=str,
|
| 1004 |
+
help=("The prediction type that the model was trained on."),
|
| 1005 |
+
)
|
| 1006 |
+
parser.add_argument(
|
| 1007 |
+
"--extract_ema",
|
| 1008 |
+
action="store_true",
|
| 1009 |
+
help=(
|
| 1010 |
+
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
|
| 1011 |
+
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
|
| 1012 |
+
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
|
| 1013 |
+
),
|
| 1014 |
+
)
|
| 1015 |
+
parser.add_argument(
|
| 1016 |
+
"--from_safetensors",
|
| 1017 |
+
action="store_true",
|
| 1018 |
+
help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
|
| 1019 |
+
)
|
| 1020 |
+
parser.add_argument(
|
| 1021 |
+
"--to_safetensors",
|
| 1022 |
+
action="store_true",
|
| 1023 |
+
help="Whether to store pipeline in safetensors format or not.",
|
| 1024 |
+
)
|
| 1025 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 1026 |
+
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
| 1027 |
+
args = parser.parse_args()
|
| 1028 |
+
|
| 1029 |
+
pipe = load_pipeline_from_original_audioldm_ckpt(
|
| 1030 |
+
checkpoint_path=args.checkpoint_path,
|
| 1031 |
+
original_config_file=args.original_config_file,
|
| 1032 |
+
image_size=args.image_size,
|
| 1033 |
+
prediction_type=args.prediction_type,
|
| 1034 |
+
extract_ema=args.extract_ema,
|
| 1035 |
+
scheduler_type=args.scheduler_type,
|
| 1036 |
+
num_in_channels=args.num_in_channels,
|
| 1037 |
+
model_channels=args.model_channels,
|
| 1038 |
+
num_head_channels=args.num_head_channels,
|
| 1039 |
+
from_safetensors=args.from_safetensors,
|
| 1040 |
+
device=args.device,
|
| 1041 |
+
)
|
| 1042 |
+
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|