lsmpp commited on
Commit
4960ef6
·
verified ·
1 Parent(s): d926b4c

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/.gitignore +1 -0
  2. .venv/CACHEDIR.TAG +1 -0
  3. .venv/pyvenv.cfg +6 -0
  4. arch/README.md +278 -0
  5. arch/__init__.py +59 -0
  6. arch/adapter.py +86 -0
  7. arch/data_loader.py +658 -0
  8. arch/example_train.py +377 -0
  9. arch/model_loader.py +325 -0
  10. arch/pipeline.py +348 -0
  11. arch/text_encoder.py +155 -0
  12. arch/training.py +307 -0
  13. diffusers/.github/PULL_REQUEST_TEMPLATE.md +61 -0
  14. diffusers/docs/README.md +268 -0
  15. diffusers/docs/TRANSLATING.md +69 -0
  16. diffusers/scripts/conversion_ldm_uncond.py +56 -0
  17. diffusers/scripts/convert_animatediff_motion_lora_to_diffusers.py +69 -0
  18. diffusers/scripts/convert_cogvideox_to_diffusers.py +346 -0
  19. diffusers/scripts/convert_consistency_decoder.py +1128 -0
  20. diffusers/scripts/convert_dance_diffusion_to_diffusers.py +346 -0
  21. diffusers/scripts/convert_dcae_to_diffusers.py +323 -0
  22. diffusers/scripts/convert_diffusers_sdxl_lora_to_webui.py +56 -0
  23. diffusers/scripts/convert_flux_xlabs_ipadapter_to_diffusers.py +97 -0
  24. diffusers/scripts/convert_hunyuandit_controlnet_to_diffusers.py +241 -0
  25. diffusers/scripts/convert_i2vgen_to_diffusers.py +510 -0
  26. diffusers/scripts/convert_if.py +1250 -0
  27. diffusers/scripts/convert_lora_safetensor_to_diffusers.py +128 -0
  28. diffusers/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py +185 -0
  29. diffusers/scripts/convert_omnigen_to_diffusers.py +203 -0
  30. diffusers/scripts/convert_original_audioldm2_to_diffusers.py +1135 -0
  31. diffusers/scripts/convert_original_musicldm_to_diffusers.py +1056 -0
  32. diffusers/scripts/convert_pixart_sigma_to_diffusers.py +223 -0
  33. diffusers/scripts/convert_sana_to_diffusers.py +456 -0
  34. diffusers/scripts/convert_stable_cascade.py +218 -0
  35. diffusers/scripts/convert_vae_pt_to_diffusers.py +177 -0
  36. diffusers/scripts/convert_wuerstchen.py +115 -0
  37. illustrious_generated/low_quality_images.json +0 -0
  38. illustrious_generated/natural_caption_generation_report.txt +14 -0
  39. illustrious_generated/optimization_final_results.json +0 -0
  40. illustrious_generated/optimization_summary_report.txt +20 -0
  41. illustrious_generated/regeneration_results.json +0 -0
  42. peft/.gitignore +145 -0
  43. peft/.pre-commit-config.yaml +13 -0
  44. peft/LICENSE +201 -0
  45. peft/Makefile +66 -0
  46. peft/README.md +189 -0
  47. peft/pyproject.toml +50 -0
  48. peft/requirements.txt +15 -0
  49. peft/setup.py +110 -0
  50. sentence-transformers/.gitignore +69 -0
.venv/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *
.venv/CACHEDIR.TAG ADDED
@@ -0,0 +1 @@
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
.venv/pyvenv.cfg ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ home = /home/ubuntu/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/bin
2
+ implementation = CPython
3
+ uv = 0.7.3
4
+ version_info = 3.12.10
5
+ include-system-site-packages = false
6
+ prompt = QwenIllustrious
arch/README.md ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Qwen-SDXL Architecture Components
2
+
3
+ 本目录包含了 Qwen-SDXL 项目的解耦架构组件,便于训练和推理的模块化使用。
4
+
5
+ ## 📁 组件结构
6
+
7
+ ```
8
+ arch/
9
+ ├── __init__.py # 组件导入和导出
10
+ ├── adapter.py # Qwen 嵌入适配器
11
+ ├── text_encoder.py # Qwen 文本编码器
12
+ ├── model_loader.py # 模型加载工具
13
+ ├── pipeline.py # 推理管道
14
+ ├── training.py # 训练工具和损失函数
15
+ ├── data_loader.py # 数据加载器
16
+ ├── example_train.py # 示例训练脚本
17
+ └── README.md # 本文件
18
+ ```
19
+
20
+ ## 🧩 核心组件
21
+
22
+ ### 1. QwenEmbeddingAdapter (`adapter.py`)
23
+ 将 Qwen3 的 1024 维嵌入投影到 SDXL 兼容的维度:
24
+ - 文本嵌入: 1024 → 2048 (用于 encoder_hidden_states)
25
+ - 池化嵌入: 1024 → 1280 (用于 text_embeds)
26
+
27
+ ```python
28
+ from arch import QwenEmbeddingAdapter
29
+
30
+ adapter = QwenEmbeddingAdapter()
31
+ text_embeddings = adapter.forward_text_embeddings(qwen_embeddings)
32
+ pooled_embeddings = adapter.forward_pooled_embeddings(qwen_pooled)
33
+ ```
34
+
35
+ ### 2. QwenTextEncoder (`text_encoder.py`)
36
+ 封装 Qwen3 模型的文本编码功能:
37
+
38
+ ```python
39
+ from arch import QwenTextEncoder
40
+
41
+ text_encoder = QwenTextEncoder(model_path="path/to/qwen3")
42
+ text_emb, pooled_emb = text_encoder.encode_prompts(
43
+ ["a beautiful landscape"],
44
+ ["low quality"]
45
+ )
46
+ ```
47
+
48
+ ### 3. 模型加载器 (`model_loader.py`)
49
+ 提供各种模型组件的加载功能:
50
+
51
+ ```python
52
+ from arch import load_unet_from_safetensors, load_vae_from_safetensors
53
+
54
+ unet = load_unet_from_safetensors("unet.safetensors", "unet_config.json")
55
+ vae = load_vae_from_safetensors("vae.safetensors", "vae_config.json")
56
+ ```
57
+
58
+ ### 4. 推理管道 (`pipeline.py`)
59
+ 完整的 Qwen-SDXL 推理管道:
60
+
61
+ ```python
62
+ from arch import QwenSDXLInference
63
+
64
+ pipeline = QwenSDXLInference(
65
+ qwen_model_path="path/to/qwen3",
66
+ adapter_path="path/to/trained/adapter.safetensors" # 可选
67
+ )
68
+
69
+ images = pipeline.generate(
70
+ prompt="a beautiful landscape",
71
+ height=1024,
72
+ width=1024
73
+ )
74
+ ```
75
+
76
+ ## 🎓 训练组件
77
+
78
+ ### 1. DiffusionLoss (`training.py`)
79
+ 扩散训练损失函数,支持多种损失类型和 SNR 加权:
80
+
81
+ ```python
82
+ from arch import DiffusionLoss
83
+
84
+ loss_fn = DiffusionLoss(
85
+ noise_scheduler=scheduler,
86
+ loss_type="mse",
87
+ snr_gamma=5.0 # Min-SNR weighting
88
+ )
89
+ ```
90
+
91
+ ### 2. AdapterTrainingStep (`training.py`)
92
+ 适配器训练步骤,自动处理前向传播和损失计算:
93
+
94
+ ```python
95
+ from arch import AdapterTrainingStep
96
+
97
+ training_step = AdapterTrainingStep(
98
+ unet=unet,
99
+ vae=vae,
100
+ text_encoder=text_encoder,
101
+ adapter=adapter,
102
+ noise_scheduler=scheduler,
103
+ loss_fn=loss_fn
104
+ )
105
+
106
+ result = training_step.training_step(images, prompts)
107
+ loss = result["loss"]
108
+ ```
109
+
110
+ ### 3. 数据加载器 (`data_loader.py`)
111
+
112
+ #### 基础数据集
113
+ ```python
114
+ from arch import ImageCaptionDataset, create_dataloader
115
+
116
+ dataset = ImageCaptionDataset(
117
+ data_root="/path/to/images",
118
+ annotations_file="captions.jsonl"
119
+ )
120
+
121
+ dataloader = create_dataloader(dataset, batch_size=4)
122
+ ```
123
+
124
+ #### 多长宽比数据集
125
+ ```python
126
+ from arch import MultiAspectDataset
127
+
128
+ dataset = MultiAspectDataset(
129
+ data_root="/path/to/images",
130
+ annotations_file="captions.jsonl",
131
+ aspect_ratios=[(1024, 1024), (1152, 896), (896, 1152)]
132
+ )
133
+ ```
134
+
135
+ ## 🚀 使用示例
136
+
137
+ ### 1. 快速推理
138
+
139
+ ```python
140
+ from arch import QwenSDXLInference
141
+
142
+ # 初始化管道
143
+ pipeline = QwenSDXLInference()
144
+
145
+ # 生成图像
146
+ images = pipeline.generate(
147
+ prompt="a serene mountain landscape at sunset",
148
+ negative_prompt="low quality, blurry",
149
+ height=1024,
150
+ width=1024,
151
+ num_inference_steps=50,
152
+ guidance_scale=7.5
153
+ )
154
+
155
+ # 保存图像
156
+ images[0].save("generated_image.png")
157
+ ```
158
+
159
+ ### 2. 训练适配器
160
+
161
+ 参考 `example_train.py` 中的完整训练脚本:
162
+
163
+ ```bash
164
+ python arch/example_train.py \
165
+ --data_root /path/to/images \
166
+ --annotations_file /path/to/captions.jsonl \
167
+ --batch_size 4 \
168
+ --learning_rate 1e-4 \
169
+ --num_epochs 10 \
170
+ --output_dir ./checkpoints \
171
+ --use_wandb
172
+ ```
173
+
174
+ ### 3. 使用训练好的适配器
175
+
176
+ ```python
177
+ from arch import QwenSDXLInference
178
+
179
+ # 加载带有训练好的适配器的管道
180
+ pipeline = QwenSDXLInference(
181
+ adapter_path="checkpoints/adapter_epoch_10_step_5000.safetensors"
182
+ )
183
+
184
+ images = pipeline.generate("your prompt here")
185
+ ```
186
+
187
+ ## 📊 训练配置
188
+
189
+ ### 推荐的训练设置
190
+
191
+ - **学习率**: 1e-4 (AdamW)
192
+ - **批量大小**: 4-8 (根据 GPU 内存调整)
193
+ - **梯度累积**: 如果内存不足可使用
194
+ - **损失函数**: MSE 或 Huber
195
+ - **SNR 加权**: gamma=5.0 (Min-SNR)
196
+ - **EMA**: decay=0.9999
197
+ - **学习率调度**: 余弦退火 + 预热
198
+
199
+ ### 数据格式
200
+
201
+ 支持 JSON 和 JSONL 格式的标注文件:
202
+
203
+ ```jsonl
204
+ {"image": "image1.jpg", "caption": "A beautiful landscape"}
205
+ {"image": "image2.jpg", "caption": "A cute cat"}
206
+ ```
207
+
208
+ ## 🔧 自定义和扩展
209
+
210
+ ### 1. 自定义适配器架构
211
+
212
+ 修改 `adapter.py` 中的 `QwenEmbeddingAdapter` 类:
213
+
214
+ ```python
215
+ class CustomAdapter(QwenEmbeddingAdapter):
216
+ def __init__(self, qwen_dim=1024, sdxl_text_dim=2048, sdxl_pooled_dim=1280):
217
+ super().__init__(qwen_dim, sdxl_text_dim, sdxl_pooled_dim)
218
+ # 添加自定义层
219
+ self.custom_layer = nn.Linear(sdxl_text_dim, sdxl_text_dim)
220
+ ```
221
+
222
+ ### 2. 自定义损失函数
223
+
224
+ 继承 `DiffusionLoss` 类:
225
+
226
+ ```python
227
+ class CustomLoss(DiffusionLoss):
228
+ def forward(self, model_pred, target, timesteps, mask=None):
229
+ # 自定义损失计算
230
+ base_loss = super().forward(model_pred, target, timesteps, mask)
231
+ # 添加额外的正则化项
232
+ return base_loss + custom_regularization
233
+ ```
234
+
235
+ ### 3. 自定义数据加载
236
+
237
+ 继承数据集类:
238
+
239
+ ```python
240
+ class CustomDataset(ImageCaptionDataset):
241
+ def __getitem__(self, idx):
242
+ sample = super().__getitem__(idx)
243
+ # 添加自定义数据增强或预处理
244
+ return sample
245
+ ```
246
+
247
+ ## ⚠️ 注意事项
248
+
249
+ 1. **内存管理**: 使用大模型时注意 GPU 内存使用
250
+ 2. **数据类型**: 推荐使用 bfloat16 以平衡性能和精度
251
+ 3. **检查点保存**: 定期保存检查点以防止训练中断
252
+ 4. **验证集**: 使用验证集监控训练进度
253
+ 5. **梯度裁剪**: 防止梯度爆炸
254
+
255
+ ## 🐛 故障排除
256
+
257
+ ### 常见问题
258
+
259
+ 1. **CUDA 内存不足**: 减小批量大小或使用梯度累积
260
+ 2. **模型加载失败**: 检查模型路径和配置文件
261
+ 3. **生成图像质量差**: 调整学习率、损失函数或训练更多步数
262
+ 4. **训练不稳定**: 使用梯度裁剪和 EMA
263
+
264
+ ### 调试模式
265
+
266
+ 可以在各个组件中启用调试输出:
267
+
268
+ ```python
269
+ # 启用详细日志
270
+ import logging
271
+ logging.basicConfig(level=logging.DEBUG)
272
+ ```
273
+
274
+ ## 📚 参考资料
275
+
276
+ - [Stable Diffusion XL](https://arxiv.org/abs/2307.01952)
277
+ - [Qwen3 Embedding](https://github.com/QwenLM/Qwen)
278
+ - [Diffusers Library](https://github.com/huggingface/diffusers)
arch/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Architecture components for Qwen-SDXL
3
+ Qwen-SDXL 架构组件
4
+ """
5
+
6
+ # Core components
7
+ from .adapter import QwenEmbeddingAdapter
8
+ from .text_encoder import QwenTextEncoder, encode_text_with_qwen
9
+ from .model_loader import (
10
+ load_qwen_model,
11
+ load_unet_from_safetensors,
12
+ load_vae_from_safetensors,
13
+ create_scheduler,
14
+ save_model_components,
15
+ load_checkpoint
16
+ )
17
+ from .pipeline import QwenIllustriousInference
18
+
19
+ # Training components
20
+ from .training import (
21
+ DiffusionLoss,
22
+ AdapterTrainingStep,
23
+ get_cosine_schedule_with_warmup,
24
+ EMAModel
25
+ )
26
+
27
+ # Data loading components
28
+ from .data_loader import (
29
+ ImageCaptionDataset,
30
+ MultiAspectDataset,
31
+ collate_fn,
32
+ create_dataloader
33
+ )
34
+
35
+ __all__ = [
36
+ # Core components
37
+ "QwenEmbeddingAdapter",
38
+ "QwenTextEncoder",
39
+ "encode_text_with_qwen",
40
+ "load_qwen_model",
41
+ "load_unet_from_safetensors",
42
+ "load_vae_from_safetensors",
43
+ "create_scheduler",
44
+ "save_model_components",
45
+ "load_checkpoint",
46
+ "QwenIllustriousInference",
47
+
48
+ # Training components
49
+ "DiffusionLoss",
50
+ "AdapterTrainingStep",
51
+ "get_cosine_schedule_with_warmup",
52
+ "EMAModel",
53
+
54
+ # Data loading components
55
+ "ImageCaptionDataset",
56
+ "MultiAspectDataset",
57
+ "collate_fn",
58
+ "create_dataloader"
59
+ ]
arch/adapter.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Qwen Embedding Adapter
3
+ Qwen 嵌入适配器 - 将 Qwen3 嵌入维度投影到 SDXL 兼容维度
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class QwenEmbeddingAdapter(nn.Module):
12
+ """
13
+ Adapter layer to project Qwen3 embeddings to SDXL-compatible dimensions
14
+ 将 Qwen3 嵌入维度投影到 SDXL 兼容维度
15
+ - Text embeddings: 1024 -> 2048 (for encoder_hidden_states)
16
+ - Pooled embeddings: 1024 -> 1280 (for text_embeds in added_cond_kwargs)
17
+ """
18
+ def __init__(self, qwen_dim=1024, sdxl_text_dim=2048, sdxl_pooled_dim=1280):
19
+ super().__init__()
20
+ self.qwen_dim = qwen_dim
21
+ self.sdxl_text_dim = sdxl_text_dim
22
+ self.sdxl_pooled_dim = sdxl_pooled_dim
23
+
24
+ # Text embeddings projection (for encoder_hidden_states)
25
+ self.text_projection = nn.Linear(qwen_dim, sdxl_text_dim)
26
+ self.text_layer_norm = nn.LayerNorm(sdxl_text_dim)
27
+ self.text_activation = nn.GELU()
28
+
29
+ # Pooled embeddings MLP (for text_embeds in added_cond_kwargs)
30
+ self.pooled_mlp = nn.Sequential(
31
+ nn.Linear(qwen_dim, qwen_dim * 2),
32
+ nn.GELU(),
33
+ nn.Dropout(0.1),
34
+ nn.Linear(qwen_dim * 2, sdxl_pooled_dim),
35
+ nn.LayerNorm(sdxl_pooled_dim)
36
+ )
37
+
38
+ # 初始化权重
39
+ self._init_weights()
40
+
41
+ def _init_weights(self):
42
+ """Initialize weights for better training stability"""
43
+ # Text projection initialization
44
+ nn.init.xavier_uniform_(self.text_projection.weight)
45
+ nn.init.zeros_(self.text_projection.bias)
46
+
47
+ # Pooled MLP initialization
48
+ for module in self.pooled_mlp:
49
+ if isinstance(module, nn.Linear):
50
+ nn.init.xavier_uniform_(module.weight)
51
+ nn.init.zeros_(module.bias)
52
+
53
+ def forward_text_embeddings(self, qwen_embeddings):
54
+ """
55
+ Project text embeddings for encoder_hidden_states
56
+ Args:
57
+ qwen_embeddings: tensor of shape [batch_size, seq_len, 1024]
58
+ Returns:
59
+ text_embeddings: tensor of shape [batch_size, seq_len, 2048]
60
+ """
61
+ projected = self.text_projection(qwen_embeddings)
62
+ projected = self.text_activation(projected)
63
+ return self.text_layer_norm(projected)
64
+
65
+ def forward_pooled_embeddings(self, qwen_embeddings):
66
+ """
67
+ Project pooled embeddings for text_embeds (using MLP)
68
+ Args:
69
+ qwen_embeddings: tensor of shape [batch_size, 1024]
70
+ Returns:
71
+ pooled_embeddings: tensor of shape [batch_size, 1280]
72
+ """
73
+ return self.pooled_mlp(qwen_embeddings)
74
+
75
+ def forward(self, text_embeddings, pooled_embeddings):
76
+ """
77
+ Forward pass for both text and pooled embeddings
78
+ Args:
79
+ text_embeddings: tensor of shape [batch_size, seq_len, 1024]
80
+ pooled_embeddings: tensor of shape [batch_size, 1024]
81
+ Returns:
82
+ tuple: (projected_text_embeddings, projected_pooled_embeddings)
83
+ """
84
+ projected_text = self.forward_text_embeddings(text_embeddings)
85
+ projected_pooled = self.forward_pooled_embeddings(pooled_embeddings)
86
+ return projected_text, projected_pooled
arch/data_loader.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Loading Utilities for QwenIllustrious
3
+ 数据加载工具 - 处理训练数据的加载和预处理,支持预计算嵌入加速
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+ from PIL import Image
11
+ import torchvision.transforms as transforms
12
+ from pathlib import Path
13
+ from typing import Dict, List, Optional, Tuple
14
+ import pickle
15
+ from tqdm import tqdm
16
+
17
+
18
+ class QwenIllustriousDataset(Dataset):
19
+ """
20
+ Dataset for QwenIllustrious training
21
+ 支持以下功能:
22
+ - 从 metadata.json 文件加载图像和标注
23
+ - 图像预处理和增强
24
+ - Qwen 文本编码缓存
25
+ - VAE 潜在空间编码缓存
26
+ - 训练时的预计算加速
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ dataset_path: str,
32
+ qwen_text_encoder=None,
33
+ vae=None,
34
+ image_size: int = 1024,
35
+ cache_dir: Optional[str] = None,
36
+ precompute_embeddings: bool = False
37
+ ):
38
+ self.dataset_path = Path(dataset_path)
39
+ self.qwen_text_encoder = qwen_text_encoder
40
+ self.vae = vae
41
+ self.image_size = image_size
42
+ self.cache_dir = Path(cache_dir) if cache_dir else None
43
+ self.precompute_embeddings = precompute_embeddings
44
+
45
+ # Setup image transforms
46
+ self.image_transforms = transforms.Compose([
47
+ transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
48
+ transforms.ToTensor(),
49
+ transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1]
50
+ ])
51
+
52
+ # Load metadata
53
+ self.metadata = self._load_metadata()
54
+
55
+ # Setup cache directories
56
+ if self.cache_dir:
57
+ self.cache_dir.mkdir(exist_ok=True)
58
+ self.text_cache_dir = self.cache_dir / "text_embeddings"
59
+ self.vae_cache_dir = self.cache_dir / "vae_latents"
60
+ self.text_cache_dir.mkdir(exist_ok=True)
61
+ self.vae_cache_dir.mkdir(exist_ok=True)
62
+
63
+ # Precomputed data storage
64
+ self.precomputed_data = {}
65
+
66
+ def _load_metadata(self) -> List[Dict]:
67
+ """Load all metadata files"""
68
+ metadata_dir = self.dataset_path / "metadata"
69
+ if not metadata_dir.exists():
70
+ raise ValueError(f"Metadata directory not found: {metadata_dir}")
71
+
72
+ metadata_files = list(metadata_dir.glob("*.json"))
73
+
74
+ metadata_list = []
75
+ print(f"Loading metadata from {len(metadata_files)} files...")
76
+
77
+ for file_path in tqdm(metadata_files, desc="Loading metadata"):
78
+ try:
79
+ with open(file_path, 'r', encoding='utf-8') as f:
80
+ data = json.load(f)
81
+ # Add file path info
82
+ data['metadata_file'] = str(file_path)
83
+ data['image_file'] = str(self.dataset_path / f"{data['filename_hash']}.png")
84
+ metadata_list.append(data)
85
+ except Exception as e:
86
+ print(f"Error loading {file_path}: {e}")
87
+ continue
88
+
89
+ print(f"Successfully loaded {len(metadata_list)} metadata files")
90
+ return metadata_list
91
+
92
+ def _get_text_cache_path(self, filename_hash: str) -> Path:
93
+ """Get path for cached text embeddings"""
94
+ return self.text_cache_dir / f"{filename_hash}_text.pt"
95
+
96
+ def _get_vae_cache_path(self, filename_hash: str) -> Path:
97
+ """Get path for cached VAE latents"""
98
+ return self.vae_cache_dir / f"{filename_hash}_vae.pt"
99
+
100
+ def _compute_text_embeddings(self, prompt: str, device='cpu') -> Dict[str, torch.Tensor]:
101
+ """Compute text embeddings using Qwen text encoder"""
102
+ if not self.qwen_text_encoder:
103
+ # Return dummy embeddings
104
+ return {
105
+ 'text_embeddings': torch.zeros(1, 2048), # SDXL text embedding size
106
+ 'pooled_embeddings': torch.zeros(1, 1280) # SDXL pooled embedding size
107
+ }
108
+
109
+ with torch.no_grad():
110
+ # Move to device temporarily for computation
111
+ original_device = next(self.qwen_text_encoder.parameters()).device
112
+ self.qwen_text_encoder.to(device)
113
+
114
+ embeddings = self.qwen_text_encoder.encode_prompts([prompt])
115
+
116
+ # Move back to original device
117
+ self.qwen_text_encoder.to(original_device)
118
+
119
+ return {
120
+ 'text_embeddings': embeddings[0].cpu(),
121
+ 'pooled_embeddings': embeddings[1].cpu() if len(embeddings) > 1 else embeddings[0].cpu()
122
+ }
123
+
124
+ def _compute_vae_latents(self, image: torch.Tensor, device='cpu') -> torch.Tensor:
125
+ """Compute VAE latents for image"""
126
+ if not self.vae:
127
+ # Return dummy latents
128
+ return torch.zeros(1, 4, self.image_size // 8, self.image_size // 8)
129
+
130
+ with torch.no_grad():
131
+ # Move to device temporarily for computation
132
+ original_device = next(self.vae.parameters()).device
133
+ self.vae.to(device)
134
+
135
+ # Add batch dimension if needed
136
+ if image.dim() == 3:
137
+ image = image.unsqueeze(0)
138
+
139
+ image = image.to(device).to(self.vae.dtype)
140
+ latents = self.vae.encode(image).latent_dist.sample()
141
+ latents = latents * self.vae.config.scaling_factor
142
+
143
+ # Move back to original device
144
+ self.vae.to(original_device)
145
+
146
+ return latents.cpu()
147
+
148
+ def _load_or_compute_text_embeddings(self, prompt: str, filename_hash: str, device='cpu') -> Dict[str, torch.Tensor]:
149
+ """Load cached text embeddings or compute new ones"""
150
+ if self.cache_dir:
151
+ cache_path = self._get_text_cache_path(filename_hash)
152
+
153
+ # Try to load from cache
154
+ if cache_path.exists():
155
+ try:
156
+ return torch.load(cache_path, map_location='cpu')
157
+ except Exception as e:
158
+ print(f"Error loading cached text embeddings {cache_path}: {e}")
159
+
160
+ # Compute new embeddings
161
+ embeddings = self._compute_text_embeddings(prompt, device)
162
+
163
+ # Cache the embeddings
164
+ if self.cache_dir:
165
+ try:
166
+ torch.save(embeddings, cache_path)
167
+ except Exception as e:
168
+ print(f"Error saving text embeddings cache {cache_path}: {e}")
169
+
170
+ return embeddings
171
+
172
+ def _load_or_compute_vae_latents(self, image_path: str, filename_hash: str, device='cpu') -> torch.Tensor:
173
+ """Load cached VAE latents or compute new ones"""
174
+ if self.cache_dir:
175
+ cache_path = self._get_vae_cache_path(filename_hash)
176
+
177
+ # Try to load from cache
178
+ if cache_path.exists():
179
+ try:
180
+ return torch.load(cache_path, map_location='cpu')
181
+ except Exception as e:
182
+ print(f"Error loading cached VAE latents {cache_path}: {e}")
183
+
184
+ # Load and process image
185
+ try:
186
+ image = Image.open(image_path).convert('RGB')
187
+ image = self.image_transforms(image)
188
+ except Exception as e:
189
+ print(f"Error loading image {image_path}: {e}")
190
+ image = torch.zeros(3, self.image_size, self.image_size)
191
+
192
+ # Compute latents
193
+ latents = self._compute_vae_latents(image, device)
194
+
195
+ # Cache the latents
196
+ if self.cache_dir:
197
+ try:
198
+ torch.save(latents, cache_path)
199
+ except Exception as e:
200
+ print(f"Error saving VAE latents cache {cache_path}: {e}")
201
+
202
+ return latents
203
+
204
+ def precompute_all(self, device='cuda'):
205
+ """Precompute all embeddings and latents for faster training"""
206
+ print("Precomputing all embeddings and latents...")
207
+
208
+ for idx in tqdm(range(len(self.metadata)), desc="Precomputing"):
209
+ metadata = self.metadata[idx]
210
+ filename_hash = metadata['filename_hash']
211
+
212
+ # Get prompt
213
+ prompt = metadata.get('natural_caption_data', {}).get('natural_caption', '')
214
+ if not prompt:
215
+ prompt = metadata.get('original_prompt_data', {}).get('positive_prompt', '')
216
+
217
+ # Precompute text embeddings
218
+ text_embeddings = self._load_or_compute_text_embeddings(prompt, filename_hash, device)
219
+
220
+ # Precompute VAE latents
221
+ vae_latents = self._load_or_compute_vae_latents(metadata['image_file'], filename_hash, device)
222
+
223
+ # Store in memory for fast access
224
+ self.precomputed_data[filename_hash] = {
225
+ 'text_embeddings': text_embeddings['text_embeddings'].squeeze(0),
226
+ 'pooled_embeddings': text_embeddings['pooled_embeddings'].squeeze(0),
227
+ 'latents': vae_latents.squeeze(0),
228
+ 'prompt': prompt
229
+ }
230
+
231
+ print(f"Precomputation completed for {len(self.precomputed_data)} items")
232
+
233
+ def __len__(self):
234
+ return len(self.metadata)
235
+
236
+ def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
237
+ metadata = self.metadata[idx]
238
+ filename_hash = metadata['filename_hash']
239
+
240
+ if self.precompute_embeddings and filename_hash in self.precomputed_data:
241
+ # Use precomputed data
242
+ data = self.precomputed_data[filename_hash]
243
+ return {
244
+ 'text_embeddings': data['text_embeddings'],
245
+ 'pooled_embeddings': data['pooled_embeddings'],
246
+ 'latents': data['latents'],
247
+ 'prompts': data['prompt'],
248
+ 'filename_hash': filename_hash,
249
+ 'metadata': metadata
250
+ }
251
+ else:
252
+ # Load data on-the-fly
253
+
254
+ # Load image
255
+ image_path = metadata['image_file']
256
+ try:
257
+ image = Image.open(image_path).convert('RGB')
258
+ image = self.image_transforms(image)
259
+ except Exception as e:
260
+ print(f"Error loading image {image_path}: {e}")
261
+ image = torch.zeros(3, self.image_size, self.image_size)
262
+
263
+ # Get prompt
264
+ prompt = metadata.get('natural_caption_data', {}).get('natural_caption', '')
265
+ if not prompt:
266
+ prompt = metadata.get('original_prompt_data', {}).get('positive_prompt', '')
267
+
268
+ # Get text embeddings (will use cache if available)
269
+ text_embeddings = self._load_or_compute_text_embeddings(prompt, filename_hash)
270
+
271
+ return {
272
+ 'images': image,
273
+ 'prompts': prompt,
274
+ 'text_embeddings': text_embeddings['text_embeddings'].squeeze(0),
275
+ 'pooled_embeddings': text_embeddings['pooled_embeddings'].squeeze(0),
276
+ 'filename_hash': filename_hash,
277
+ 'metadata': metadata
278
+ }
279
+
280
+
281
+ def collate_fn(examples: List[Dict]) -> Dict[str, torch.Tensor]:
282
+ """Custom collate function for DataLoader"""
283
+ batch = {}
284
+
285
+ # Handle different data formats (precomputed vs on-the-fly)
286
+ if 'latents' in examples[0]:
287
+ # Precomputed format - embeddings and latents are already computed
288
+ batch['latents'] = torch.stack([example['latents'] for example in examples])
289
+ batch['text_embeddings'] = torch.stack([example['text_embeddings'] for example in examples])
290
+ batch['pooled_embeddings'] = torch.stack([example['pooled_embeddings'] for example in examples])
291
+ else:
292
+ # On-the-fly format - need to handle images
293
+ batch['images'] = torch.stack([example['images'] for example in examples])
294
+ batch['text_embeddings'] = torch.stack([example['text_embeddings'] for example in examples])
295
+ batch['pooled_embeddings'] = torch.stack([example['pooled_embeddings'] for example in examples])
296
+
297
+ # Handle string fields
298
+ batch['prompts'] = [example['prompts'] for example in examples]
299
+ batch['filename_hash'] = [example['filename_hash'] for example in examples]
300
+ batch['metadata'] = [example['metadata'] for example in examples]
301
+
302
+ return batch
303
+
304
+ import torch
305
+ from torch.utils.data import Dataset, DataLoader
306
+ from PIL import Image
307
+ import json
308
+ import os
309
+ from typing import List, Dict, Any, Optional, Tuple, Union
310
+ import torchvision.transforms as transforms
311
+ import random
312
+
313
+
314
+ class ImageCaptionDataset(Dataset):
315
+ """
316
+ Dataset for image-caption pairs
317
+ 图像-标题对数据集
318
+ """
319
+
320
+ def __init__(
321
+ self,
322
+ data_root: str,
323
+ annotations_file: str,
324
+ image_size: int = 1024,
325
+ center_crop: bool = True,
326
+ random_flip: bool = True,
327
+ caption_column: str = "caption",
328
+ image_column: str = "image",
329
+ max_caption_length: int = 512
330
+ ):
331
+ self.data_root = data_root
332
+ self.image_size = image_size
333
+ self.caption_column = caption_column
334
+ self.image_column = image_column
335
+ self.max_caption_length = max_caption_length
336
+
337
+ # Load annotations
338
+ self.annotations = self._load_annotations(annotations_file)
339
+
340
+ # Setup image transforms
341
+ self.image_transforms = self._setup_transforms(image_size, center_crop, random_flip)
342
+
343
+ print(f"📚 数据集加载完成: {len(self.annotations)} 个样本")
344
+
345
+ def _load_annotations(self, annotations_file: str) -> List[Dict]:
346
+ """Load annotations from file"""
347
+ if annotations_file.endswith('.json'):
348
+ with open(annotations_file, 'r', encoding='utf-8') as f:
349
+ data = json.load(f)
350
+ elif annotations_file.endswith('.jsonl'):
351
+ data = []
352
+ with open(annotations_file, 'r', encoding='utf-8') as f:
353
+ for line in f:
354
+ if line.strip():
355
+ data.append(json.loads(line))
356
+ else:
357
+ raise ValueError(f"Unsupported annotation file format: {annotations_file}")
358
+
359
+ # Filter valid samples
360
+ valid_data = []
361
+ for item in data:
362
+ if self.caption_column in item and self.image_column in item:
363
+ if isinstance(item[self.caption_column], str) and item[self.caption_column].strip():
364
+ valid_data.append(item)
365
+
366
+ print(f"📋 有效样本数: {len(valid_data)} / {len(data)}")
367
+ return valid_data
368
+
369
+ def _setup_transforms(self, size: int, center_crop: bool, random_flip: bool):
370
+ """Setup image preprocessing transforms"""
371
+ transform_list = []
372
+
373
+ # Resize
374
+ if center_crop:
375
+ transform_list.extend([
376
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
377
+ transforms.CenterCrop(size)
378
+ ])
379
+ else:
380
+ transform_list.append(
381
+ transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR)
382
+ )
383
+
384
+ # Random horizontal flip
385
+ if random_flip:
386
+ transform_list.append(transforms.RandomHorizontalFlip(p=0.5))
387
+
388
+ # Convert to tensor and normalize
389
+ transform_list.extend([
390
+ transforms.ToTensor(),
391
+ transforms.Normalize([0.5], [0.5]) # Scale to [-1, 1]
392
+ ])
393
+
394
+ return transforms.Compose(transform_list)
395
+
396
+ def __len__(self):
397
+ return len(self.annotations)
398
+
399
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
400
+ """Get a single sample"""
401
+ annotation = self.annotations[idx]
402
+
403
+ # Load image
404
+ image_path = os.path.join(self.data_root, annotation[self.image_column])
405
+ try:
406
+ image = Image.open(image_path)
407
+ if image.mode != 'RGB':
408
+ image = image.convert('RGB')
409
+ except Exception as e:
410
+ print(f"⚠️ 加载图像失败 {image_path}: {e}")
411
+ # Return a black image as fallback
412
+ image = Image.new('RGB', (self.image_size, self.image_size), (0, 0, 0))
413
+
414
+ # Apply transforms
415
+ image = self.image_transforms(image)
416
+
417
+ # Get caption
418
+ caption = annotation[self.caption_column]
419
+ if len(caption) > self.max_caption_length:
420
+ caption = caption[:self.max_caption_length]
421
+
422
+ return {
423
+ "images": image,
424
+ "captions": caption,
425
+ "image_paths": image_path
426
+ }
427
+
428
+
429
+ class MultiAspectDataset(Dataset):
430
+ """
431
+ Dataset that supports multiple aspect ratios
432
+ 支持多种长宽比的数据集
433
+ """
434
+
435
+ def __init__(
436
+ self,
437
+ data_root: str,
438
+ annotations_file: str,
439
+ base_size: int = 1024,
440
+ aspect_ratios: List[Tuple[int, int]] = None,
441
+ bucket_tolerance: float = 0.1,
442
+ caption_column: str = "caption",
443
+ image_column: str = "image",
444
+ max_caption_length: int = 512
445
+ ):
446
+ self.data_root = data_root
447
+ self.base_size = base_size
448
+ self.caption_column = caption_column
449
+ self.image_column = image_column
450
+ self.max_caption_length = max_caption_length
451
+
452
+ # Default aspect ratios for SDXL
453
+ if aspect_ratios is None:
454
+ aspect_ratios = [
455
+ (1024, 1024), # 1:1
456
+ (1152, 896), # 9:7
457
+ (896, 1152), # 7:9
458
+ (1216, 832), # 3:2
459
+ (832, 1216), # 2:3
460
+ (1344, 768), # 7:4
461
+ (768, 1344), # 4:7
462
+ (1536, 640), # 12:5
463
+ (640, 1536), # 5:12
464
+ ]
465
+
466
+ self.aspect_ratios = aspect_ratios
467
+ self.bucket_tolerance = bucket_tolerance
468
+
469
+ # Load and bucket annotations
470
+ self.annotations = self._load_and_bucket_annotations(annotations_file)
471
+
472
+ print(f"📚 多长宽比数据集加载完成: {len(self.annotations)} 个样本")
473
+ self._print_bucket_stats()
474
+
475
+ def _load_and_bucket_annotations(self, annotations_file: str) -> List[Dict]:
476
+ """Load annotations and assign to aspect ratio buckets"""
477
+ # Load annotations
478
+ if annotations_file.endswith('.json'):
479
+ with open(annotations_file, 'r', encoding='utf-8') as f:
480
+ data = json.load(f)
481
+ elif annotations_file.endswith('.jsonl'):
482
+ data = []
483
+ with open(annotations_file, 'r', encoding='utf-8') as f:
484
+ for line in f:
485
+ if line.strip():
486
+ data.append(json.loads(line))
487
+
488
+ bucketed_data = []
489
+
490
+ for item in data:
491
+ if self.caption_column not in item or self.image_column not in item:
492
+ continue
493
+
494
+ caption = item[self.caption_column]
495
+ if not isinstance(caption, str) or not caption.strip():
496
+ continue
497
+
498
+ # Try to get image dimensions to assign bucket
499
+ image_path = os.path.join(self.data_root, item[self.image_column])
500
+ try:
501
+ with Image.open(image_path) as img:
502
+ width, height = img.size
503
+ aspect_ratio = width / height
504
+
505
+ # Find best matching bucket
506
+ best_bucket = self._find_best_bucket(aspect_ratio)
507
+
508
+ item_copy = item.copy()
509
+ item_copy['bucket_width'] = best_bucket[0]
510
+ item_copy['bucket_height'] = best_bucket[1]
511
+ item_copy['original_width'] = width
512
+ item_copy['original_height'] = height
513
+
514
+ bucketed_data.append(item_copy)
515
+
516
+ except Exception as e:
517
+ print(f"⚠️ 无法获取图像尺寸 {image_path}: {e}")
518
+ # Use default 1:1 bucket
519
+ item_copy = item.copy()
520
+ item_copy['bucket_width'] = 1024
521
+ item_copy['bucket_height'] = 1024
522
+ item_copy['original_width'] = 1024
523
+ item_copy['original_height'] = 1024
524
+ bucketed_data.append(item_copy)
525
+
526
+ return bucketed_data
527
+
528
+ def _find_best_bucket(self, aspect_ratio: float) -> Tuple[int, int]:
529
+ """Find the best matching aspect ratio bucket"""
530
+ best_bucket = self.aspect_ratios[0]
531
+ best_diff = float('inf')
532
+
533
+ for bucket_w, bucket_h in self.aspect_ratios:
534
+ bucket_ratio = bucket_w / bucket_h
535
+ diff = abs(aspect_ratio - bucket_ratio)
536
+
537
+ if diff < best_diff:
538
+ best_diff = diff
539
+ best_bucket = (bucket_w, bucket_h)
540
+
541
+ return best_bucket
542
+
543
+ def _print_bucket_stats(self):
544
+ """Print statistics about bucket distribution"""
545
+ bucket_counts = {}
546
+ for item in self.annotations:
547
+ bucket = (item['bucket_width'], item['bucket_height'])
548
+ bucket_counts[bucket] = bucket_counts.get(bucket, 0) + 1
549
+
550
+ print("📊 长宽比分布:")
551
+ for bucket, count in sorted(bucket_counts.items()):
552
+ ratio = bucket[0] / bucket[1]
553
+ print(f" {bucket[0]}×{bucket[1]} (比例 {ratio:.2f}): {count} 个样本")
554
+
555
+ def _get_transforms(self, target_width: int, target_height: int):
556
+ """Get transforms for specific target size"""
557
+ return transforms.Compose([
558
+ transforms.Resize((target_height, target_width), interpolation=transforms.InterpolationMode.BILINEAR),
559
+ transforms.RandomHorizontalFlip(p=0.5),
560
+ transforms.ToTensor(),
561
+ transforms.Normalize([0.5], [0.5])
562
+ ])
563
+
564
+ def __len__(self):
565
+ return len(self.annotations)
566
+
567
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
568
+ """Get a single sample"""
569
+ annotation = self.annotations[idx]
570
+
571
+ # Get target dimensions from bucket
572
+ target_width = annotation['bucket_width']
573
+ target_height = annotation['bucket_height']
574
+
575
+ # Load and transform image
576
+ image_path = os.path.join(self.data_root, annotation[self.image_column])
577
+ try:
578
+ image = Image.open(image_path)
579
+ if image.mode != 'RGB':
580
+ image = image.convert('RGB')
581
+ except Exception as e:
582
+ print(f"⚠️ 加载图像失败 {image_path}: {e}")
583
+ image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
584
+
585
+ # Apply transforms
586
+ transforms_fn = self._get_transforms(target_width, target_height)
587
+ image = transforms_fn(image)
588
+
589
+ # Get caption
590
+ caption = annotation[self.caption_column]
591
+ if len(caption) > self.max_caption_length:
592
+ caption = caption[:self.max_caption_length]
593
+
594
+ return {
595
+ "images": image,
596
+ "captions": caption,
597
+ "image_paths": image_path,
598
+ "width": target_width,
599
+ "height": target_height
600
+ }
601
+
602
+
603
+ # def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
604
+ # """
605
+ # Custom collate function for batching
606
+ # 自定义批处理整理函数
607
+ # """
608
+ # # Check if all images have the same size
609
+ # sizes = [(item["images"].shape[-2], item["images"].shape[-1]) for item in batch]
610
+ # if len(set(sizes)) == 1:
611
+ # # All same size, can batch normally
612
+ # images = torch.stack([item["images"] for item in batch])
613
+ # captions = [item["captions"] for item in batch]
614
+
615
+ # result = {
616
+ # "images": images,
617
+ # "captions": captions,
618
+ # "image_paths": [item["image_paths"] for item in batch]
619
+ # }
620
+
621
+ # # Add width/height if available
622
+ # if "width" in batch[0]:
623
+ # result["widths"] = [item["width"] for item in batch]
624
+ # result["heights"] = [item["height"] for item in batch]
625
+
626
+ # return result
627
+ # else:
628
+ # # Different sizes, return as list
629
+ # return {
630
+ # "images": [item["images"] for item in batch],
631
+ # "captions": [item["captions"] for item in batch],
632
+ # "image_paths": [item["image_paths"] for item in batch],
633
+ # "widths": [item.get("width", item["images"].shape[-1]) for item in batch],
634
+ # "heights": [item.get("height", item["images"].shape[-2]) for item in batch]
635
+ # }
636
+
637
+
638
+ def create_dataloader(
639
+ dataset: Dataset,
640
+ batch_size: int = 4,
641
+ shuffle: bool = True,
642
+ num_workers: int = 4,
643
+ pin_memory: bool = True,
644
+ drop_last: bool = True
645
+ ) -> DataLoader:
646
+ """
647
+ Create dataloader with appropriate settings
648
+ 创建具有适当设置的数据加载器
649
+ """
650
+ return DataLoader(
651
+ dataset,
652
+ batch_size=batch_size,
653
+ shuffle=shuffle,
654
+ num_workers=num_workers,
655
+ pin_memory=pin_memory,
656
+ drop_last=drop_last,
657
+ collate_fn=collate_fn
658
+ )
arch/example_train.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example Training Script using Arch Components
3
+ 使用架构组件的示例训练脚本
4
+ """
5
+
6
+ import torch
7
+ import torch.optim as optim
8
+ from torch.utils.data import DataLoader
9
+ import os
10
+ import argparse
11
+ from tqdm import tqdm
12
+ import wandb
13
+ from typing import Optional
14
+
15
+ # Import arch components
16
+ from arch import (
17
+ QwenTextEncoder,
18
+ QwenEmbeddingAdapter,
19
+ load_unet_from_safetensors,
20
+ load_vae_from_safetensors,
21
+ create_scheduler,
22
+ DiffusionLoss,
23
+ AdapterTrainingStep,
24
+ get_cosine_schedule_with_warmup,
25
+ EMAModel,
26
+ ImageCaptionDataset,
27
+ MultiAspectDataset,
28
+ create_dataloader
29
+ )
30
+
31
+
32
+ def parse_args():
33
+ parser = argparse.ArgumentParser(description="Train Qwen-SDXL Adapter")
34
+
35
+ # Model paths
36
+ parser.add_argument("--qwen_model_path", type=str, default="models/Qwen3-Embedding-0.6B")
37
+ parser.add_argument("--unet_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_unet.safetensors")
38
+ parser.add_argument("--unet_config_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_unet_config.json")
39
+ parser.add_argument("--vae_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_vae.safetensors")
40
+ parser.add_argument("--vae_config_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_vae_config.json")
41
+
42
+ # Data
43
+ parser.add_argument("--data_root", type=str, required=True, help="Root directory of training images")
44
+ parser.add_argument("--annotations_file", type=str, required=True, help="Path to annotations file (JSON/JSONL)")
45
+ parser.add_argument("--caption_column", type=str, default="caption")
46
+ parser.add_argument("--image_column", type=str, default="image")
47
+ parser.add_argument("--use_multi_aspect", action="store_true", help="Use multi-aspect ratio dataset")
48
+
49
+ # Training
50
+ parser.add_argument("--batch_size", type=int, default=4)
51
+ parser.add_argument("--learning_rate", type=float, default=1e-4)
52
+ parser.add_argument("--num_epochs", type=int, default=10)
53
+ parser.add_argument("--warmup_steps", type=int, default=500)
54
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
55
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
56
+
57
+ # Loss
58
+ parser.add_argument("--loss_type", type=str, default="mse", choices=["mse", "l1", "huber"])
59
+ parser.add_argument("--snr_gamma", type=float, default=None, help="SNR gamma for loss weighting")
60
+ parser.add_argument("--use_v_parameterization", action="store_true")
61
+
62
+ # Optimization
63
+ parser.add_argument("--optimizer", type=str, default="adamw", choices=["adamw", "adam"])
64
+ parser.add_argument("--weight_decay", type=float, default=0.01)
65
+ parser.add_argument("--use_ema", action="store_true", help="Use EMA for adapter")
66
+ parser.add_argument("--ema_decay", type=float, default=0.9999)
67
+
68
+ # Checkpointing
69
+ parser.add_argument("--output_dir", type=str, default="./checkpoints")
70
+ parser.add_argument("--save_steps", type=int, default=1000)
71
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None)
72
+
73
+ # Logging
74
+ parser.add_argument("--logging_steps", type=int, default=50)
75
+ parser.add_argument("--use_wandb", action="store_true")
76
+ parser.add_argument("--wandb_project", type=str, default="qwen-sdxl-training")
77
+ parser.add_argument("--wandb_run_name", type=str, default=None)
78
+
79
+ # Hardware
80
+ parser.add_argument("--device", type=str, default="cuda")
81
+ parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"])
82
+ parser.add_argument("--num_workers", type=int, default=4)
83
+
84
+ return parser.parse_args()
85
+
86
+
87
+ def setup_models(args):
88
+ """Setup all model components"""
89
+ print("🚀 设置模型组件...")
90
+
91
+ # Convert dtype string to torch dtype
92
+ dtype_map = {
93
+ "float32": torch.float32,
94
+ "float16": torch.float16,
95
+ "bfloat16": torch.bfloat16
96
+ }
97
+ dtype = dtype_map[args.dtype]
98
+
99
+ # Load text encoder
100
+ print("📝 加载 Qwen 文本编码器...")
101
+ text_encoder = QwenTextEncoder(
102
+ model_path=args.qwen_model_path,
103
+ device=args.device,
104
+ freeze_encoder=True
105
+ )
106
+
107
+ # Initialize adapter
108
+ print("🔧 初始化适配器...")
109
+ adapter = QwenEmbeddingAdapter()
110
+ adapter.to(args.device, dtype)
111
+
112
+ # Load UNet
113
+ print("🏗️ 加载 UNet...")
114
+ unet = load_unet_from_safetensors(
115
+ args.unet_path,
116
+ args.unet_config_path,
117
+ args.device,
118
+ dtype
119
+ )
120
+
121
+ # Load VAE
122
+ print("🎨 加载 VAE...")
123
+ vae = load_vae_from_safetensors(
124
+ args.vae_path,
125
+ args.vae_config_path,
126
+ args.device,
127
+ dtype
128
+ )
129
+
130
+ # Create scheduler
131
+ print("⏰ 创建调度器...")
132
+ noise_scheduler = create_scheduler("DDPM")
133
+
134
+ return text_encoder, adapter, unet, vae, noise_scheduler, dtype
135
+
136
+
137
+ def setup_data(args):
138
+ """Setup data loaders"""
139
+ print("📚 设置数据加载器...")
140
+
141
+ if args.use_multi_aspect:
142
+ dataset = MultiAspectDataset(
143
+ data_root=args.data_root,
144
+ annotations_file=args.annotations_file,
145
+ caption_column=args.caption_column,
146
+ image_column=args.image_column
147
+ )
148
+ else:
149
+ dataset = ImageCaptionDataset(
150
+ data_root=args.data_root,
151
+ annotations_file=args.annotations_file,
152
+ caption_column=args.caption_column,
153
+ image_column=args.image_column
154
+ )
155
+
156
+ dataloader = create_dataloader(
157
+ dataset,
158
+ batch_size=args.batch_size,
159
+ shuffle=True,
160
+ num_workers=args.num_workers,
161
+ pin_memory=True,
162
+ drop_last=True
163
+ )
164
+
165
+ return dataloader
166
+
167
+
168
+ def setup_training(args, adapter, noise_scheduler):
169
+ """Setup training components"""
170
+ print("🎯 设置训练组件...")
171
+
172
+ # Loss function
173
+ loss_fn = DiffusionLoss(
174
+ noise_scheduler=noise_scheduler,
175
+ loss_type=args.loss_type,
176
+ snr_gamma=args.snr_gamma,
177
+ use_v_parameterization=args.use_v_parameterization
178
+ )
179
+
180
+ # Optimizer
181
+ if args.optimizer == "adamw":
182
+ optimizer = optim.AdamW(
183
+ adapter.parameters(),
184
+ lr=args.learning_rate,
185
+ weight_decay=args.weight_decay,
186
+ betas=(0.9, 0.999)
187
+ )
188
+ else:
189
+ optimizer = optim.Adam(
190
+ adapter.parameters(),
191
+ lr=args.learning_rate,
192
+ weight_decay=args.weight_decay
193
+ )
194
+
195
+ # EMA
196
+ ema = None
197
+ if args.use_ema:
198
+ ema = EMAModel(adapter, decay=args.ema_decay)
199
+
200
+ return loss_fn, optimizer, ema
201
+
202
+
203
+ def train_step(training_step_fn, batch, optimizer, args, ema=None):
204
+ """Execute one training step"""
205
+ # Handle different batch formats
206
+ if isinstance(batch["images"], list):
207
+ # Multi-size batch, train one by one
208
+ total_loss = 0
209
+ num_samples = 0
210
+
211
+ for i in range(len(batch["images"])):
212
+ images = batch["images"][i].unsqueeze(0)
213
+ captions = [batch["captions"][i]]
214
+
215
+ step_output = training_step_fn.training_step(images, captions)
216
+ loss = step_output["loss"] / args.gradient_accumulation_steps
217
+
218
+ loss.backward()
219
+ total_loss += loss.item()
220
+ num_samples += 1
221
+
222
+ avg_loss = total_loss / num_samples if num_samples > 0 else 0
223
+ else:
224
+ # Regular batch
225
+ images = batch["images"]
226
+ captions = batch["captions"]
227
+
228
+ step_output = training_step_fn.training_step(images, captions)
229
+ loss = step_output["loss"] / args.gradient_accumulation_steps
230
+
231
+ loss.backward()
232
+ avg_loss = loss.item()
233
+
234
+ # Gradient clipping and optimization step
235
+ torch.nn.utils.clip_grad_norm_(training_step_fn.adapter.parameters(), args.max_grad_norm)
236
+ optimizer.step()
237
+ optimizer.zero_grad()
238
+
239
+ # Update EMA
240
+ if ema is not None:
241
+ ema.update()
242
+
243
+ return avg_loss
244
+
245
+
246
+ def save_checkpoint(adapter, optimizer, ema, epoch, step, args):
247
+ """Save training checkpoint"""
248
+ os.makedirs(args.output_dir, exist_ok=True)
249
+
250
+ # Save adapter
251
+ adapter_path = os.path.join(args.output_dir, f"adapter_epoch_{epoch}_step_{step}.safetensors")
252
+ if hasattr(adapter, 'save_adapter'):
253
+ adapter.save_adapter(adapter_path)
254
+ else:
255
+ import safetensors.torch
256
+ safetensors.torch.save_file(adapter.state_dict(), adapter_path)
257
+
258
+ # Save EMA adapter if available
259
+ if ema is not None:
260
+ ema.apply_shadow()
261
+ ema_path = os.path.join(args.output_dir, f"adapter_ema_epoch_{epoch}_step_{step}.safetensors")
262
+ import safetensors.torch
263
+ safetensors.torch.save_file(adapter.state_dict(), ema_path)
264
+ ema.restore()
265
+
266
+ # Save training state
267
+ state_path = os.path.join(args.output_dir, f"training_state_epoch_{epoch}_step_{step}.pt")
268
+ torch.save({
269
+ "epoch": epoch,
270
+ "step": step,
271
+ "optimizer_state_dict": optimizer.state_dict(),
272
+ "args": args
273
+ }, state_path)
274
+
275
+ print(f"💾 检查点已保存: epoch {epoch}, step {step}")
276
+
277
+
278
+ def main():
279
+ args = parse_args()
280
+
281
+ # Setup wandb
282
+ if args.use_wandb:
283
+ wandb.init(
284
+ project=args.wandb_project,
285
+ name=args.wandb_run_name,
286
+ config=vars(args)
287
+ )
288
+
289
+ # Setup models
290
+ text_encoder, adapter, unet, vae, noise_scheduler, dtype = setup_models(args)
291
+
292
+ # Setup data
293
+ dataloader = setup_data(args)
294
+
295
+ # Setup training
296
+ loss_fn, optimizer, ema = setup_training(args, adapter, noise_scheduler)
297
+
298
+ # Create training step function
299
+ training_step_fn = AdapterTrainingStep(
300
+ unet=unet,
301
+ vae=vae,
302
+ text_encoder=text_encoder,
303
+ adapter=adapter,
304
+ noise_scheduler=noise_scheduler,
305
+ loss_fn=loss_fn,
306
+ device=args.device,
307
+ dtype=dtype
308
+ )
309
+
310
+ # Setup learning rate scheduler
311
+ total_steps = len(dataloader) * args.num_epochs // args.gradient_accumulation_steps
312
+ lr_scheduler = get_cosine_schedule_with_warmup(
313
+ optimizer,
314
+ num_warmup_steps=args.warmup_steps,
315
+ num_training_steps=total_steps
316
+ )
317
+
318
+ print(f"🎓 开始训练: {args.num_epochs} epochs, {len(dataloader)} steps/epoch")
319
+ print(f"📊 总训练步数: {total_steps}")
320
+
321
+ # Training loop
322
+ global_step = 0
323
+
324
+ for epoch in range(args.num_epochs):
325
+ adapter.train()
326
+ epoch_loss = 0
327
+
328
+ progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.num_epochs}")
329
+
330
+ for step, batch in enumerate(progress_bar):
331
+ step_loss = train_step(training_step_fn, batch, optimizer, args, ema)
332
+ epoch_loss += step_loss
333
+
334
+ # Update learning rate
335
+ lr_scheduler.step()
336
+
337
+ global_step += 1
338
+
339
+ # Logging
340
+ if global_step % args.logging_steps == 0:
341
+ avg_loss = epoch_loss / (step + 1)
342
+ current_lr = lr_scheduler.get_last_lr()[0]
343
+
344
+ progress_bar.set_postfix({
345
+ "loss": f"{step_loss:.4f}",
346
+ "avg_loss": f"{avg_loss:.4f}",
347
+ "lr": f"{current_lr:.2e}"
348
+ })
349
+
350
+ if args.use_wandb:
351
+ wandb.log({
352
+ "train/loss": step_loss,
353
+ "train/avg_loss": avg_loss,
354
+ "train/learning_rate": current_lr,
355
+ "train/epoch": epoch,
356
+ "train/step": global_step
357
+ })
358
+
359
+ # Save checkpoint
360
+ if global_step % args.save_steps == 0:
361
+ save_checkpoint(adapter, optimizer, ema, epoch, global_step, args)
362
+
363
+ # End of epoch
364
+ avg_epoch_loss = epoch_loss / len(dataloader)
365
+ print(f"📈 Epoch {epoch+1} 完成,平均损失: {avg_epoch_loss:.4f}")
366
+
367
+ # Save epoch checkpoint
368
+ save_checkpoint(adapter, optimizer, ema, epoch+1, global_step, args)
369
+
370
+ print("🎉 训练完成!")
371
+
372
+ if args.use_wandb:
373
+ wandb.finish()
374
+
375
+
376
+ if __name__ == "__main__":
377
+ main()
arch/model_loader.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Loader Utilities
3
+ 模型加载工具 - 用于加载各种模型组件
4
+ """
5
+
6
+ import torch
7
+ import json
8
+ import safetensors.torch
9
+ from typing import Optional
10
+
11
+
12
+ def load_unet_from_safetensors(unet_path: str, config_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16):
13
+ """
14
+ Load UNet from safetensors file
15
+ 从 safetensors 文件加载 UNet
16
+
17
+ Args:
18
+ unet_path: Path to UNet safetensors file
19
+ config_path: Path to UNet config JSON file
20
+ device: Device to load model on
21
+ dtype: Data type for model weights
22
+
23
+ Returns:
24
+ UNet2DConditionModel or None if loading fails
25
+ """
26
+ try:
27
+ from diffusers import UNet2DConditionModel
28
+
29
+ # Load config
30
+ with open(config_path, 'r') as f:
31
+ unet_config = json.load(f)
32
+
33
+ # Create UNet
34
+ unet = UNet2DConditionModel.from_config(unet_config)
35
+
36
+ # Load weights
37
+ state_dict = safetensors.torch.load_file(unet_path)
38
+ unet.load_state_dict(state_dict)
39
+ unet.to(device, dtype)
40
+
41
+ return unet
42
+ except Exception as e:
43
+ print(f"Error loading UNet: {e}")
44
+ return None
45
+
46
+
47
+ def load_vae_from_safetensors(vae_path: str, config_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16):
48
+ """
49
+ Load VAE from safetensors file
50
+ 从 safetensors 文件加载 VAE
51
+
52
+ Args:
53
+ vae_path: Path to VAE safetensors file
54
+ config_path: Path to VAE config JSON file
55
+ device: Device to load model on
56
+ dtype: Data type for model weights
57
+
58
+ Returns:
59
+ AutoencoderKL or None if loading fails
60
+ """
61
+ try:
62
+ from diffusers import AutoencoderKL
63
+
64
+ # Load config
65
+ with open(config_path, 'r') as f:
66
+ vae_config = json.load(f)
67
+
68
+ # Create VAE
69
+ vae = AutoencoderKL.from_config(vae_config)
70
+
71
+ # Load weights
72
+ state_dict = safetensors.torch.load_file(vae_path)
73
+ vae.load_state_dict(state_dict)
74
+ vae.to(device, dtype)
75
+
76
+ return vae
77
+ except Exception as e:
78
+ print(f"Error loading VAE: {e}")
79
+ return None
80
+
81
+
82
+ def create_scheduler(scheduler_type: str = "EulerAncestral", model_id: str = "stabilityai/stable-diffusion-xl-base-1.0"):
83
+ """
84
+ Create scheduler for diffusion process
85
+ 创建扩散过程调度器
86
+
87
+ Args:
88
+ scheduler_type: Type of scheduler to create
89
+ model_id: Model ID to load scheduler config from
90
+
91
+ Returns:
92
+ Scheduler object or None if creation fails
93
+ """
94
+ try:
95
+ if scheduler_type == "DDPM":
96
+ from diffusers import DDPMScheduler
97
+ scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
98
+ elif scheduler_type == "DDIM":
99
+ from diffusers import DDIMScheduler
100
+ scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
101
+ elif scheduler_type == "DPMSolverMultistep":
102
+ from diffusers import DPMSolverMultistepScheduler
103
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
104
+ elif scheduler_type == "EulerAncestral":
105
+ from diffusers import EulerAncestralDiscreteScheduler
106
+ scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
107
+ else:
108
+ print(f"Unsupported scheduler type: {scheduler_type}, using DDPM")
109
+ from diffusers import DDPMScheduler
110
+ scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
111
+
112
+ return scheduler
113
+ except Exception as e:
114
+ print(f"Error creating scheduler: {e}")
115
+ return None
116
+
117
+
118
+ def load_qwen_model(model_path: str, device: str = "cuda"):
119
+ """
120
+ Load Qwen3 embedding model
121
+ 加载 Qwen3 嵌入模型
122
+
123
+ Args:
124
+ model_path: Path to Qwen model
125
+ device: Device to load model on
126
+
127
+ Returns:
128
+ SentenceTransformer model or None if loading fails
129
+ """
130
+ try:
131
+ from sentence_transformers import SentenceTransformer
132
+ model = SentenceTransformer(model_path)
133
+ model.to(device)
134
+ return model
135
+ except ImportError:
136
+ print("Warning: sentence-transformers not available. Using mock embeddings.")
137
+ return None
138
+ except Exception as e:
139
+ print(f"Error loading Qwen model: {e}")
140
+ return None
141
+
142
+
143
+ def save_model_components(
144
+ unet,
145
+ vae,
146
+ adapter,
147
+ text_encoder,
148
+ save_dir: str,
149
+ save_format: str = "safetensors"
150
+ ):
151
+ """
152
+ Save model components for training checkpoints
153
+ 保存模型组件用于训练检查点
154
+
155
+ Args:
156
+ unet: UNet model
157
+ vae: VAE model
158
+ adapter: Qwen embedding adapter
159
+ text_encoder: Qwen text encoder
160
+ save_dir: Directory to save components
161
+ save_format: Format to save in (safetensors or pt)
162
+ """
163
+ import os
164
+ os.makedirs(save_dir, exist_ok=True)
165
+
166
+ try:
167
+ if save_format == "safetensors":
168
+ # Save UNet
169
+ if unet is not None:
170
+ safetensors.torch.save_file(
171
+ unet.state_dict(),
172
+ os.path.join(save_dir, "unet.safetensors")
173
+ )
174
+
175
+ # Save VAE
176
+ if vae is not None:
177
+ safetensors.torch.save_file(
178
+ vae.state_dict(),
179
+ os.path.join(save_dir, "vae.safetensors")
180
+ )
181
+
182
+ # Save adapter
183
+ if adapter is not None:
184
+ safetensors.torch.save_file(
185
+ adapter.state_dict(),
186
+ os.path.join(save_dir, "adapter.safetensors")
187
+ )
188
+
189
+ else: # PyTorch format
190
+ if unet is not None:
191
+ torch.save(unet.state_dict(), os.path.join(save_dir, "unet.pt"))
192
+ if vae is not None:
193
+ torch.save(vae.state_dict(), os.path.join(save_dir, "vae.pt"))
194
+ if adapter is not None:
195
+ torch.save(adapter.state_dict(), os.path.join(save_dir, "adapter.pt"))
196
+
197
+ print(f"Model components saved to {save_dir}")
198
+
199
+ except Exception as e:
200
+ print(f"Error saving model components: {e}")
201
+
202
+
203
+ def load_unet_with_lora(
204
+ unet_path: str,
205
+ unet_config_path: str,
206
+ lora_weights_path: Optional[str] = None,
207
+ lora_config_path: Optional[str] = None,
208
+ device: str = "cuda",
209
+ dtype: torch.dtype = torch.bfloat16
210
+ ):
211
+ """
212
+ Load UNet with optional LoRA weights
213
+ 加载带有可选LoRA权重的UNet
214
+
215
+ Args:
216
+ base_unet_path: Path to base UNet (can be safetensors file or HF model path)
217
+ lora_weights_path: Optional path to LoRA weights (safetensors file)
218
+ lora_config_path: Optional path to LoRA config directory
219
+ device: Device to load model on
220
+ dtype: Data type for model weights
221
+
222
+ Returns:
223
+ UNet model with LoRA applied if specified
224
+ """
225
+ try:
226
+ from diffusers import UNet2DConditionModel
227
+ from peft import PeftModel, LoraConfig
228
+
229
+ # Load base UNet
230
+ # if unet_path.endswith(".safetensors"):
231
+ # # Load from safetensors file (need config too)
232
+ # print("Loading UNet from safetensors format requires separate config file")
233
+ # return None
234
+ # else:
235
+ # Load from HuggingFace model path
236
+ # unet = UNet2DConditionModel.from_pretrained(
237
+ # base_unet_path,
238
+ # subfolder="unet" if "/" in base_unet_path and not base_unet_path.endswith("unet") else None,
239
+ # torch_dtype=dtype
240
+ # )
241
+ unet = load_unet_from_safetensors(unet_path, unet_config_path, device, dtype)
242
+
243
+ # Apply LoRA if provided
244
+ if lora_weights_path and lora_config_path:
245
+ print(f"Loading LoRA weights from {lora_weights_path}")
246
+
247
+ # Load LoRA weights
248
+ if lora_weights_path.endswith(".safetensors"):
249
+ import safetensors.torch
250
+ lora_state_dict = safetensors.torch.load_file(lora_weights_path)
251
+ else:
252
+ lora_state_dict = torch.load(lora_weights_path, map_location=device)
253
+
254
+ # Load LoRA config
255
+ lora_config = LoraConfig.from_pretrained(lora_config_path)
256
+
257
+ # Apply LoRA to UNet
258
+ from peft import get_peft_model, set_peft_model_state_dict
259
+ unet = get_peft_model(unet, lora_config)
260
+ set_peft_model_state_dict(unet, lora_state_dict)
261
+
262
+ print("LoRA weights applied to UNet")
263
+
264
+ unet.to(device, dtype)
265
+ return unet
266
+
267
+ except Exception as e:
268
+ print(f"Error loading UNet with LoRA: {e}")
269
+ return None
270
+
271
+
272
+ def load_fused_unet(
273
+ fused_unet_path: str,
274
+ device: str = "cuda",
275
+ dtype: torch.dtype = torch.bfloat16
276
+ ):
277
+ """
278
+ Load UNet with fused LoRA weights
279
+ 加载融合了LoRA权重的UNet
280
+
281
+ Args:
282
+ fused_unet_path: Path to fused UNet model directory
283
+ device: Device to load model on
284
+ dtype: Data type for model weights
285
+
286
+ Returns:
287
+ UNet model with fused LoRA weights
288
+ """
289
+ try:
290
+ from diffusers import UNet2DConditionModel
291
+
292
+ unet = UNet2DConditionModel.from_pretrained(
293
+ fused_unet_path,
294
+ torch_dtype=dtype
295
+ )
296
+
297
+ unet.to(device, dtype)
298
+ print(f"Fused UNet loaded from {fused_unet_path}")
299
+ return unet
300
+
301
+ except Exception as e:
302
+ print(f"Error loading fused UNet: {e}")
303
+ return None
304
+
305
+
306
+ def load_checkpoint(checkpoint_path: str, device: str = "cuda"):
307
+ """
308
+ Load training checkpoint
309
+ 加载训练检查点
310
+
311
+ Args:
312
+ checkpoint_path: Path to checkpoint file
313
+ device: Device to load on
314
+
315
+ Returns:
316
+ Dictionary containing checkpoint data
317
+ """
318
+ try:
319
+ if checkpoint_path.endswith(".safetensors"):
320
+ return safetensors.torch.load_file(checkpoint_path, device=device)
321
+ else:
322
+ return torch.load(checkpoint_path, map_location=device)
323
+ except Exception as e:
324
+ print(f"Error loading checkpoint: {e}")
325
+ return None
arch/pipeline.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Qwen-SDXL Inference Pipeline
3
+ Qwen-SDXL 推理管道 - 使用 Qwen3 嵌入模型替代 CLIP 文本编码器的 SDXL 推理管道
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ from PIL import Image
10
+ from typing import List, Optional, Union, Tuple
11
+
12
+ from .adapter import QwenEmbeddingAdapter
13
+ from .text_encoder import QwenTextEncoder
14
+ from .model_loader import load_qwen_model, load_unet_from_safetensors, load_vae_from_safetensors, create_scheduler
15
+
16
+
17
+ class QwenIllustriousInference:
18
+ """
19
+ Qwen-SDXL 推理管道
20
+ 使用 Qwen3 嵌入模型替代 CLIP 文本编码器的 SDXL 推理管道
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ qwen_model_path: str = "models/Qwen3-Embedding-0.6B",
26
+ unet_path: str = "models/extracted_components/waiNSFWIllustrious_v140_unet.safetensors",
27
+ unet_config_path: str = "models/extracted_components/waiNSFWIllustrious_v140_unet_config.json",
28
+ vae_path: str = "models/extracted_components/waiNSFWIllustrious_v140_vae.safetensors",
29
+ vae_config_path: str = "models/extracted_components/waiNSFWIllustrious_v140_vae_config.json",
30
+ adapter_path: Optional[str] = "/home/ubuntu/lyl/QwenIllustrious/qwen_illustrious_output/adapter/adapter.safetensors",
31
+ lora_weights_path: Optional[str] = "/home/ubuntu/lyl/QwenIllustrious/qwen_illustrious_output/lora_weights/lora_weights.safetensors",
32
+ lora_config_path: Optional[str] = "/home/ubuntu/lyl/QwenIllustrious/qwen_illustrious_output/lora_weights/adapter_config.json",
33
+ use_fused_unet: bool = False,
34
+ fused_unet_path: Optional[str] = None,
35
+ device: str = "cuda",
36
+ dtype: torch.dtype = torch.bfloat16,
37
+ scheduler_type: str = "DDPM"
38
+ ):
39
+ self.device = device
40
+ self.dtype = dtype
41
+ self.vae_scale_factor = 8 # SDXL default
42
+
43
+ print("🚀 初始化 Qwen-SDXL 推理管道...")
44
+
45
+ # Initialize text encoder
46
+ print("📝 初始化 Qwen 文本编码器...")
47
+ self.text_encoder = QwenTextEncoder(
48
+ model_path=qwen_model_path,
49
+ device=device,
50
+ freeze_encoder=True
51
+ )
52
+
53
+ # Initialize adapter layer
54
+ print("🔧 初始化适配器层...")
55
+ self.adapter = QwenEmbeddingAdapter()
56
+ self.adapter.to(device, dtype)
57
+
58
+ # Load adapter weights if provided
59
+ if adapter_path is not None:
60
+ print(f"📥 加载适配器权重: {adapter_path}")
61
+ try:
62
+ if adapter_path.endswith(".safetensors"):
63
+ import safetensors.torch
64
+ adapter_state = safetensors.torch.load_file(adapter_path)
65
+ else:
66
+ adapter_state = torch.load(adapter_path, map_location=device)
67
+ self.adapter.load_state_dict(adapter_state)
68
+ except Exception as e:
69
+ print(f"⚠️ 加载适配器权重失败: {e}")
70
+
71
+ # Load UNet (with LoRA support)
72
+ print("🏗️ 加载 UNet 模型...")
73
+ from .model_loader import load_unet_with_lora, load_fused_unet
74
+
75
+ if use_fused_unet and fused_unet_path:
76
+ # Load fused UNet with merged LoRA weights
77
+ print("📦 使用融合LoRA权重的UNet...")
78
+ self.unet = load_fused_unet(fused_unet_path, device, dtype)
79
+ elif lora_weights_path and lora_config_path:
80
+ # Load UNet with separate LoRA weights
81
+ print("🔧 加载UNet并应用LoRA权重...")
82
+ # For this case, use the base SDXL model path instead of safetensors
83
+ # base_model_path = unet_path.replace("/unet.safetensors", "").replace("/extracted_components/waiNSFWIllustrious_v140_unet.safetensors", "")
84
+ self.unet = load_unet_with_lora(
85
+ unet_path=unet_path,
86
+ unet_config_path=unet_config_path,
87
+ lora_weights_path=lora_weights_path,
88
+ lora_config_path=lora_config_path,
89
+ device=device,
90
+ dtype=dtype
91
+ )
92
+ else:
93
+ # Load standard UNet from safetensors
94
+ self.unet = load_unet_from_safetensors(unet_path, unet_config_path, device, dtype)
95
+
96
+ # Load VAE
97
+ print("🎨 加载 VAE 模型...")
98
+ self.vae = load_vae_from_safetensors(vae_path, vae_config_path, device, dtype)
99
+
100
+ # Initialize scheduler
101
+ print(f"⏰ 创建调度器 ({scheduler_type})...")
102
+ self.scheduler = create_scheduler(scheduler_type)
103
+
104
+ # Check if all components loaded successfully
105
+ self.is_ready = all([
106
+ self.text_encoder is not None,
107
+ self.adapter is not None,
108
+ self.unet is not None,
109
+ self.vae is not None,
110
+ self.scheduler is not None
111
+ ])
112
+
113
+ if self.is_ready:
114
+ print("✅ 管道初始化完成!")
115
+ else:
116
+ print("❌ 管道初始化失败,某些组件加��失败")
117
+
118
+ def encode_prompts(
119
+ self,
120
+ prompts: List[str],
121
+ negative_prompts: Optional[List[str]] = None,
122
+ do_classifier_free_guidance: bool = True
123
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
124
+ """
125
+ Encode prompts using Qwen3 + adapter
126
+ 使用 Qwen3 + 适配器编码提示词
127
+ """
128
+ # Get raw embeddings from Qwen
129
+ text_embeddings, pooled_embeddings = self.text_encoder.encode_prompts(
130
+ prompts, negative_prompts, do_classifier_free_guidance
131
+ )
132
+
133
+ batch_size = len(prompts)
134
+ if do_classifier_free_guidance:
135
+ batch_size *= 2
136
+
137
+ # Add sequence dimension for text embeddings (We uses 512 tokens for SDXL)
138
+ seq_len = 512
139
+ text_embeddings_seq = text_embeddings.unsqueeze(1).expand(-1, seq_len, -1) # [B, 512, 1024]
140
+
141
+ # Project to SDXL dimensions using adapter
142
+ prompt_embeds = self.adapter.forward_text_embeddings(text_embeddings_seq.to(self.dtype)) # [B, 512, 2048]
143
+ pooled_prompt_embeds = self.adapter.forward_pooled_embeddings(pooled_embeddings.to(self.dtype)) # [B, 1280]
144
+
145
+ return prompt_embeds, pooled_prompt_embeds
146
+
147
+ def prepare_latents(
148
+ self,
149
+ batch_size: int,
150
+ height: int,
151
+ width: int,
152
+ generator: Optional[torch.Generator] = None
153
+ ) -> torch.Tensor:
154
+ """
155
+ Prepare initial latents
156
+ 准备初始潜在变量
157
+ """
158
+ if self.unet is None:
159
+ # Mock latents for testing
160
+ shape = (batch_size, 4, height // self.vae_scale_factor, width // self.vae_scale_factor)
161
+ return torch.randn(shape, device=self.device, dtype=self.dtype)
162
+
163
+ shape = (
164
+ batch_size,
165
+ self.unet.config.in_channels,
166
+ height // self.vae_scale_factor,
167
+ width // self.vae_scale_factor,
168
+ )
169
+
170
+ try:
171
+ from diffusers.utils import randn_tensor
172
+ latents = randn_tensor(shape, generator=generator, device=self.device, dtype=self.dtype)
173
+ except ImportError:
174
+ latents = torch.randn(shape, device=self.device, dtype=self.dtype, generator=generator)
175
+
176
+ # Scale initial noise
177
+ if self.scheduler is not None:
178
+ latents = latents * self.scheduler.init_noise_sigma
179
+
180
+ return latents
181
+
182
+ def get_time_ids(
183
+ self,
184
+ height: int,
185
+ width: int,
186
+ original_size: Tuple[int, int],
187
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
188
+ target_size: Optional[Tuple[int, int]] = None
189
+ ) -> torch.Tensor:
190
+ """
191
+ Get SDXL time IDs for micro-conditioning
192
+ 获取 SDXL 时间 ID 用于微调节
193
+ """
194
+ if target_size is None:
195
+ target_size = (height, width)
196
+
197
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
198
+ add_time_ids = torch.tensor([add_time_ids], dtype=self.dtype, device=self.device)
199
+
200
+ return add_time_ids
201
+
202
+ @torch.no_grad()
203
+ def generate(
204
+ self,
205
+ prompt: Union[str, List[str]],
206
+ negative_prompt: Optional[Union[str, List[str]]] = None,
207
+ height: int = 1024,
208
+ width: int = 1024,
209
+ num_inference_steps: int = 50,
210
+ guidance_scale: float = 7.5,
211
+ generator: Optional[torch.Generator] = None,
212
+ return_type: str = "pil"
213
+ ) -> List[Image.Image]:
214
+ """
215
+ Generate images using Qwen-SDXL pipeline
216
+ 使用 Qwen-SDXL 管道生成图像
217
+ """
218
+ if not self.is_ready:
219
+ print("❌ 管道未准备就绪,无法生成图像")
220
+ return []
221
+
222
+ # Prepare prompts
223
+ if isinstance(prompt, str):
224
+ prompt = [prompt]
225
+ if isinstance(negative_prompt, str):
226
+ negative_prompt = [negative_prompt]
227
+
228
+ batch_size = len(prompt)
229
+ do_classifier_free_guidance = guidance_scale > 1.0
230
+
231
+ print(f"🎯 开始生成 {batch_size} 张图像...")
232
+ print(f"📏 尺寸: {width}x{height}")
233
+ print(f"🔄 推理步数: {num_inference_steps}")
234
+ print(f"🎚️ 引导强度: {guidance_scale}")
235
+
236
+ # 1. Encode prompts
237
+ print("📝 编码提示词...")
238
+ prompt_embeds, pooled_prompt_embeds = self.encode_prompts(
239
+ prompt, negative_prompt, do_classifier_free_guidance
240
+ )
241
+
242
+ # 2. Prepare timesteps
243
+ print("⏰ 准备时间步...")
244
+ if self.scheduler is not None:
245
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device)
246
+ timesteps = self.scheduler.timesteps
247
+ else:
248
+ timesteps = torch.linspace(1000, 0, num_inference_steps, device=self.device)
249
+
250
+ # 3. Prepare latents
251
+ print("🌀 准备潜在变量...")
252
+ latents = self.prepare_latents(batch_size, height, width, generator)
253
+
254
+ # 4. Prepare time IDs
255
+ original_size = (height, width)
256
+ target_size = (height, width)
257
+ add_time_ids = self.get_time_ids(height, width, original_size, target_size=target_size)
258
+
259
+ if do_classifier_free_guidance:
260
+ add_time_ids = add_time_ids.repeat(2, 1)
261
+ add_time_ids = add_time_ids.repeat(batch_size, 1)
262
+
263
+ # 5. Denoising loop
264
+ print("🔄 开始去噪过程...")
265
+ for i, t in enumerate(timesteps):
266
+ # Expand latents for classifier-free guidance
267
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
268
+
269
+ if self.scheduler is not None:
270
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
271
+
272
+ # Predict noise
273
+ if self.unet is not None:
274
+ added_cond_kwargs = {
275
+ "text_embeds": pooled_prompt_embeds,
276
+ "time_ids": add_time_ids
277
+ }
278
+
279
+ noise_pred = self.unet(
280
+ latent_model_input,
281
+ t,
282
+ encoder_hidden_states=prompt_embeds,
283
+ added_cond_kwargs=added_cond_kwargs,
284
+ return_dict=False,
285
+ )[0]
286
+
287
+ # Classifier-free guidance
288
+ if do_classifier_free_guidance:
289
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
290
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
291
+
292
+ # Scheduler step
293
+ if self.scheduler is not None:
294
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
295
+
296
+ if (i + 1) % 5 == 0:
297
+ print(f" 步骤 {i+1}/{len(timesteps)} 完成")
298
+
299
+ # 6. Decode latents
300
+ print("🎨 解码生成图像...")
301
+ if self.vae is not None:
302
+ latents = latents / self.vae.config.scaling_factor
303
+ images = self.vae.decode(latents, return_dict=False)[0]
304
+ else:
305
+ # Mock image generation for testing
306
+ images = torch.randn(batch_size, 3, height, width, device=self.device)
307
+
308
+ # 7. Convert to PIL images
309
+ images = (images / 2 + 0.5).clamp(0, 1)
310
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
311
+
312
+ if return_type == "pil":
313
+ images = [Image.fromarray((img * 255).astype(np.uint8)) for img in images]
314
+
315
+ print("✅ 图像生成完成!")
316
+ return images
317
+
318
+ def save_adapter(self, save_path: str):
319
+ """
320
+ Save adapter weights
321
+ 保存适配器权重
322
+ """
323
+ try:
324
+ if save_path.endswith(".safetensors"):
325
+ import safetensors.torch
326
+ safetensors.torch.save_file(self.adapter.state_dict(), save_path)
327
+ else:
328
+ torch.save(self.adapter.state_dict(), save_path)
329
+ print(f"✅ 适配器权重已保存到: {save_path}")
330
+ except Exception as e:
331
+ print(f"❌ 保存适配器权重失败: {e}")
332
+
333
+ def load_adapter(self, load_path: str):
334
+ """
335
+ Load adapter weights
336
+ 加载适配器权重
337
+ """
338
+ try:
339
+ if load_path.endswith(".safetensors"):
340
+ import safetensors.torch
341
+ state_dict = safetensors.torch.load_file(load_path)
342
+ else:
343
+ state_dict = torch.load(load_path, map_location=self.device)
344
+
345
+ self.adapter.load_state_dict(state_dict)
346
+ print(f"✅ 适配器权重已从 {load_path} 加载")
347
+ except Exception as e:
348
+ print(f"❌ 加载适配器权重失败: {e}")
arch/text_encoder.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Qwen Text Encoder
3
+ Qwen 文本编码器 - 使用 Qwen3 模型进行文本编码
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import List, Optional, Union, Tuple
9
+
10
+
11
+ def load_qwen_model(model_path: str, device: str = "cuda"):
12
+ """
13
+ Load Qwen3 embedding model
14
+ 加载 Qwen3 嵌入模型
15
+ """
16
+ try:
17
+ from sentence_transformers import SentenceTransformer
18
+ model = SentenceTransformer(model_path)
19
+ model.to(device)
20
+ return model
21
+ except ImportError:
22
+ print("Warning: sentence-transformers not available. Using mock embeddings.")
23
+ return None
24
+
25
+
26
+ def encode_text_with_qwen(
27
+ qwen_model,
28
+ texts: List[str],
29
+ device: str = "cuda",
30
+ max_length: int = 512,
31
+ use_query_mode: bool = False
32
+ ) -> torch.Tensor:
33
+ """
34
+ Encode text using Qwen3 model
35
+ 使用 Qwen3 模型编码文本
36
+ Args:
37
+ qwen_model: Qwen3 embedding model
38
+ texts: List of text strings to encode
39
+ device: Device to run on
40
+ max_length: Maximum sequence length
41
+ use_query_mode: Whether to use query prompt for better understanding
42
+ """
43
+ if qwen_model is None:
44
+ # Mock embeddings for testing when sentence-transformers is not available
45
+ batch_size = len(texts)
46
+ return torch.randn(batch_size, 1024, device=device, dtype=torch.float32)
47
+
48
+ with torch.no_grad():
49
+ # Use query prompt for better text understanding when specified
50
+ embeddings = qwen_model.encode(
51
+ texts,
52
+ prompt_name="query" if use_query_mode else None,
53
+ convert_to_tensor=True,
54
+ device=device,
55
+ max_seq_length=max_length,
56
+ output_value="token_embeddings" if not use_query_mode else "sentence_embedding"
57
+ )
58
+
59
+ return embeddings if use_query_mode else torch.stack(embeddings, dim=0)
60
+
61
+
62
+ class QwenTextEncoder(nn.Module):
63
+ """
64
+ Qwen Text Encoder wrapper for training and inference
65
+ 用于训练和推理的 Qwen 文本编码器包装器
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ model_path: str = "models/Qwen3-Embedding-0.6B",
71
+ device: str = "cuda",
72
+ max_length: int = 512,
73
+ freeze_encoder: bool = True
74
+ ):
75
+ super().__init__()
76
+ self.device = device
77
+ self.max_length = max_length
78
+ self.freeze_encoder = freeze_encoder
79
+
80
+ # Load Qwen model
81
+ self.qwen_model = load_qwen_model(model_path, device)
82
+
83
+ # Freeze parameters if specified
84
+ if self.freeze_encoder and self.qwen_model is not None:
85
+ for param in self.qwen_model.parameters():
86
+ param.requires_grad = False
87
+
88
+ def encode_prompts(
89
+ self,
90
+ prompts: List[str],
91
+ negative_prompts: Optional[List[str]] = None,
92
+ do_classifier_free_guidance: bool = False
93
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
94
+ """
95
+ Encode prompts using Qwen3 model
96
+ 使用 Qwen3 模型编码提示词
97
+
98
+ Returns:
99
+ tuple: (text_embeddings, pooled_embeddings)
100
+ - text_embeddings: [batch_size, 1024] for sequence embeddings
101
+ - pooled_embeddings: [batch_size, 1024] for pooled embeddings
102
+ """
103
+ batch_size = len(prompts)
104
+
105
+ # Encode positive prompts for text embeddings (normal mode)
106
+ text_embeddings = encode_text_with_qwen(
107
+ self.qwen_model, prompts, self.device,
108
+ max_length=self.max_length, use_query_mode=False
109
+ )
110
+
111
+ # Encode positive prompts for pooled embeddings (query mode)
112
+ pooled_embeddings = encode_text_with_qwen(
113
+ self.qwen_model, prompts, self.device,
114
+ max_length=self.max_length, use_query_mode=True
115
+ )
116
+
117
+ # Handle negative prompts
118
+ if do_classifier_free_guidance:
119
+ if negative_prompts is None:
120
+ negative_prompts = [""] * batch_size
121
+
122
+ # Encode negative prompts
123
+ negative_text_embeddings = encode_text_with_qwen(
124
+ self.qwen_model, negative_prompts, self.device,
125
+ max_length=self.max_length, use_query_mode=False
126
+ )
127
+
128
+ negative_pooled_embeddings = encode_text_with_qwen(
129
+ self.qwen_model, negative_prompts, self.device,
130
+ max_length=self.max_length, use_query_mode=True
131
+ )
132
+
133
+ # Concatenate for classifier-free guidance
134
+ text_embeddings = torch.cat([negative_text_embeddings, text_embeddings], dim=0)
135
+ pooled_embeddings = torch.cat([negative_pooled_embeddings, pooled_embeddings], dim=0)
136
+
137
+ return text_embeddings, pooled_embeddings
138
+
139
+ def forward(self, prompts: List[str], negative_prompts: Optional[List[str]] = None):
140
+ """
141
+ Forward pass for text encoding
142
+ Args:
143
+ prompts: List of text prompts
144
+ negative_prompts: Optional list of negative prompts
145
+ Returns:
146
+ tuple: (text_embeddings, pooled_embeddings)
147
+ """
148
+ return self.encode_prompts(prompts, negative_prompts, do_classifier_free_guidance=(negative_prompts is not None))
149
+
150
+ def train(self, mode: bool = True):
151
+ """Override train mode to handle frozen encoder"""
152
+ super().train(mode)
153
+ if self.freeze_encoder and self.qwen_model is not None:
154
+ self.qwen_model.eval() # Keep encoder in eval mode
155
+ return self
arch/training.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training Utilities for Qwen-SDXL
3
+ Qwen-SDXL 训练工具 - 包含损失函数、训练步骤等
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Dict, Any, Optional, Tuple
10
+ import math
11
+
12
+
13
+ class DiffusionLoss(nn.Module):
14
+ """
15
+ Diffusion training loss for SDXL with Qwen embeddings
16
+ 使用 Qwen 嵌入的 SDXL 扩散训练损失
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ noise_scheduler,
22
+ loss_type: str = "mse",
23
+ snr_gamma: Optional[float] = None,
24
+ use_v_parameterization: bool = False
25
+ ):
26
+ super().__init__()
27
+ self.noise_scheduler = noise_scheduler
28
+ self.loss_type = loss_type
29
+ self.snr_gamma = snr_gamma
30
+ self.use_v_parameterization = use_v_parameterization
31
+
32
+ if loss_type == "mse":
33
+ self.loss_fn = nn.MSELoss(reduction="none")
34
+ elif loss_type == "l1":
35
+ self.loss_fn = nn.L1Loss(reduction="none")
36
+ elif loss_type == "huber":
37
+ self.loss_fn = nn.HuberLoss(reduction="none", delta=0.1)
38
+ else:
39
+ raise ValueError(f"Unsupported loss type: {loss_type}")
40
+
41
+ def compute_snr(self, timesteps):
42
+ """
43
+ Compute signal-to-noise ratio for loss weighting
44
+ 计算信噪比用于损失加权
45
+ """
46
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps]
47
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
48
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
49
+
50
+ # SNR = signal^2 / noise^2
51
+ snr = (sqrt_alphas_cumprod / sqrt_one_minus_alphas_cumprod) ** 2
52
+ return snr
53
+
54
+ def forward(
55
+ self,
56
+ model_pred: torch.Tensor,
57
+ target: torch.Tensor,
58
+ timesteps: torch.Tensor,
59
+ mask: Optional[torch.Tensor] = None
60
+ ) -> torch.Tensor:
61
+ """
62
+ Compute diffusion loss
63
+ 计算扩散损失
64
+
65
+ Args:
66
+ model_pred: Model prediction (noise or v-parameterization)
67
+ target: Target (noise or v-parameterization)
68
+ timesteps: Diffusion timesteps
69
+ mask: Optional mask for selective loss computation
70
+
71
+ Returns:
72
+ Loss tensor
73
+ """
74
+ # Basic loss computation
75
+ loss = self.loss_fn(model_pred, target)
76
+
77
+ # Apply mask if provided
78
+ if mask is not None:
79
+ loss = loss * mask
80
+
81
+ # Reduce over spatial dimensions
82
+ loss = loss.mean(dim=list(range(1, len(loss.shape))))
83
+
84
+ # Apply SNR weighting if specified
85
+ if self.snr_gamma is not None:
86
+ snr = self.compute_snr(timesteps)
87
+
88
+ if self.snr_gamma >= 1.0:
89
+ # Min-SNR weighting
90
+ snr_weight = torch.stack([snr, torch.full_like(snr, self.snr_gamma)], dim=1).min(dim=1)[0]
91
+ else:
92
+ # Standard SNR weighting
93
+ snr_weight = snr ** self.snr_gamma
94
+
95
+ loss = loss * snr_weight
96
+
97
+ return loss.mean()
98
+
99
+
100
+ class AdapterTrainingStep:
101
+ """
102
+ Training step for adapter-only training
103
+ 仅训练适配器的训练步骤
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ unet,
109
+ vae,
110
+ text_encoder,
111
+ adapter,
112
+ noise_scheduler,
113
+ loss_fn: DiffusionLoss,
114
+ device: str = "cuda",
115
+ dtype: torch.dtype = torch.bfloat16
116
+ ):
117
+ self.unet = unet
118
+ self.vae = vae
119
+ self.text_encoder = text_encoder
120
+ self.adapter = adapter
121
+ self.noise_scheduler = noise_scheduler
122
+ self.loss_fn = loss_fn
123
+ self.device = device
124
+ self.dtype = dtype
125
+
126
+ # Freeze components except adapter
127
+ self._freeze_components()
128
+
129
+ def _freeze_components(self):
130
+ """Freeze all components except adapter"""
131
+ if self.unet is not None:
132
+ for param in self.unet.parameters():
133
+ param.requires_grad = False
134
+
135
+ if self.vae is not None:
136
+ for param in self.vae.parameters():
137
+ param.requires_grad = False
138
+
139
+ # Text encoder is already frozen in QwenTextEncoder
140
+ # Only adapter parameters should be trainable
141
+ for param in self.adapter.parameters():
142
+ param.requires_grad = True
143
+
144
+ def prepare_inputs(
145
+ self,
146
+ images: torch.Tensor,
147
+ prompts: list,
148
+ negative_prompts: Optional[list] = None
149
+ ) -> Dict[str, torch.Tensor]:
150
+ """
151
+ Prepare inputs for training step
152
+ 准备训练步骤的输入
153
+ """
154
+ batch_size = images.shape[0]
155
+
156
+ # Encode images to latents
157
+ with torch.no_grad():
158
+ latents = self.vae.encode(images.to(self.dtype)).latent_dist.sample()
159
+ latents = latents * self.vae.config.scaling_factor
160
+
161
+ # Add noise to latents
162
+ noise = torch.randn_like(latents)
163
+ timesteps = torch.randint(
164
+ 0, self.noise_scheduler.config.num_train_timesteps,
165
+ (batch_size,), device=self.device
166
+ ).long()
167
+
168
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
169
+
170
+ # Encode text
171
+ text_embeddings, pooled_embeddings = self.text_encoder.encode_prompts(
172
+ prompts, negative_prompts, do_classifier_free_guidance=False
173
+ )
174
+
175
+ # Add sequence dimension and project through adapter
176
+ seq_len = 77
177
+ text_embeddings_seq = text_embeddings.unsqueeze(1).expand(-1, seq_len, -1)
178
+
179
+ encoder_hidden_states = self.adapter.forward_text_embeddings(text_embeddings_seq.to(self.dtype))
180
+ pooled_prompt_embeds = self.adapter.forward_pooled_embeddings(pooled_embeddings.to(self.dtype))
181
+
182
+ # Prepare time IDs (simplified for training)
183
+ height, width = images.shape[-2:]
184
+ original_size = (height, width)
185
+ target_size = (height, width)
186
+ crops_coords_top_left = (0, 0)
187
+
188
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
189
+ add_time_ids = torch.tensor([add_time_ids] * batch_size, dtype=self.dtype, device=self.device)
190
+
191
+ return {
192
+ "noisy_latents": noisy_latents,
193
+ "timesteps": timesteps,
194
+ "encoder_hidden_states": encoder_hidden_states,
195
+ "pooled_prompt_embeds": pooled_prompt_embeds,
196
+ "add_time_ids": add_time_ids,
197
+ "noise": noise
198
+ }
199
+
200
+ def training_step(
201
+ self,
202
+ images: torch.Tensor,
203
+ prompts: list,
204
+ negative_prompts: Optional[list] = None
205
+ ) -> Dict[str, Any]:
206
+ """
207
+ Execute one training step
208
+ 执行一个训练步骤
209
+ """
210
+ # Prepare inputs
211
+ inputs = self.prepare_inputs(images, prompts, negative_prompts)
212
+
213
+ # Forward pass through UNet
214
+ added_cond_kwargs = {
215
+ "text_embeds": inputs["pooled_prompt_embeds"],
216
+ "time_ids": inputs["add_time_ids"]
217
+ }
218
+
219
+ model_pred = self.unet(
220
+ inputs["noisy_latents"],
221
+ inputs["timesteps"],
222
+ encoder_hidden_states=inputs["encoder_hidden_states"],
223
+ added_cond_kwargs=added_cond_kwargs,
224
+ return_dict=False,
225
+ )[0]
226
+
227
+ # Compute loss
228
+ if self.loss_fn.use_v_parameterization:
229
+ # v-parameterization target
230
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod[inputs["timesteps"]]
231
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
232
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
233
+
234
+ target = sqrt_alphas_cumprod.view(-1, 1, 1, 1) * inputs["noise"] - \
235
+ sqrt_one_minus_alphas_cumprod.view(-1, 1, 1, 1) * inputs["noisy_latents"]
236
+ else:
237
+ # Standard noise prediction
238
+ target = inputs["noise"]
239
+
240
+ loss = self.loss_fn(model_pred, target, inputs["timesteps"])
241
+
242
+ return {
243
+ "loss": loss,
244
+ "model_pred": model_pred.detach(),
245
+ "target": target.detach(),
246
+ "timesteps": inputs["timesteps"]
247
+ }
248
+
249
+
250
+ def get_cosine_schedule_with_warmup(
251
+ optimizer,
252
+ num_warmup_steps: int,
253
+ num_training_steps: int,
254
+ num_cycles: float = 0.5,
255
+ last_epoch: int = -1,
256
+ ):
257
+ """
258
+ Cosine learning rate schedule with warmup
259
+ 带预热的余弦学习率调度
260
+ """
261
+ def lr_lambda(current_step):
262
+ if current_step < num_warmup_steps:
263
+ return float(current_step) / float(max(1, num_warmup_steps))
264
+
265
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
266
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
267
+
268
+ from torch.optim.lr_scheduler import LambdaLR
269
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
270
+
271
+
272
+ class EMAModel:
273
+ """
274
+ Exponential Moving Average for model parameters
275
+ 模型参数的指数移动平均
276
+ """
277
+
278
+ def __init__(self, model, decay: float = 0.9999):
279
+ self.model = model
280
+ self.decay = decay
281
+ self.shadow = {}
282
+ self.backup = {}
283
+
284
+ # Initialize shadow parameters
285
+ for name, param in model.named_parameters():
286
+ if param.requires_grad:
287
+ self.shadow[name] = param.data.clone()
288
+
289
+ def update(self):
290
+ """Update EMA parameters"""
291
+ for name, param in self.model.named_parameters():
292
+ if param.requires_grad and name in self.shadow:
293
+ self.shadow[name] = self.decay * self.shadow[name] + (1 - self.decay) * param.data
294
+
295
+ def apply_shadow(self):
296
+ """Apply EMA parameters to model"""
297
+ for name, param in self.model.named_parameters():
298
+ if param.requires_grad and name in self.shadow:
299
+ self.backup[name] = param.data.clone()
300
+ param.data = self.shadow[name]
301
+
302
+ def restore(self):
303
+ """Restore original parameters"""
304
+ for name, param in self.model.named_parameters():
305
+ if param.requires_grad and name in self.backup:
306
+ param.data = self.backup[name]
307
+ self.backup = {}
diffusers/.github/PULL_REQUEST_TEMPLATE.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # What does this PR do?
2
+
3
+ <!--
4
+ Congratulations! You've made it this far! You're not quite done yet though.
5
+
6
+ Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution.
7
+
8
+ Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.
9
+
10
+ Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost.
11
+ -->
12
+
13
+ <!-- Remove if not applicable -->
14
+
15
+ Fixes # (issue)
16
+
17
+
18
+ ## Before submitting
19
+ - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
20
+ - [ ] Did you read the [contributor guideline](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md)?
21
+ - [ ] Did you read our [philosophy doc](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md) (important for complex PRs)?
22
+ - [ ] Was this discussed/approved via a GitHub issue or the [forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63)? Please add a link to it if that's the case.
23
+ - [ ] Did you make sure to update the documentation with your changes? Here are the
24
+ [documentation guidelines](https://github.com/huggingface/diffusers/tree/main/docs), and
25
+ [here are tips on formatting docstrings](https://github.com/huggingface/diffusers/tree/main/docs#writing-source-documentation).
26
+ - [ ] Did you write any new necessary tests?
27
+
28
+
29
+ ## Who can review?
30
+
31
+ Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
32
+ members/contributors who may be interested in your PR.
33
+
34
+ <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @.
35
+
36
+ If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.
37
+ Please tag fewer than 3 people.
38
+
39
+ Core library:
40
+
41
+ - Schedulers: @yiyixuxu
42
+ - Pipelines and pipeline callbacks: @yiyixuxu and @asomoza
43
+ - Training examples: @sayakpaul
44
+ - Docs: @stevhliu and @sayakpaul
45
+ - JAX and MPS: @pcuenca
46
+ - Audio: @sanchit-gandhi
47
+ - General functionalities: @sayakpaul @yiyixuxu @DN6
48
+
49
+ Integrations:
50
+
51
+ - deepspeed: HF Trainer/Accelerate: @SunMarc
52
+ - PEFT: @sayakpaul @BenjaminBossan
53
+
54
+ HF projects:
55
+
56
+ - accelerate: [different repo](https://github.com/huggingface/accelerate)
57
+ - datasets: [different repo](https://github.com/huggingface/datasets)
58
+ - transformers: [different repo](https://github.com/huggingface/transformers)
59
+ - safetensors: [different repo](https://github.com/huggingface/safetensors)
60
+
61
+ -->
diffusers/docs/README.md ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!---
2
+ Copyright 2024- The HuggingFace Team. All rights reserved.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ -->
16
+
17
+ # Generating the documentation
18
+
19
+ To generate the documentation, you first have to build it. Several packages are necessary to build the doc,
20
+ you can install them with the following command, at the root of the code repository:
21
+
22
+ ```bash
23
+ pip install -e ".[docs]"
24
+ ```
25
+
26
+ Then you need to install our open source documentation builder tool:
27
+
28
+ ```bash
29
+ pip install git+https://github.com/huggingface/doc-builder
30
+ ```
31
+
32
+ ---
33
+ **NOTE**
34
+
35
+ You only need to generate the documentation to inspect it locally (if you're planning changes and want to
36
+ check how they look before committing for instance). You don't have to commit the built documentation.
37
+
38
+ ---
39
+
40
+ ## Previewing the documentation
41
+
42
+ To preview the docs, first install the `watchdog` module with:
43
+
44
+ ```bash
45
+ pip install watchdog
46
+ ```
47
+
48
+ Then run the following command:
49
+
50
+ ```bash
51
+ doc-builder preview {package_name} {path_to_docs}
52
+ ```
53
+
54
+ For example:
55
+
56
+ ```bash
57
+ doc-builder preview diffusers docs/source/en
58
+ ```
59
+
60
+ The docs will be viewable at [http://localhost:3000](http://localhost:3000). You can also preview the docs once you have opened a PR. You will see a bot add a comment to a link where the documentation with your changes lives.
61
+
62
+ ---
63
+ **NOTE**
64
+
65
+ The `preview` command only works with existing doc files. When you add a completely new file, you need to update `_toctree.yml` & restart `preview` command (`ctrl-c` to stop it & call `doc-builder preview ...` again).
66
+
67
+ ---
68
+
69
+ ## Adding a new element to the navigation bar
70
+
71
+ Accepted files are Markdown (.md).
72
+
73
+ Create a file with its extension and put it in the source directory. You can then link it to the toc-tree by putting
74
+ the filename without the extension in the [`_toctree.yml`](https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml) file.
75
+
76
+ ## Renaming section headers and moving sections
77
+
78
+ It helps to keep the old links working when renaming the section header and/or moving sections from one document to another. This is because the old links are likely to be used in Issues, Forums, and Social media and it'd make for a much more superior user experience if users reading those months later could still easily navigate to the originally intended information.
79
+
80
+ Therefore, we simply keep a little map of moved sections at the end of the document where the original section was. The key is to preserve the original anchor.
81
+
82
+ So if you renamed a section from: "Section A" to "Section B", then you can add at the end of the file:
83
+
84
+ ```md
85
+ Sections that were moved:
86
+
87
+ [ <a href="#section-b">Section A</a><a id="section-a"></a> ]
88
+ ```
89
+ and of course, if you moved it to another file, then:
90
+
91
+ ```md
92
+ Sections that were moved:
93
+
94
+ [ <a href="../new-file#section-b">Section A</a><a id="section-a"></a> ]
95
+ ```
96
+
97
+ Use the relative style to link to the new file so that the versioned docs continue to work.
98
+
99
+ For an example of a rich moved section set please see the very end of [the transformers Trainer doc](https://github.com/huggingface/transformers/blob/main/docs/source/en/main_classes/trainer.md).
100
+
101
+
102
+ ## Writing Documentation - Specification
103
+
104
+ The `huggingface/diffusers` documentation follows the
105
+ [Google documentation](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) style for docstrings,
106
+ although we can write them directly in Markdown.
107
+
108
+ ### Adding a new tutorial
109
+
110
+ Adding a new tutorial or section is done in two steps:
111
+
112
+ - Add a new Markdown (.md) file under `docs/source/<languageCode>`.
113
+ - Link that file in `docs/source/<languageCode>/_toctree.yml` on the correct toc-tree.
114
+
115
+ Make sure to put your new file under the proper section. It's unlikely to go in the first section (*Get Started*), so
116
+ depending on the intended targets (beginners, more advanced users, or researchers) it should go in sections two, three, or four.
117
+
118
+ ### Adding a new pipeline/scheduler
119
+
120
+ When adding a new pipeline:
121
+
122
+ - Create a file `xxx.md` under `docs/source/<languageCode>/api/pipelines` (don't hesitate to copy an existing file as template).
123
+ - Link that file in (*Diffusers Summary*) section in `docs/source/api/pipelines/overview.md`, along with the link to the paper, and a colab notebook (if available).
124
+ - Write a short overview of the diffusion model:
125
+ - Overview with paper & authors
126
+ - Paper abstract
127
+ - Tips and tricks and how to use it best
128
+ - Possible an end-to-end example of how to use it
129
+ - Add all the pipeline classes that should be linked in the diffusion model. These classes should be added using our Markdown syntax. By default as follows:
130
+
131
+ ```
132
+ [[autodoc]] XXXPipeline
133
+ - all
134
+ - __call__
135
+ ```
136
+
137
+ This will include every public method of the pipeline that is documented, as well as the `__call__` method that is not documented by default. If you just want to add additional methods that are not documented, you can put the list of all methods to add in a list that contains `all`.
138
+
139
+ ```
140
+ [[autodoc]] XXXPipeline
141
+ - all
142
+ - __call__
143
+ - enable_attention_slicing
144
+ - disable_attention_slicing
145
+ - enable_xformers_memory_efficient_attention
146
+ - disable_xformers_memory_efficient_attention
147
+ ```
148
+
149
+ You can follow the same process to create a new scheduler under the `docs/source/<languageCode>/api/schedulers` folder.
150
+
151
+ ### Writing source documentation
152
+
153
+ Values that should be put in `code` should either be surrounded by backticks: \`like so\`. Note that argument names
154
+ and objects like True, None, or any strings should usually be put in `code`.
155
+
156
+ When mentioning a class, function, or method, it is recommended to use our syntax for internal links so that our tool
157
+ adds a link to its documentation with this syntax: \[\`XXXClass\`\] or \[\`function\`\]. This requires the class or
158
+ function to be in the main package.
159
+
160
+ If you want to create a link to some internal class or function, you need to
161
+ provide its path. For instance: \[\`pipelines.ImagePipelineOutput\`\]. This will be converted into a link with
162
+ `pipelines.ImagePipelineOutput` in the description. To get rid of the path and only keep the name of the object you are
163
+ linking to in the description, add a ~: \[\`~pipelines.ImagePipelineOutput\`\] will generate a link with `ImagePipelineOutput` in the description.
164
+
165
+ The same works for methods so you can either use \[\`XXXClass.method\`\] or \[\`~XXXClass.method\`\].
166
+
167
+ #### Defining arguments in a method
168
+
169
+ Arguments should be defined with the `Args:` (or `Arguments:` or `Parameters:`) prefix, followed by a line return and
170
+ an indentation. The argument should be followed by its type, with its shape if it is a tensor, a colon, and its
171
+ description:
172
+
173
+ ```
174
+ Args:
175
+ n_layers (`int`): The number of layers of the model.
176
+ ```
177
+
178
+ If the description is too long to fit in one line, another indentation is necessary before writing the description
179
+ after the argument.
180
+
181
+ Here's an example showcasing everything so far:
182
+
183
+ ```
184
+ Args:
185
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
186
+ Indices of input sequence tokens in the vocabulary.
187
+
188
+ Indices can be obtained using [`AlbertTokenizer`]. See [`~PreTrainedTokenizer.encode`] and
189
+ [`~PreTrainedTokenizer.__call__`] for details.
190
+
191
+ [What are input IDs?](../glossary#input-ids)
192
+ ```
193
+
194
+ For optional arguments or arguments with defaults we follow the following syntax: imagine we have a function with the
195
+ following signature:
196
+
197
+ ```py
198
+ def my_function(x: str=None, a: float=3.14):
199
+ ```
200
+
201
+ then its documentation should look like this:
202
+
203
+ ```
204
+ Args:
205
+ x (`str`, *optional*):
206
+ This argument controls ...
207
+ a (`float`, *optional*, defaults to `3.14`):
208
+ This argument is used to ...
209
+ ```
210
+
211
+ Note that we always omit the "defaults to \`None\`" when None is the default for any argument. Also note that even
212
+ if the first line describing your argument type and its default gets long, you can't break it on several lines. You can
213
+ however write as many lines as you want in the indented description (see the example above with `input_ids`).
214
+
215
+ #### Writing a multi-line code block
216
+
217
+ Multi-line code blocks can be useful for displaying examples. They are done between two lines of three backticks as usual in Markdown:
218
+
219
+
220
+ ````
221
+ ```
222
+ # first line of code
223
+ # second line
224
+ # etc
225
+ ```
226
+ ````
227
+
228
+ #### Writing a return block
229
+
230
+ The return block should be introduced with the `Returns:` prefix, followed by a line return and an indentation.
231
+ The first line should be the type of the return, followed by a line return. No need to indent further for the elements
232
+ building the return.
233
+
234
+ Here's an example of a single value return:
235
+
236
+ ```
237
+ Returns:
238
+ `List[int]`: A list of integers in the range [0, 1] --- 1 for a special token, 0 for a sequence token.
239
+ ```
240
+
241
+ Here's an example of a tuple return, comprising several objects:
242
+
243
+ ```
244
+ Returns:
245
+ `tuple(torch.Tensor)` comprising various elements depending on the configuration ([`BertConfig`]) and inputs:
246
+ - ** loss** (*optional*, returned when `masked_lm_labels` is provided) `torch.Tensor` of shape `(1,)` --
247
+ Total loss is the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
248
+ - **prediction_scores** (`torch.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`) --
249
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
250
+ ```
251
+
252
+ #### Adding an image
253
+
254
+ Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like
255
+ the ones hosted on [`hf-internal-testing`](https://huggingface.co/hf-internal-testing) in which to place these files and reference
256
+ them by URL. We recommend putting them in the following dataset: [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images).
257
+ If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images
258
+ to this dataset.
259
+
260
+ ## Styling the docstring
261
+
262
+ We have an automatic script running with the `make style` command that will make sure that:
263
+ - the docstrings fully take advantage of the line width
264
+ - all code examples are formatted using black, like the code of the Transformers library
265
+
266
+ This script may have some weird failures if you made a syntax mistake or if you uncover a bug. Therefore, it's
267
+ recommended to commit your changes before running `make style`, so you can revert the changes done by that script
268
+ easily.
diffusers/docs/TRANSLATING.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!--Copyright 2025 The HuggingFace Team. All rights reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4
+ the License. You may obtain a copy of the License at
5
+
6
+ http://www.apache.org/licenses/LICENSE-2.0
7
+
8
+ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9
+ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10
+ specific language governing permissions and limitations under the License.
11
+ -->
12
+
13
+ ### Translating the Diffusers documentation into your language
14
+
15
+ As part of our mission to democratize machine learning, we'd love to make the Diffusers library available in many more languages! Follow the steps below if you want to help translate the documentation into your language 🙏.
16
+
17
+ **🗞️ Open an issue**
18
+
19
+ To get started, navigate to the [Issues](https://github.com/huggingface/diffusers/issues) page of this repo and check if anyone else has opened an issue for your language. If not, open a new issue by selecting the "🌐 Translating a New Language?" from the "New issue" button.
20
+
21
+ Once an issue exists, post a comment to indicate which chapters you'd like to work on, and we'll add your name to the list.
22
+
23
+
24
+ **🍴 Fork the repository**
25
+
26
+ First, you'll need to [fork the Diffusers repo](https://docs.github.com/en/get-started/quickstart/fork-a-repo). You can do this by clicking on the **Fork** button on the top-right corner of this repo's page.
27
+
28
+ Once you've forked the repo, you'll want to get the files on your local machine for editing. You can do that by cloning the fork with Git as follows:
29
+
30
+ ```bash
31
+ git clone https://github.com/<YOUR-USERNAME>/diffusers.git
32
+ ```
33
+
34
+ **📋 Copy-paste the English version with a new language code**
35
+
36
+ The documentation files are in one leading directory:
37
+
38
+ - [`docs/source`](https://github.com/huggingface/diffusers/tree/main/docs/source): All the documentation materials are organized here by language.
39
+
40
+ You'll only need to copy the files in the [`docs/source/en`](https://github.com/huggingface/diffusers/tree/main/docs/source/en) directory, so first navigate to your fork of the repo and run the following:
41
+
42
+ ```bash
43
+ cd ~/path/to/diffusers/docs
44
+ cp -r source/en source/<LANG-ID>
45
+ ```
46
+
47
+ Here, `<LANG-ID>` should be one of the ISO 639-1 or ISO 639-2 language codes -- see [here](https://www.loc.gov/standards/iso639-2/php/code_list.php) for a handy table.
48
+
49
+ **✍️ Start translating**
50
+
51
+ The fun part comes - translating the text!
52
+
53
+ The first thing we recommend is translating the part of the `_toctree.yml` file that corresponds to your doc chapter. This file is used to render the table of contents on the website.
54
+
55
+ > 🙋 If the `_toctree.yml` file doesn't yet exist for your language, you can create one by copy-pasting from the English version and deleting the sections unrelated to your chapter. Just make sure it exists in the `docs/source/<LANG-ID>/` directory!
56
+
57
+ The fields you should add are `local` (with the name of the file containing the translation; e.g. `autoclass_tutorial`), and `title` (with the title of the doc in your language; e.g. `Load pretrained instances with an AutoClass`) -- as a reference, here is the `_toctree.yml` for [English](https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml):
58
+
59
+ ```yaml
60
+ - sections:
61
+ - local: pipeline_tutorial # Do not change this! Use the same name for your .md file
62
+ title: Pipelines for inference # Translate this!
63
+ ...
64
+ title: Tutorials # Translate this!
65
+ ```
66
+
67
+ Once you have translated the `_toctree.yml` file, you can start translating the [MDX](https://mdxjs.com/) files associated with your docs chapter.
68
+
69
+ > 🙋 If you'd like others to help you with the translation, you should [open an issue](https://github.com/huggingface/diffusers/issues) and tag @patrickvonplaten.
diffusers/scripts/conversion_ldm_uncond.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ import yaml
5
+
6
+ from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
7
+
8
+
9
+ def convert_ldm_original(checkpoint_path, config_path, output_path):
10
+ config = yaml.safe_load(config_path)
11
+ state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
12
+ keys = list(state_dict.keys())
13
+
14
+ # extract state_dict for VQVAE
15
+ first_stage_dict = {}
16
+ first_stage_key = "first_stage_model."
17
+ for key in keys:
18
+ if key.startswith(first_stage_key):
19
+ first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
20
+
21
+ # extract state_dict for UNetLDM
22
+ unet_state_dict = {}
23
+ unet_key = "model.diffusion_model."
24
+ for key in keys:
25
+ if key.startswith(unet_key):
26
+ unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
27
+
28
+ vqvae_init_args = config["model"]["params"]["first_stage_config"]["params"]
29
+ unet_init_args = config["model"]["params"]["unet_config"]["params"]
30
+
31
+ vqvae = VQModel(**vqvae_init_args).eval()
32
+ vqvae.load_state_dict(first_stage_dict)
33
+
34
+ unet = UNetLDMModel(**unet_init_args).eval()
35
+ unet.load_state_dict(unet_state_dict)
36
+
37
+ noise_scheduler = DDIMScheduler(
38
+ timesteps=config["model"]["params"]["timesteps"],
39
+ beta_schedule="scaled_linear",
40
+ beta_start=config["model"]["params"]["linear_start"],
41
+ beta_end=config["model"]["params"]["linear_end"],
42
+ clip_sample=False,
43
+ )
44
+
45
+ pipeline = LDMPipeline(vqvae, unet, noise_scheduler)
46
+ pipeline.save_pretrained(output_path)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument("--checkpoint_path", type=str, required=True)
52
+ parser.add_argument("--config_path", type=str, required=True)
53
+ parser.add_argument("--output_path", type=str, required=True)
54
+ args = parser.parse_args()
55
+
56
+ convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
diffusers/scripts/convert_animatediff_motion_lora_to_diffusers.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from huggingface_hub import create_repo, upload_folder
6
+ from safetensors.torch import load_file, save_file
7
+
8
+
9
+ def convert_motion_module(original_state_dict):
10
+ converted_state_dict = {}
11
+ for k, v in original_state_dict.items():
12
+ if "pos_encoder" in k:
13
+ continue
14
+
15
+ else:
16
+ converted_state_dict[
17
+ k.replace(".norms.0", ".norm1")
18
+ .replace(".norms.1", ".norm2")
19
+ .replace(".ff_norm", ".norm3")
20
+ .replace(".attention_blocks.0", ".attn1")
21
+ .replace(".attention_blocks.1", ".attn2")
22
+ .replace(".temporal_transformer", "")
23
+ ] = v
24
+
25
+ return converted_state_dict
26
+
27
+
28
+ def get_args():
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument("--ckpt_path", type=str, required=True, help="Path to checkpoint")
31
+ parser.add_argument("--output_path", type=str, required=True, help="Path to output directory")
32
+ parser.add_argument(
33
+ "--push_to_hub",
34
+ action="store_true",
35
+ default=False,
36
+ help="Whether to push the converted model to the HF or not",
37
+ )
38
+
39
+ return parser.parse_args()
40
+
41
+
42
+ if __name__ == "__main__":
43
+ args = get_args()
44
+
45
+ if args.ckpt_path.endswith(".safetensors"):
46
+ state_dict = load_file(args.ckpt_path)
47
+ else:
48
+ state_dict = torch.load(args.ckpt_path, map_location="cpu")
49
+
50
+ if "state_dict" in state_dict.keys():
51
+ state_dict = state_dict["state_dict"]
52
+
53
+ conv_state_dict = convert_motion_module(state_dict)
54
+
55
+ # convert to new format
56
+ output_dict = {}
57
+ for module_name, params in conv_state_dict.items():
58
+ if type(params) is not torch.Tensor:
59
+ continue
60
+ output_dict.update({f"unet.{module_name}": params})
61
+
62
+ os.makedirs(args.output_path, exist_ok=True)
63
+
64
+ filepath = os.path.join(args.output_path, "diffusion_pytorch_model.safetensors")
65
+ save_file(output_dict, filepath)
66
+
67
+ if args.push_to_hub:
68
+ repo_id = create_repo(args.output_path, exist_ok=True).repo_id
69
+ upload_folder(repo_id=repo_id, folder_path=args.output_path, repo_type="model")
diffusers/scripts/convert_cogvideox_to_diffusers.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import Any, Dict
3
+
4
+ import torch
5
+ from transformers import T5EncoderModel, T5Tokenizer
6
+
7
+ from diffusers import (
8
+ AutoencoderKLCogVideoX,
9
+ CogVideoXDDIMScheduler,
10
+ CogVideoXImageToVideoPipeline,
11
+ CogVideoXPipeline,
12
+ CogVideoXTransformer3DModel,
13
+ )
14
+
15
+
16
+ def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
17
+ to_q_key = key.replace("query_key_value", "to_q")
18
+ to_k_key = key.replace("query_key_value", "to_k")
19
+ to_v_key = key.replace("query_key_value", "to_v")
20
+ to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
21
+ state_dict[to_q_key] = to_q
22
+ state_dict[to_k_key] = to_k
23
+ state_dict[to_v_key] = to_v
24
+ state_dict.pop(key)
25
+
26
+
27
+ def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
28
+ layer_id, weight_or_bias = key.split(".")[-2:]
29
+
30
+ if "query" in key:
31
+ new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
32
+ elif "key" in key:
33
+ new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"
34
+
35
+ state_dict[new_key] = state_dict.pop(key)
36
+
37
+
38
+ def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
39
+ layer_id, _, weight_or_bias = key.split(".")[-3:]
40
+
41
+ weights_or_biases = state_dict[key].chunk(12, dim=0)
42
+ norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
43
+ norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
44
+
45
+ norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
46
+ state_dict[norm1_key] = norm1_weights_or_biases
47
+
48
+ norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
49
+ state_dict[norm2_key] = norm2_weights_or_biases
50
+
51
+ state_dict.pop(key)
52
+
53
+
54
+ def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
55
+ state_dict.pop(key)
56
+
57
+
58
+ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
59
+ key_split = key.split(".")
60
+ layer_index = int(key_split[2])
61
+ replace_layer_index = 4 - 1 - layer_index
62
+
63
+ key_split[1] = "up_blocks"
64
+ key_split[2] = str(replace_layer_index)
65
+ new_key = ".".join(key_split)
66
+
67
+ state_dict[new_key] = state_dict.pop(key)
68
+
69
+
70
+ TRANSFORMER_KEYS_RENAME_DICT = {
71
+ "transformer.final_layernorm": "norm_final",
72
+ "transformer": "transformer_blocks",
73
+ "attention": "attn1",
74
+ "mlp": "ff.net",
75
+ "dense_h_to_4h": "0.proj",
76
+ "dense_4h_to_h": "2",
77
+ ".layers": "",
78
+ "dense": "to_out.0",
79
+ "input_layernorm": "norm1.norm",
80
+ "post_attn1_layernorm": "norm2.norm",
81
+ "time_embed.0": "time_embedding.linear_1",
82
+ "time_embed.2": "time_embedding.linear_2",
83
+ "ofs_embed.0": "ofs_embedding.linear_1",
84
+ "ofs_embed.2": "ofs_embedding.linear_2",
85
+ "mixins.patch_embed": "patch_embed",
86
+ "mixins.final_layer.norm_final": "norm_out.norm",
87
+ "mixins.final_layer.linear": "proj_out",
88
+ "mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
89
+ "mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V
90
+ }
91
+
92
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
93
+ "query_key_value": reassign_query_key_value_inplace,
94
+ "query_layernorm_list": reassign_query_key_layernorm_inplace,
95
+ "key_layernorm_list": reassign_query_key_layernorm_inplace,
96
+ "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
97
+ "embed_tokens": remove_keys_inplace,
98
+ "freqs_sin": remove_keys_inplace,
99
+ "freqs_cos": remove_keys_inplace,
100
+ "position_embedding": remove_keys_inplace,
101
+ }
102
+
103
+ VAE_KEYS_RENAME_DICT = {
104
+ "block.": "resnets.",
105
+ "down.": "down_blocks.",
106
+ "downsample": "downsamplers.0",
107
+ "upsample": "upsamplers.0",
108
+ "nin_shortcut": "conv_shortcut",
109
+ "encoder.mid.block_1": "encoder.mid_block.resnets.0",
110
+ "encoder.mid.block_2": "encoder.mid_block.resnets.1",
111
+ "decoder.mid.block_1": "decoder.mid_block.resnets.0",
112
+ "decoder.mid.block_2": "decoder.mid_block.resnets.1",
113
+ }
114
+
115
+ VAE_SPECIAL_KEYS_REMAP = {
116
+ "loss": remove_keys_inplace,
117
+ "up.": replace_up_keys_inplace,
118
+ }
119
+
120
+ TOKENIZER_MAX_LENGTH = 226
121
+
122
+
123
+ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
124
+ state_dict = saved_dict
125
+ if "model" in saved_dict.keys():
126
+ state_dict = state_dict["model"]
127
+ if "module" in saved_dict.keys():
128
+ state_dict = state_dict["module"]
129
+ if "state_dict" in saved_dict.keys():
130
+ state_dict = state_dict["state_dict"]
131
+ return state_dict
132
+
133
+
134
+ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
135
+ state_dict[new_key] = state_dict.pop(old_key)
136
+
137
+
138
+ def convert_transformer(
139
+ ckpt_path: str,
140
+ num_layers: int,
141
+ num_attention_heads: int,
142
+ use_rotary_positional_embeddings: bool,
143
+ i2v: bool,
144
+ dtype: torch.dtype,
145
+ init_kwargs: Dict[str, Any],
146
+ ):
147
+ PREFIX_KEY = "model.diffusion_model."
148
+
149
+ original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
150
+ transformer = CogVideoXTransformer3DModel(
151
+ in_channels=32 if i2v else 16,
152
+ num_layers=num_layers,
153
+ num_attention_heads=num_attention_heads,
154
+ use_rotary_positional_embeddings=use_rotary_positional_embeddings,
155
+ ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V
156
+ use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
157
+ **init_kwargs,
158
+ ).to(dtype=dtype)
159
+
160
+ for key in list(original_state_dict.keys()):
161
+ new_key = key[len(PREFIX_KEY) :]
162
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
163
+ new_key = new_key.replace(replace_key, rename_key)
164
+ update_state_dict_inplace(original_state_dict, key, new_key)
165
+
166
+ for key in list(original_state_dict.keys()):
167
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
168
+ if special_key not in key:
169
+ continue
170
+ handler_fn_inplace(key, original_state_dict)
171
+
172
+ transformer.load_state_dict(original_state_dict, strict=True)
173
+ return transformer
174
+
175
+
176
+ def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype):
177
+ init_kwargs = {"scaling_factor": scaling_factor}
178
+ if version == "1.5":
179
+ init_kwargs.update({"invert_scale_latents": True})
180
+
181
+ original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
182
+ vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype)
183
+
184
+ for key in list(original_state_dict.keys()):
185
+ new_key = key[:]
186
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
187
+ new_key = new_key.replace(replace_key, rename_key)
188
+ update_state_dict_inplace(original_state_dict, key, new_key)
189
+
190
+ for key in list(original_state_dict.keys()):
191
+ for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
192
+ if special_key not in key:
193
+ continue
194
+ handler_fn_inplace(key, original_state_dict)
195
+
196
+ vae.load_state_dict(original_state_dict, strict=True)
197
+ return vae
198
+
199
+
200
+ def get_transformer_init_kwargs(version: str):
201
+ if version == "1.0":
202
+ vae_scale_factor_spatial = 8
203
+ init_kwargs = {
204
+ "patch_size": 2,
205
+ "patch_size_t": None,
206
+ "patch_bias": True,
207
+ "sample_height": 480 // vae_scale_factor_spatial,
208
+ "sample_width": 720 // vae_scale_factor_spatial,
209
+ "sample_frames": 49,
210
+ }
211
+
212
+ elif version == "1.5":
213
+ vae_scale_factor_spatial = 8
214
+ init_kwargs = {
215
+ "patch_size": 2,
216
+ "patch_size_t": 2,
217
+ "patch_bias": False,
218
+ "sample_height": 300,
219
+ "sample_width": 300,
220
+ "sample_frames": 81,
221
+ }
222
+ else:
223
+ raise ValueError("Unsupported version of CogVideoX.")
224
+
225
+ return init_kwargs
226
+
227
+
228
+ def get_args():
229
+ parser = argparse.ArgumentParser()
230
+ parser.add_argument(
231
+ "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
232
+ )
233
+ parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
234
+ parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
235
+ parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
236
+ parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
237
+ parser.add_argument(
238
+ "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
239
+ )
240
+ parser.add_argument(
241
+ "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
242
+ )
243
+ parser.add_argument(
244
+ "--typecast_text_encoder",
245
+ action="store_true",
246
+ default=False,
247
+ help="Whether or not to apply fp16/bf16 precision to text_encoder",
248
+ )
249
+ # For CogVideoX-2B, num_layers is 30. For 5B, it is 42
250
+ parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
251
+ # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
252
+ parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
253
+ # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
254
+ parser.add_argument(
255
+ "--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
256
+ )
257
+ # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
258
+ parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
259
+ # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
260
+ parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
261
+ parser.add_argument(
262
+ "--i2v",
263
+ action="store_true",
264
+ default=False,
265
+ help="Whether the model to be converted is the Image-to-Video version of CogVideoX.",
266
+ )
267
+ parser.add_argument(
268
+ "--version",
269
+ choices=["1.0", "1.5"],
270
+ default="1.0",
271
+ help="Which version of CogVideoX to use for initializing default modeling parameters.",
272
+ )
273
+ return parser.parse_args()
274
+
275
+
276
+ if __name__ == "__main__":
277
+ args = get_args()
278
+
279
+ transformer = None
280
+ vae = None
281
+
282
+ if args.fp16 and args.bf16:
283
+ raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
284
+
285
+ dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
286
+
287
+ if args.transformer_ckpt_path is not None:
288
+ init_kwargs = get_transformer_init_kwargs(args.version)
289
+ transformer = convert_transformer(
290
+ args.transformer_ckpt_path,
291
+ args.num_layers,
292
+ args.num_attention_heads,
293
+ args.use_rotary_positional_embeddings,
294
+ args.i2v,
295
+ dtype,
296
+ init_kwargs,
297
+ )
298
+ if args.vae_ckpt_path is not None:
299
+ # Keep VAE in float32 for better quality
300
+ vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32)
301
+
302
+ text_encoder_id = "google/t5-v1_1-xxl"
303
+ tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
304
+ text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
305
+
306
+ if args.typecast_text_encoder:
307
+ text_encoder = text_encoder.to(dtype=dtype)
308
+
309
+ # Apparently, the conversion does not work anymore without this :shrug:
310
+ for param in text_encoder.parameters():
311
+ param.data = param.data.contiguous()
312
+
313
+ scheduler = CogVideoXDDIMScheduler.from_config(
314
+ {
315
+ "snr_shift_scale": args.snr_shift_scale,
316
+ "beta_end": 0.012,
317
+ "beta_schedule": "scaled_linear",
318
+ "beta_start": 0.00085,
319
+ "clip_sample": False,
320
+ "num_train_timesteps": 1000,
321
+ "prediction_type": "v_prediction",
322
+ "rescale_betas_zero_snr": True,
323
+ "set_alpha_to_one": True,
324
+ "timestep_spacing": "trailing",
325
+ }
326
+ )
327
+ if args.i2v:
328
+ pipeline_cls = CogVideoXImageToVideoPipeline
329
+ else:
330
+ pipeline_cls = CogVideoXPipeline
331
+
332
+ pipe = pipeline_cls(
333
+ tokenizer=tokenizer,
334
+ text_encoder=text_encoder,
335
+ vae=vae,
336
+ transformer=transformer,
337
+ scheduler=scheduler,
338
+ )
339
+
340
+ # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
341
+ # for users to specify variant when the default is not fp32 and they want to run with the correct default (which
342
+ # is either fp16/bf16 here).
343
+
344
+ # This is necessary This is necessary for users with insufficient memory,
345
+ # such as those using Colab and notebooks, as it can save some memory used for model loading.
346
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
diffusers/scripts/convert_consistency_decoder.py ADDED
@@ -0,0 +1,1128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from argparse import ArgumentParser
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from huggingface_hub.utils import insecure_hashlib
11
+ from safetensors.torch import load_file as stl
12
+ from tqdm import tqdm
13
+
14
+ from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
15
+ from diffusers.models.autoencoders.vae import Encoder
16
+ from diffusers.models.embeddings import TimestepEmbedding
17
+ from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
18
+
19
+
20
+ args = ArgumentParser()
21
+ args.add_argument("--save_pretrained", required=False, default=None, type=str)
22
+ args.add_argument("--test_image", required=True, type=str)
23
+ args = args.parse_args()
24
+
25
+
26
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
27
+ # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895 """
28
+ res = arr[timesteps].float()
29
+ dims_to_append = len(broadcast_shape) - len(res.shape)
30
+ return res[(...,) + (None,) * dims_to_append]
31
+
32
+
33
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
34
+ # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L45
35
+ betas = []
36
+ for i in range(num_diffusion_timesteps):
37
+ t1 = i / num_diffusion_timesteps
38
+ t2 = (i + 1) / num_diffusion_timesteps
39
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
40
+ return torch.tensor(betas)
41
+
42
+
43
+ def _download(url: str, root: str):
44
+ os.makedirs(root, exist_ok=True)
45
+ filename = os.path.basename(url)
46
+
47
+ expected_sha256 = url.split("/")[-2]
48
+ download_target = os.path.join(root, filename)
49
+
50
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
51
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
52
+
53
+ if os.path.isfile(download_target):
54
+ if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55
+ return download_target
56
+ else:
57
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58
+
59
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60
+ with tqdm(
61
+ total=int(source.info().get("Content-Length")),
62
+ ncols=80,
63
+ unit="iB",
64
+ unit_scale=True,
65
+ unit_divisor=1024,
66
+ ) as loop:
67
+ while True:
68
+ buffer = source.read(8192)
69
+ if not buffer:
70
+ break
71
+
72
+ output.write(buffer)
73
+ loop.update(len(buffer))
74
+
75
+ if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
76
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not match")
77
+
78
+ return download_target
79
+
80
+
81
+ class ConsistencyDecoder:
82
+ def __init__(self, device="cuda:0", download_root=os.path.expanduser("~/.cache/clip")):
83
+ self.n_distilled_steps = 64
84
+ download_target = _download(
85
+ "https://openaipublic.azureedge.net/diff-vae/c9cebd3132dd9c42936d803e33424145a748843c8f716c0814838bdc8a2fe7cb/decoder.pt",
86
+ download_root,
87
+ )
88
+ self.ckpt = torch.jit.load(download_target).to(device)
89
+ self.device = device
90
+ sigma_data = 0.5
91
+ betas = betas_for_alpha_bar(1024, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2).to(device)
92
+ alphas = 1.0 - betas
93
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
94
+ self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
95
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
96
+ sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
97
+ sigmas = torch.sqrt(1.0 / alphas_cumprod - 1)
98
+ self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2)
99
+ self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5
100
+ self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5
101
+
102
+ @staticmethod
103
+ def round_timesteps(timesteps, total_timesteps, n_distilled_steps, truncate_start=True):
104
+ with torch.no_grad():
105
+ space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor")
106
+ rounded_timesteps = (torch.div(timesteps, space, rounding_mode="floor") + 1) * space
107
+ if truncate_start:
108
+ rounded_timesteps[rounded_timesteps == total_timesteps] -= space
109
+ else:
110
+ rounded_timesteps[rounded_timesteps == total_timesteps] -= space
111
+ rounded_timesteps[rounded_timesteps == 0] += space
112
+ return rounded_timesteps
113
+
114
+ @staticmethod
115
+ def ldm_transform_latent(z, extra_scale_factor=1):
116
+ channel_means = [0.38862467, 0.02253063, 0.07381133, -0.0171294]
117
+ channel_stds = [0.9654121, 1.0440036, 0.76147926, 0.77022034]
118
+
119
+ if len(z.shape) != 4:
120
+ raise ValueError()
121
+
122
+ z = z * 0.18215
123
+ channels = [z[:, i] for i in range(z.shape[1])]
124
+
125
+ channels = [extra_scale_factor * (c - channel_means[i]) / channel_stds[i] for i, c in enumerate(channels)]
126
+ return torch.stack(channels, dim=1)
127
+
128
+ @torch.no_grad()
129
+ def __call__(
130
+ self,
131
+ features: torch.Tensor,
132
+ schedule=[1.0, 0.5],
133
+ generator=None,
134
+ ):
135
+ features = self.ldm_transform_latent(features)
136
+ ts = self.round_timesteps(
137
+ torch.arange(0, 1024),
138
+ 1024,
139
+ self.n_distilled_steps,
140
+ truncate_start=False,
141
+ )
142
+ shape = (
143
+ features.size(0),
144
+ 3,
145
+ 8 * features.size(2),
146
+ 8 * features.size(3),
147
+ )
148
+ x_start = torch.zeros(shape, device=features.device, dtype=features.dtype)
149
+ schedule_timesteps = [int((1024 - 1) * s) for s in schedule]
150
+ for i in schedule_timesteps:
151
+ t = ts[i].item()
152
+ t_ = torch.tensor([t] * features.shape[0]).to(self.device)
153
+ # noise = torch.randn_like(x_start)
154
+ noise = torch.randn(x_start.shape, dtype=x_start.dtype, generator=generator).to(device=x_start.device)
155
+ x_start = (
156
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t_, x_start.shape) * x_start
157
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t_, x_start.shape) * noise
158
+ )
159
+ c_in = _extract_into_tensor(self.c_in, t_, x_start.shape)
160
+
161
+ import torch.nn.functional as F
162
+
163
+ from diffusers import UNet2DModel
164
+
165
+ if isinstance(self.ckpt, UNet2DModel):
166
+ input = torch.concat([c_in * x_start, F.upsample_nearest(features, scale_factor=8)], dim=1)
167
+ model_output = self.ckpt(input, t_).sample
168
+ else:
169
+ model_output = self.ckpt(c_in * x_start, t_, features=features)
170
+
171
+ B, C = x_start.shape[:2]
172
+ model_output, _ = torch.split(model_output, C, dim=1)
173
+ pred_xstart = (
174
+ _extract_into_tensor(self.c_out, t_, x_start.shape) * model_output
175
+ + _extract_into_tensor(self.c_skip, t_, x_start.shape) * x_start
176
+ ).clamp(-1, 1)
177
+ x_start = pred_xstart
178
+ return x_start
179
+
180
+
181
+ def save_image(image, name):
182
+ import numpy as np
183
+ from PIL import Image
184
+
185
+ image = image[0].cpu().numpy()
186
+ image = (image + 1.0) * 127.5
187
+ image = image.clip(0, 255).astype(np.uint8)
188
+ image = Image.fromarray(image.transpose(1, 2, 0))
189
+ image.save(name)
190
+
191
+
192
+ def load_image(uri, size=None, center_crop=False):
193
+ import numpy as np
194
+ from PIL import Image
195
+
196
+ image = Image.open(uri)
197
+ if center_crop:
198
+ image = image.crop(
199
+ (
200
+ (image.width - min(image.width, image.height)) // 2,
201
+ (image.height - min(image.width, image.height)) // 2,
202
+ (image.width + min(image.width, image.height)) // 2,
203
+ (image.height + min(image.width, image.height)) // 2,
204
+ )
205
+ )
206
+ if size is not None:
207
+ image = image.resize(size)
208
+ image = torch.tensor(np.array(image).transpose(2, 0, 1)).unsqueeze(0).float()
209
+ image = image / 127.5 - 1.0
210
+ return image
211
+
212
+
213
+ class TimestepEmbedding_(nn.Module):
214
+ def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None:
215
+ super().__init__()
216
+ self.emb = nn.Embedding(n_time, n_emb)
217
+ self.f_1 = nn.Linear(n_emb, n_out)
218
+ self.f_2 = nn.Linear(n_out, n_out)
219
+
220
+ def forward(self, x) -> torch.Tensor:
221
+ x = self.emb(x)
222
+ x = self.f_1(x)
223
+ x = F.silu(x)
224
+ return self.f_2(x)
225
+
226
+
227
+ class ImageEmbedding(nn.Module):
228
+ def __init__(self, in_channels=7, out_channels=320) -> None:
229
+ super().__init__()
230
+ self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
231
+
232
+ def forward(self, x) -> torch.Tensor:
233
+ return self.f(x)
234
+
235
+
236
+ class ImageUnembedding(nn.Module):
237
+ def __init__(self, in_channels=320, out_channels=6) -> None:
238
+ super().__init__()
239
+ self.gn = nn.GroupNorm(32, in_channels)
240
+ self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
241
+
242
+ def forward(self, x) -> torch.Tensor:
243
+ return self.f(F.silu(self.gn(x)))
244
+
245
+
246
+ class ConvResblock(nn.Module):
247
+ def __init__(self, in_features=320, out_features=320) -> None:
248
+ super().__init__()
249
+ self.f_t = nn.Linear(1280, out_features * 2)
250
+
251
+ self.gn_1 = nn.GroupNorm(32, in_features)
252
+ self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1)
253
+
254
+ self.gn_2 = nn.GroupNorm(32, out_features)
255
+ self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1)
256
+
257
+ skip_conv = in_features != out_features
258
+ self.f_s = nn.Conv2d(in_features, out_features, kernel_size=1, padding=0) if skip_conv else nn.Identity()
259
+
260
+ def forward(self, x, t):
261
+ x_skip = x
262
+ t = self.f_t(F.silu(t))
263
+ t = t.chunk(2, dim=1)
264
+ t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1
265
+ t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3)
266
+
267
+ gn_1 = F.silu(self.gn_1(x))
268
+ f_1 = self.f_1(gn_1)
269
+
270
+ gn_2 = self.gn_2(f_1)
271
+
272
+ return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2))
273
+
274
+
275
+ # Also ConvResblock
276
+ class Downsample(nn.Module):
277
+ def __init__(self, in_channels=320) -> None:
278
+ super().__init__()
279
+ self.f_t = nn.Linear(1280, in_channels * 2)
280
+
281
+ self.gn_1 = nn.GroupNorm(32, in_channels)
282
+ self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
283
+ self.gn_2 = nn.GroupNorm(32, in_channels)
284
+
285
+ self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
286
+
287
+ def forward(self, x, t) -> torch.Tensor:
288
+ x_skip = x
289
+
290
+ t = self.f_t(F.silu(t))
291
+ t_1, t_2 = t.chunk(2, dim=1)
292
+ t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
293
+ t_2 = t_2.unsqueeze(2).unsqueeze(3)
294
+
295
+ gn_1 = F.silu(self.gn_1(x))
296
+ avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None)
297
+
298
+ f_1 = self.f_1(avg_pool2d)
299
+ gn_2 = self.gn_2(f_1)
300
+
301
+ f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
302
+
303
+ return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None)
304
+
305
+
306
+ # Also ConvResblock
307
+ class Upsample(nn.Module):
308
+ def __init__(self, in_channels=1024) -> None:
309
+ super().__init__()
310
+ self.f_t = nn.Linear(1280, in_channels * 2)
311
+
312
+ self.gn_1 = nn.GroupNorm(32, in_channels)
313
+ self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
314
+ self.gn_2 = nn.GroupNorm(32, in_channels)
315
+
316
+ self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
317
+
318
+ def forward(self, x, t) -> torch.Tensor:
319
+ x_skip = x
320
+
321
+ t = self.f_t(F.silu(t))
322
+ t_1, t_2 = t.chunk(2, dim=1)
323
+ t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
324
+ t_2 = t_2.unsqueeze(2).unsqueeze(3)
325
+
326
+ gn_1 = F.silu(self.gn_1(x))
327
+ upsample = F.upsample_nearest(gn_1, scale_factor=2)
328
+ f_1 = self.f_1(upsample)
329
+ gn_2 = self.gn_2(f_1)
330
+
331
+ f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
332
+
333
+ return f_2 + F.upsample_nearest(x_skip, scale_factor=2)
334
+
335
+
336
+ class ConvUNetVAE(nn.Module):
337
+ def __init__(self) -> None:
338
+ super().__init__()
339
+ self.embed_image = ImageEmbedding()
340
+ self.embed_time = TimestepEmbedding_()
341
+
342
+ down_0 = nn.ModuleList(
343
+ [
344
+ ConvResblock(320, 320),
345
+ ConvResblock(320, 320),
346
+ ConvResblock(320, 320),
347
+ Downsample(320),
348
+ ]
349
+ )
350
+ down_1 = nn.ModuleList(
351
+ [
352
+ ConvResblock(320, 640),
353
+ ConvResblock(640, 640),
354
+ ConvResblock(640, 640),
355
+ Downsample(640),
356
+ ]
357
+ )
358
+ down_2 = nn.ModuleList(
359
+ [
360
+ ConvResblock(640, 1024),
361
+ ConvResblock(1024, 1024),
362
+ ConvResblock(1024, 1024),
363
+ Downsample(1024),
364
+ ]
365
+ )
366
+ down_3 = nn.ModuleList(
367
+ [
368
+ ConvResblock(1024, 1024),
369
+ ConvResblock(1024, 1024),
370
+ ConvResblock(1024, 1024),
371
+ ]
372
+ )
373
+ self.down = nn.ModuleList(
374
+ [
375
+ down_0,
376
+ down_1,
377
+ down_2,
378
+ down_3,
379
+ ]
380
+ )
381
+
382
+ self.mid = nn.ModuleList(
383
+ [
384
+ ConvResblock(1024, 1024),
385
+ ConvResblock(1024, 1024),
386
+ ]
387
+ )
388
+
389
+ up_3 = nn.ModuleList(
390
+ [
391
+ ConvResblock(1024 * 2, 1024),
392
+ ConvResblock(1024 * 2, 1024),
393
+ ConvResblock(1024 * 2, 1024),
394
+ ConvResblock(1024 * 2, 1024),
395
+ Upsample(1024),
396
+ ]
397
+ )
398
+ up_2 = nn.ModuleList(
399
+ [
400
+ ConvResblock(1024 * 2, 1024),
401
+ ConvResblock(1024 * 2, 1024),
402
+ ConvResblock(1024 * 2, 1024),
403
+ ConvResblock(1024 + 640, 1024),
404
+ Upsample(1024),
405
+ ]
406
+ )
407
+ up_1 = nn.ModuleList(
408
+ [
409
+ ConvResblock(1024 + 640, 640),
410
+ ConvResblock(640 * 2, 640),
411
+ ConvResblock(640 * 2, 640),
412
+ ConvResblock(320 + 640, 640),
413
+ Upsample(640),
414
+ ]
415
+ )
416
+ up_0 = nn.ModuleList(
417
+ [
418
+ ConvResblock(320 + 640, 320),
419
+ ConvResblock(320 * 2, 320),
420
+ ConvResblock(320 * 2, 320),
421
+ ConvResblock(320 * 2, 320),
422
+ ]
423
+ )
424
+ self.up = nn.ModuleList(
425
+ [
426
+ up_0,
427
+ up_1,
428
+ up_2,
429
+ up_3,
430
+ ]
431
+ )
432
+
433
+ self.output = ImageUnembedding()
434
+
435
+ def forward(self, x, t, features) -> torch.Tensor:
436
+ converted = hasattr(self, "converted") and self.converted
437
+
438
+ x = torch.cat([x, F.upsample_nearest(features, scale_factor=8)], dim=1)
439
+
440
+ if converted:
441
+ t = self.time_embedding(self.time_proj(t))
442
+ else:
443
+ t = self.embed_time(t)
444
+
445
+ x = self.embed_image(x)
446
+
447
+ skips = [x]
448
+ for i, down in enumerate(self.down):
449
+ if converted and i in [0, 1, 2, 3]:
450
+ x, skips_ = down(x, t)
451
+ for skip in skips_:
452
+ skips.append(skip)
453
+ else:
454
+ for block in down:
455
+ x = block(x, t)
456
+ skips.append(x)
457
+ print(x.float().abs().sum())
458
+
459
+ if converted:
460
+ x = self.mid(x, t)
461
+ else:
462
+ for i in range(2):
463
+ x = self.mid[i](x, t)
464
+ print(x.float().abs().sum())
465
+
466
+ for i, up in enumerate(self.up[::-1]):
467
+ if converted and i in [0, 1, 2, 3]:
468
+ skip_4 = skips.pop()
469
+ skip_3 = skips.pop()
470
+ skip_2 = skips.pop()
471
+ skip_1 = skips.pop()
472
+ skips_ = (skip_1, skip_2, skip_3, skip_4)
473
+ x = up(x, skips_, t)
474
+ else:
475
+ for block in up:
476
+ if isinstance(block, ConvResblock):
477
+ x = torch.concat([x, skips.pop()], dim=1)
478
+ x = block(x, t)
479
+
480
+ return self.output(x)
481
+
482
+
483
+ def rename_state_dict_key(k):
484
+ k = k.replace("blocks.", "")
485
+ for i in range(5):
486
+ k = k.replace(f"down_{i}_", f"down.{i}.")
487
+ k = k.replace(f"conv_{i}.", f"{i}.")
488
+ k = k.replace(f"up_{i}_", f"up.{i}.")
489
+ k = k.replace(f"mid_{i}", f"mid.{i}")
490
+ k = k.replace("upsamp.", "4.")
491
+ k = k.replace("downsamp.", "3.")
492
+ k = k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias")
493
+ k = k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias")
494
+ k = k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias")
495
+ k = k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias")
496
+ k = k.replace("f.w", "f.weight").replace("f.b", "f.bias")
497
+ k = k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias")
498
+ k = k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias")
499
+ k = k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias")
500
+ return k
501
+
502
+
503
+ def rename_state_dict(sd, embedding):
504
+ sd = {rename_state_dict_key(k): v for k, v in sd.items()}
505
+ sd["embed_time.emb.weight"] = embedding["weight"]
506
+ return sd
507
+
508
+
509
+ # encode with stable diffusion vae
510
+ pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
511
+ pipe.vae.cuda()
512
+
513
+ # construct original decoder with jitted model
514
+ decoder_consistency = ConsistencyDecoder(device="cuda:0")
515
+
516
+ # construct UNet code, overwrite the decoder with conv_unet_vae
517
+ model = ConvUNetVAE()
518
+ model.load_state_dict(
519
+ rename_state_dict(
520
+ stl("consistency_decoder.safetensors"),
521
+ stl("embedding.safetensors"),
522
+ )
523
+ )
524
+ model = model.cuda()
525
+
526
+ decoder_consistency.ckpt = model
527
+
528
+ image = load_image(args.test_image, size=(256, 256), center_crop=True)
529
+ latent = pipe.vae.encode(image.half().cuda()).latent_dist.sample()
530
+
531
+ # decode with gan
532
+ sample_gan = pipe.vae.decode(latent).sample.detach()
533
+ save_image(sample_gan, "gan.png")
534
+
535
+ # decode with conv_unet_vae
536
+ sample_consistency_orig = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
537
+ save_image(sample_consistency_orig, "con_orig.png")
538
+
539
+
540
+ ########### conversion
541
+
542
+ print("CONVERSION")
543
+
544
+ print("DOWN BLOCK ONE")
545
+
546
+ block_one_sd_orig = model.down[0].state_dict()
547
+ block_one_sd_new = {}
548
+
549
+ for i in range(3):
550
+ block_one_sd_new[f"resnets.{i}.norm1.weight"] = block_one_sd_orig.pop(f"{i}.gn_1.weight")
551
+ block_one_sd_new[f"resnets.{i}.norm1.bias"] = block_one_sd_orig.pop(f"{i}.gn_1.bias")
552
+ block_one_sd_new[f"resnets.{i}.conv1.weight"] = block_one_sd_orig.pop(f"{i}.f_1.weight")
553
+ block_one_sd_new[f"resnets.{i}.conv1.bias"] = block_one_sd_orig.pop(f"{i}.f_1.bias")
554
+ block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_one_sd_orig.pop(f"{i}.f_t.weight")
555
+ block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_one_sd_orig.pop(f"{i}.f_t.bias")
556
+ block_one_sd_new[f"resnets.{i}.norm2.weight"] = block_one_sd_orig.pop(f"{i}.gn_2.weight")
557
+ block_one_sd_new[f"resnets.{i}.norm2.bias"] = block_one_sd_orig.pop(f"{i}.gn_2.bias")
558
+ block_one_sd_new[f"resnets.{i}.conv2.weight"] = block_one_sd_orig.pop(f"{i}.f_2.weight")
559
+ block_one_sd_new[f"resnets.{i}.conv2.bias"] = block_one_sd_orig.pop(f"{i}.f_2.bias")
560
+
561
+ block_one_sd_new["downsamplers.0.norm1.weight"] = block_one_sd_orig.pop("3.gn_1.weight")
562
+ block_one_sd_new["downsamplers.0.norm1.bias"] = block_one_sd_orig.pop("3.gn_1.bias")
563
+ block_one_sd_new["downsamplers.0.conv1.weight"] = block_one_sd_orig.pop("3.f_1.weight")
564
+ block_one_sd_new["downsamplers.0.conv1.bias"] = block_one_sd_orig.pop("3.f_1.bias")
565
+ block_one_sd_new["downsamplers.0.time_emb_proj.weight"] = block_one_sd_orig.pop("3.f_t.weight")
566
+ block_one_sd_new["downsamplers.0.time_emb_proj.bias"] = block_one_sd_orig.pop("3.f_t.bias")
567
+ block_one_sd_new["downsamplers.0.norm2.weight"] = block_one_sd_orig.pop("3.gn_2.weight")
568
+ block_one_sd_new["downsamplers.0.norm2.bias"] = block_one_sd_orig.pop("3.gn_2.bias")
569
+ block_one_sd_new["downsamplers.0.conv2.weight"] = block_one_sd_orig.pop("3.f_2.weight")
570
+ block_one_sd_new["downsamplers.0.conv2.bias"] = block_one_sd_orig.pop("3.f_2.bias")
571
+
572
+ assert len(block_one_sd_orig) == 0
573
+
574
+ block_one = ResnetDownsampleBlock2D(
575
+ in_channels=320,
576
+ out_channels=320,
577
+ temb_channels=1280,
578
+ num_layers=3,
579
+ add_downsample=True,
580
+ resnet_time_scale_shift="scale_shift",
581
+ resnet_eps=1e-5,
582
+ )
583
+
584
+ block_one.load_state_dict(block_one_sd_new)
585
+
586
+ print("DOWN BLOCK TWO")
587
+
588
+ block_two_sd_orig = model.down[1].state_dict()
589
+ block_two_sd_new = {}
590
+
591
+ for i in range(3):
592
+ block_two_sd_new[f"resnets.{i}.norm1.weight"] = block_two_sd_orig.pop(f"{i}.gn_1.weight")
593
+ block_two_sd_new[f"resnets.{i}.norm1.bias"] = block_two_sd_orig.pop(f"{i}.gn_1.bias")
594
+ block_two_sd_new[f"resnets.{i}.conv1.weight"] = block_two_sd_orig.pop(f"{i}.f_1.weight")
595
+ block_two_sd_new[f"resnets.{i}.conv1.bias"] = block_two_sd_orig.pop(f"{i}.f_1.bias")
596
+ block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_two_sd_orig.pop(f"{i}.f_t.weight")
597
+ block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_two_sd_orig.pop(f"{i}.f_t.bias")
598
+ block_two_sd_new[f"resnets.{i}.norm2.weight"] = block_two_sd_orig.pop(f"{i}.gn_2.weight")
599
+ block_two_sd_new[f"resnets.{i}.norm2.bias"] = block_two_sd_orig.pop(f"{i}.gn_2.bias")
600
+ block_two_sd_new[f"resnets.{i}.conv2.weight"] = block_two_sd_orig.pop(f"{i}.f_2.weight")
601
+ block_two_sd_new[f"resnets.{i}.conv2.bias"] = block_two_sd_orig.pop(f"{i}.f_2.bias")
602
+
603
+ if i == 0:
604
+ block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_two_sd_orig.pop(f"{i}.f_s.weight")
605
+ block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_two_sd_orig.pop(f"{i}.f_s.bias")
606
+
607
+ block_two_sd_new["downsamplers.0.norm1.weight"] = block_two_sd_orig.pop("3.gn_1.weight")
608
+ block_two_sd_new["downsamplers.0.norm1.bias"] = block_two_sd_orig.pop("3.gn_1.bias")
609
+ block_two_sd_new["downsamplers.0.conv1.weight"] = block_two_sd_orig.pop("3.f_1.weight")
610
+ block_two_sd_new["downsamplers.0.conv1.bias"] = block_two_sd_orig.pop("3.f_1.bias")
611
+ block_two_sd_new["downsamplers.0.time_emb_proj.weight"] = block_two_sd_orig.pop("3.f_t.weight")
612
+ block_two_sd_new["downsamplers.0.time_emb_proj.bias"] = block_two_sd_orig.pop("3.f_t.bias")
613
+ block_two_sd_new["downsamplers.0.norm2.weight"] = block_two_sd_orig.pop("3.gn_2.weight")
614
+ block_two_sd_new["downsamplers.0.norm2.bias"] = block_two_sd_orig.pop("3.gn_2.bias")
615
+ block_two_sd_new["downsamplers.0.conv2.weight"] = block_two_sd_orig.pop("3.f_2.weight")
616
+ block_two_sd_new["downsamplers.0.conv2.bias"] = block_two_sd_orig.pop("3.f_2.bias")
617
+
618
+ assert len(block_two_sd_orig) == 0
619
+
620
+ block_two = ResnetDownsampleBlock2D(
621
+ in_channels=320,
622
+ out_channels=640,
623
+ temb_channels=1280,
624
+ num_layers=3,
625
+ add_downsample=True,
626
+ resnet_time_scale_shift="scale_shift",
627
+ resnet_eps=1e-5,
628
+ )
629
+
630
+ block_two.load_state_dict(block_two_sd_new)
631
+
632
+ print("DOWN BLOCK THREE")
633
+
634
+ block_three_sd_orig = model.down[2].state_dict()
635
+ block_three_sd_new = {}
636
+
637
+ for i in range(3):
638
+ block_three_sd_new[f"resnets.{i}.norm1.weight"] = block_three_sd_orig.pop(f"{i}.gn_1.weight")
639
+ block_three_sd_new[f"resnets.{i}.norm1.bias"] = block_three_sd_orig.pop(f"{i}.gn_1.bias")
640
+ block_three_sd_new[f"resnets.{i}.conv1.weight"] = block_three_sd_orig.pop(f"{i}.f_1.weight")
641
+ block_three_sd_new[f"resnets.{i}.conv1.bias"] = block_three_sd_orig.pop(f"{i}.f_1.bias")
642
+ block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_three_sd_orig.pop(f"{i}.f_t.weight")
643
+ block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_three_sd_orig.pop(f"{i}.f_t.bias")
644
+ block_three_sd_new[f"resnets.{i}.norm2.weight"] = block_three_sd_orig.pop(f"{i}.gn_2.weight")
645
+ block_three_sd_new[f"resnets.{i}.norm2.bias"] = block_three_sd_orig.pop(f"{i}.gn_2.bias")
646
+ block_three_sd_new[f"resnets.{i}.conv2.weight"] = block_three_sd_orig.pop(f"{i}.f_2.weight")
647
+ block_three_sd_new[f"resnets.{i}.conv2.bias"] = block_three_sd_orig.pop(f"{i}.f_2.bias")
648
+
649
+ if i == 0:
650
+ block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_three_sd_orig.pop(f"{i}.f_s.weight")
651
+ block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_three_sd_orig.pop(f"{i}.f_s.bias")
652
+
653
+ block_three_sd_new["downsamplers.0.norm1.weight"] = block_three_sd_orig.pop("3.gn_1.weight")
654
+ block_three_sd_new["downsamplers.0.norm1.bias"] = block_three_sd_orig.pop("3.gn_1.bias")
655
+ block_three_sd_new["downsamplers.0.conv1.weight"] = block_three_sd_orig.pop("3.f_1.weight")
656
+ block_three_sd_new["downsamplers.0.conv1.bias"] = block_three_sd_orig.pop("3.f_1.bias")
657
+ block_three_sd_new["downsamplers.0.time_emb_proj.weight"] = block_three_sd_orig.pop("3.f_t.weight")
658
+ block_three_sd_new["downsamplers.0.time_emb_proj.bias"] = block_three_sd_orig.pop("3.f_t.bias")
659
+ block_three_sd_new["downsamplers.0.norm2.weight"] = block_three_sd_orig.pop("3.gn_2.weight")
660
+ block_three_sd_new["downsamplers.0.norm2.bias"] = block_three_sd_orig.pop("3.gn_2.bias")
661
+ block_three_sd_new["downsamplers.0.conv2.weight"] = block_three_sd_orig.pop("3.f_2.weight")
662
+ block_three_sd_new["downsamplers.0.conv2.bias"] = block_three_sd_orig.pop("3.f_2.bias")
663
+
664
+ assert len(block_three_sd_orig) == 0
665
+
666
+ block_three = ResnetDownsampleBlock2D(
667
+ in_channels=640,
668
+ out_channels=1024,
669
+ temb_channels=1280,
670
+ num_layers=3,
671
+ add_downsample=True,
672
+ resnet_time_scale_shift="scale_shift",
673
+ resnet_eps=1e-5,
674
+ )
675
+
676
+ block_three.load_state_dict(block_three_sd_new)
677
+
678
+ print("DOWN BLOCK FOUR")
679
+
680
+ block_four_sd_orig = model.down[3].state_dict()
681
+ block_four_sd_new = {}
682
+
683
+ for i in range(3):
684
+ block_four_sd_new[f"resnets.{i}.norm1.weight"] = block_four_sd_orig.pop(f"{i}.gn_1.weight")
685
+ block_four_sd_new[f"resnets.{i}.norm1.bias"] = block_four_sd_orig.pop(f"{i}.gn_1.bias")
686
+ block_four_sd_new[f"resnets.{i}.conv1.weight"] = block_four_sd_orig.pop(f"{i}.f_1.weight")
687
+ block_four_sd_new[f"resnets.{i}.conv1.bias"] = block_four_sd_orig.pop(f"{i}.f_1.bias")
688
+ block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_four_sd_orig.pop(f"{i}.f_t.weight")
689
+ block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_four_sd_orig.pop(f"{i}.f_t.bias")
690
+ block_four_sd_new[f"resnets.{i}.norm2.weight"] = block_four_sd_orig.pop(f"{i}.gn_2.weight")
691
+ block_four_sd_new[f"resnets.{i}.norm2.bias"] = block_four_sd_orig.pop(f"{i}.gn_2.bias")
692
+ block_four_sd_new[f"resnets.{i}.conv2.weight"] = block_four_sd_orig.pop(f"{i}.f_2.weight")
693
+ block_four_sd_new[f"resnets.{i}.conv2.bias"] = block_four_sd_orig.pop(f"{i}.f_2.bias")
694
+
695
+ assert len(block_four_sd_orig) == 0
696
+
697
+ block_four = ResnetDownsampleBlock2D(
698
+ in_channels=1024,
699
+ out_channels=1024,
700
+ temb_channels=1280,
701
+ num_layers=3,
702
+ add_downsample=False,
703
+ resnet_time_scale_shift="scale_shift",
704
+ resnet_eps=1e-5,
705
+ )
706
+
707
+ block_four.load_state_dict(block_four_sd_new)
708
+
709
+
710
+ print("MID BLOCK 1")
711
+
712
+ mid_block_one_sd_orig = model.mid.state_dict()
713
+ mid_block_one_sd_new = {}
714
+
715
+ for i in range(2):
716
+ mid_block_one_sd_new[f"resnets.{i}.norm1.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.weight")
717
+ mid_block_one_sd_new[f"resnets.{i}.norm1.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.bias")
718
+ mid_block_one_sd_new[f"resnets.{i}.conv1.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_1.weight")
719
+ mid_block_one_sd_new[f"resnets.{i}.conv1.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_1.bias")
720
+ mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_t.weight")
721
+ mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_t.bias")
722
+ mid_block_one_sd_new[f"resnets.{i}.norm2.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.weight")
723
+ mid_block_one_sd_new[f"resnets.{i}.norm2.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.bias")
724
+ mid_block_one_sd_new[f"resnets.{i}.conv2.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_2.weight")
725
+ mid_block_one_sd_new[f"resnets.{i}.conv2.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_2.bias")
726
+
727
+ assert len(mid_block_one_sd_orig) == 0
728
+
729
+ mid_block_one = UNetMidBlock2D(
730
+ in_channels=1024,
731
+ temb_channels=1280,
732
+ num_layers=1,
733
+ resnet_time_scale_shift="scale_shift",
734
+ resnet_eps=1e-5,
735
+ add_attention=False,
736
+ )
737
+
738
+ mid_block_one.load_state_dict(mid_block_one_sd_new)
739
+
740
+ print("UP BLOCK ONE")
741
+
742
+ up_block_one_sd_orig = model.up[-1].state_dict()
743
+ up_block_one_sd_new = {}
744
+
745
+ for i in range(4):
746
+ up_block_one_sd_new[f"resnets.{i}.norm1.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_1.weight")
747
+ up_block_one_sd_new[f"resnets.{i}.norm1.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_1.bias")
748
+ up_block_one_sd_new[f"resnets.{i}.conv1.weight"] = up_block_one_sd_orig.pop(f"{i}.f_1.weight")
749
+ up_block_one_sd_new[f"resnets.{i}.conv1.bias"] = up_block_one_sd_orig.pop(f"{i}.f_1.bias")
750
+ up_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_one_sd_orig.pop(f"{i}.f_t.weight")
751
+ up_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_one_sd_orig.pop(f"{i}.f_t.bias")
752
+ up_block_one_sd_new[f"resnets.{i}.norm2.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_2.weight")
753
+ up_block_one_sd_new[f"resnets.{i}.norm2.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_2.bias")
754
+ up_block_one_sd_new[f"resnets.{i}.conv2.weight"] = up_block_one_sd_orig.pop(f"{i}.f_2.weight")
755
+ up_block_one_sd_new[f"resnets.{i}.conv2.bias"] = up_block_one_sd_orig.pop(f"{i}.f_2.bias")
756
+ up_block_one_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_one_sd_orig.pop(f"{i}.f_s.weight")
757
+ up_block_one_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_one_sd_orig.pop(f"{i}.f_s.bias")
758
+
759
+ up_block_one_sd_new["upsamplers.0.norm1.weight"] = up_block_one_sd_orig.pop("4.gn_1.weight")
760
+ up_block_one_sd_new["upsamplers.0.norm1.bias"] = up_block_one_sd_orig.pop("4.gn_1.bias")
761
+ up_block_one_sd_new["upsamplers.0.conv1.weight"] = up_block_one_sd_orig.pop("4.f_1.weight")
762
+ up_block_one_sd_new["upsamplers.0.conv1.bias"] = up_block_one_sd_orig.pop("4.f_1.bias")
763
+ up_block_one_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_one_sd_orig.pop("4.f_t.weight")
764
+ up_block_one_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_one_sd_orig.pop("4.f_t.bias")
765
+ up_block_one_sd_new["upsamplers.0.norm2.weight"] = up_block_one_sd_orig.pop("4.gn_2.weight")
766
+ up_block_one_sd_new["upsamplers.0.norm2.bias"] = up_block_one_sd_orig.pop("4.gn_2.bias")
767
+ up_block_one_sd_new["upsamplers.0.conv2.weight"] = up_block_one_sd_orig.pop("4.f_2.weight")
768
+ up_block_one_sd_new["upsamplers.0.conv2.bias"] = up_block_one_sd_orig.pop("4.f_2.bias")
769
+
770
+ assert len(up_block_one_sd_orig) == 0
771
+
772
+ up_block_one = ResnetUpsampleBlock2D(
773
+ in_channels=1024,
774
+ prev_output_channel=1024,
775
+ out_channels=1024,
776
+ temb_channels=1280,
777
+ num_layers=4,
778
+ add_upsample=True,
779
+ resnet_time_scale_shift="scale_shift",
780
+ resnet_eps=1e-5,
781
+ )
782
+
783
+ up_block_one.load_state_dict(up_block_one_sd_new)
784
+
785
+ print("UP BLOCK TWO")
786
+
787
+ up_block_two_sd_orig = model.up[-2].state_dict()
788
+ up_block_two_sd_new = {}
789
+
790
+ for i in range(4):
791
+ up_block_two_sd_new[f"resnets.{i}.norm1.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_1.weight")
792
+ up_block_two_sd_new[f"resnets.{i}.norm1.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_1.bias")
793
+ up_block_two_sd_new[f"resnets.{i}.conv1.weight"] = up_block_two_sd_orig.pop(f"{i}.f_1.weight")
794
+ up_block_two_sd_new[f"resnets.{i}.conv1.bias"] = up_block_two_sd_orig.pop(f"{i}.f_1.bias")
795
+ up_block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_two_sd_orig.pop(f"{i}.f_t.weight")
796
+ up_block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_two_sd_orig.pop(f"{i}.f_t.bias")
797
+ up_block_two_sd_new[f"resnets.{i}.norm2.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_2.weight")
798
+ up_block_two_sd_new[f"resnets.{i}.norm2.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_2.bias")
799
+ up_block_two_sd_new[f"resnets.{i}.conv2.weight"] = up_block_two_sd_orig.pop(f"{i}.f_2.weight")
800
+ up_block_two_sd_new[f"resnets.{i}.conv2.bias"] = up_block_two_sd_orig.pop(f"{i}.f_2.bias")
801
+ up_block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_two_sd_orig.pop(f"{i}.f_s.weight")
802
+ up_block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_two_sd_orig.pop(f"{i}.f_s.bias")
803
+
804
+ up_block_two_sd_new["upsamplers.0.norm1.weight"] = up_block_two_sd_orig.pop("4.gn_1.weight")
805
+ up_block_two_sd_new["upsamplers.0.norm1.bias"] = up_block_two_sd_orig.pop("4.gn_1.bias")
806
+ up_block_two_sd_new["upsamplers.0.conv1.weight"] = up_block_two_sd_orig.pop("4.f_1.weight")
807
+ up_block_two_sd_new["upsamplers.0.conv1.bias"] = up_block_two_sd_orig.pop("4.f_1.bias")
808
+ up_block_two_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_two_sd_orig.pop("4.f_t.weight")
809
+ up_block_two_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_two_sd_orig.pop("4.f_t.bias")
810
+ up_block_two_sd_new["upsamplers.0.norm2.weight"] = up_block_two_sd_orig.pop("4.gn_2.weight")
811
+ up_block_two_sd_new["upsamplers.0.norm2.bias"] = up_block_two_sd_orig.pop("4.gn_2.bias")
812
+ up_block_two_sd_new["upsamplers.0.conv2.weight"] = up_block_two_sd_orig.pop("4.f_2.weight")
813
+ up_block_two_sd_new["upsamplers.0.conv2.bias"] = up_block_two_sd_orig.pop("4.f_2.bias")
814
+
815
+ assert len(up_block_two_sd_orig) == 0
816
+
817
+ up_block_two = ResnetUpsampleBlock2D(
818
+ in_channels=640,
819
+ prev_output_channel=1024,
820
+ out_channels=1024,
821
+ temb_channels=1280,
822
+ num_layers=4,
823
+ add_upsample=True,
824
+ resnet_time_scale_shift="scale_shift",
825
+ resnet_eps=1e-5,
826
+ )
827
+
828
+ up_block_two.load_state_dict(up_block_two_sd_new)
829
+
830
+ print("UP BLOCK THREE")
831
+
832
+ up_block_three_sd_orig = model.up[-3].state_dict()
833
+ up_block_three_sd_new = {}
834
+
835
+ for i in range(4):
836
+ up_block_three_sd_new[f"resnets.{i}.norm1.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_1.weight")
837
+ up_block_three_sd_new[f"resnets.{i}.norm1.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_1.bias")
838
+ up_block_three_sd_new[f"resnets.{i}.conv1.weight"] = up_block_three_sd_orig.pop(f"{i}.f_1.weight")
839
+ up_block_three_sd_new[f"resnets.{i}.conv1.bias"] = up_block_three_sd_orig.pop(f"{i}.f_1.bias")
840
+ up_block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_three_sd_orig.pop(f"{i}.f_t.weight")
841
+ up_block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_three_sd_orig.pop(f"{i}.f_t.bias")
842
+ up_block_three_sd_new[f"resnets.{i}.norm2.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_2.weight")
843
+ up_block_three_sd_new[f"resnets.{i}.norm2.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_2.bias")
844
+ up_block_three_sd_new[f"resnets.{i}.conv2.weight"] = up_block_three_sd_orig.pop(f"{i}.f_2.weight")
845
+ up_block_three_sd_new[f"resnets.{i}.conv2.bias"] = up_block_three_sd_orig.pop(f"{i}.f_2.bias")
846
+ up_block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_three_sd_orig.pop(f"{i}.f_s.weight")
847
+ up_block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_three_sd_orig.pop(f"{i}.f_s.bias")
848
+
849
+ up_block_three_sd_new["upsamplers.0.norm1.weight"] = up_block_three_sd_orig.pop("4.gn_1.weight")
850
+ up_block_three_sd_new["upsamplers.0.norm1.bias"] = up_block_three_sd_orig.pop("4.gn_1.bias")
851
+ up_block_three_sd_new["upsamplers.0.conv1.weight"] = up_block_three_sd_orig.pop("4.f_1.weight")
852
+ up_block_three_sd_new["upsamplers.0.conv1.bias"] = up_block_three_sd_orig.pop("4.f_1.bias")
853
+ up_block_three_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_three_sd_orig.pop("4.f_t.weight")
854
+ up_block_three_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_three_sd_orig.pop("4.f_t.bias")
855
+ up_block_three_sd_new["upsamplers.0.norm2.weight"] = up_block_three_sd_orig.pop("4.gn_2.weight")
856
+ up_block_three_sd_new["upsamplers.0.norm2.bias"] = up_block_three_sd_orig.pop("4.gn_2.bias")
857
+ up_block_three_sd_new["upsamplers.0.conv2.weight"] = up_block_three_sd_orig.pop("4.f_2.weight")
858
+ up_block_three_sd_new["upsamplers.0.conv2.bias"] = up_block_three_sd_orig.pop("4.f_2.bias")
859
+
860
+ assert len(up_block_three_sd_orig) == 0
861
+
862
+ up_block_three = ResnetUpsampleBlock2D(
863
+ in_channels=320,
864
+ prev_output_channel=1024,
865
+ out_channels=640,
866
+ temb_channels=1280,
867
+ num_layers=4,
868
+ add_upsample=True,
869
+ resnet_time_scale_shift="scale_shift",
870
+ resnet_eps=1e-5,
871
+ )
872
+
873
+ up_block_three.load_state_dict(up_block_three_sd_new)
874
+
875
+ print("UP BLOCK FOUR")
876
+
877
+ up_block_four_sd_orig = model.up[-4].state_dict()
878
+ up_block_four_sd_new = {}
879
+
880
+ for i in range(4):
881
+ up_block_four_sd_new[f"resnets.{i}.norm1.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_1.weight")
882
+ up_block_four_sd_new[f"resnets.{i}.norm1.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_1.bias")
883
+ up_block_four_sd_new[f"resnets.{i}.conv1.weight"] = up_block_four_sd_orig.pop(f"{i}.f_1.weight")
884
+ up_block_four_sd_new[f"resnets.{i}.conv1.bias"] = up_block_four_sd_orig.pop(f"{i}.f_1.bias")
885
+ up_block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_four_sd_orig.pop(f"{i}.f_t.weight")
886
+ up_block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_four_sd_orig.pop(f"{i}.f_t.bias")
887
+ up_block_four_sd_new[f"resnets.{i}.norm2.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_2.weight")
888
+ up_block_four_sd_new[f"resnets.{i}.norm2.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_2.bias")
889
+ up_block_four_sd_new[f"resnets.{i}.conv2.weight"] = up_block_four_sd_orig.pop(f"{i}.f_2.weight")
890
+ up_block_four_sd_new[f"resnets.{i}.conv2.bias"] = up_block_four_sd_orig.pop(f"{i}.f_2.bias")
891
+ up_block_four_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_four_sd_orig.pop(f"{i}.f_s.weight")
892
+ up_block_four_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_four_sd_orig.pop(f"{i}.f_s.bias")
893
+
894
+ assert len(up_block_four_sd_orig) == 0
895
+
896
+ up_block_four = ResnetUpsampleBlock2D(
897
+ in_channels=320,
898
+ prev_output_channel=640,
899
+ out_channels=320,
900
+ temb_channels=1280,
901
+ num_layers=4,
902
+ add_upsample=False,
903
+ resnet_time_scale_shift="scale_shift",
904
+ resnet_eps=1e-5,
905
+ )
906
+
907
+ up_block_four.load_state_dict(up_block_four_sd_new)
908
+
909
+ print("initial projection (conv_in)")
910
+
911
+ conv_in_sd_orig = model.embed_image.state_dict()
912
+ conv_in_sd_new = {}
913
+
914
+ conv_in_sd_new["weight"] = conv_in_sd_orig.pop("f.weight")
915
+ conv_in_sd_new["bias"] = conv_in_sd_orig.pop("f.bias")
916
+
917
+ assert len(conv_in_sd_orig) == 0
918
+
919
+ block_out_channels = [320, 640, 1024, 1024]
920
+
921
+ in_channels = 7
922
+ conv_in_kernel = 3
923
+ conv_in_padding = (conv_in_kernel - 1) // 2
924
+ conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
925
+
926
+ conv_in.load_state_dict(conv_in_sd_new)
927
+
928
+ print("out projection (conv_out) (conv_norm_out)")
929
+ out_channels = 6
930
+ norm_num_groups = 32
931
+ norm_eps = 1e-5
932
+ act_fn = "silu"
933
+ conv_out_kernel = 3
934
+ conv_out_padding = (conv_out_kernel - 1) // 2
935
+ conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
936
+ # uses torch.functional in orig
937
+ # conv_act = get_activation(act_fn)
938
+ conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding)
939
+
940
+ conv_norm_out.load_state_dict(model.output.gn.state_dict())
941
+ conv_out.load_state_dict(model.output.f.state_dict())
942
+
943
+ print("timestep projection (time_proj) (time_embedding)")
944
+
945
+ f1_sd = model.embed_time.f_1.state_dict()
946
+ f2_sd = model.embed_time.f_2.state_dict()
947
+
948
+ time_embedding_sd = {
949
+ "linear_1.weight": f1_sd.pop("weight"),
950
+ "linear_1.bias": f1_sd.pop("bias"),
951
+ "linear_2.weight": f2_sd.pop("weight"),
952
+ "linear_2.bias": f2_sd.pop("bias"),
953
+ }
954
+
955
+ assert len(f1_sd) == 0
956
+ assert len(f2_sd) == 0
957
+
958
+ time_embedding_type = "learned"
959
+ num_train_timesteps = 1024
960
+ time_embedding_dim = 1280
961
+
962
+ time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
963
+ timestep_input_dim = block_out_channels[0]
964
+
965
+ time_embedding = TimestepEmbedding(timestep_input_dim, time_embedding_dim)
966
+
967
+ time_proj.load_state_dict(model.embed_time.emb.state_dict())
968
+ time_embedding.load_state_dict(time_embedding_sd)
969
+
970
+ print("CONVERT")
971
+
972
+ time_embedding.to("cuda")
973
+ time_proj.to("cuda")
974
+ conv_in.to("cuda")
975
+
976
+ block_one.to("cuda")
977
+ block_two.to("cuda")
978
+ block_three.to("cuda")
979
+ block_four.to("cuda")
980
+
981
+ mid_block_one.to("cuda")
982
+
983
+ up_block_one.to("cuda")
984
+ up_block_two.to("cuda")
985
+ up_block_three.to("cuda")
986
+ up_block_four.to("cuda")
987
+
988
+ conv_norm_out.to("cuda")
989
+ conv_out.to("cuda")
990
+
991
+ model.time_proj = time_proj
992
+ model.time_embedding = time_embedding
993
+ model.embed_image = conv_in
994
+
995
+ model.down[0] = block_one
996
+ model.down[1] = block_two
997
+ model.down[2] = block_three
998
+ model.down[3] = block_four
999
+
1000
+ model.mid = mid_block_one
1001
+
1002
+ model.up[-1] = up_block_one
1003
+ model.up[-2] = up_block_two
1004
+ model.up[-3] = up_block_three
1005
+ model.up[-4] = up_block_four
1006
+
1007
+ model.output.gn = conv_norm_out
1008
+ model.output.f = conv_out
1009
+
1010
+ model.converted = True
1011
+
1012
+ sample_consistency_new = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
1013
+ save_image(sample_consistency_new, "con_new.png")
1014
+
1015
+ assert (sample_consistency_orig == sample_consistency_new).all()
1016
+
1017
+ print("making unet")
1018
+
1019
+ unet = UNet2DModel(
1020
+ in_channels=in_channels,
1021
+ out_channels=out_channels,
1022
+ down_block_types=(
1023
+ "ResnetDownsampleBlock2D",
1024
+ "ResnetDownsampleBlock2D",
1025
+ "ResnetDownsampleBlock2D",
1026
+ "ResnetDownsampleBlock2D",
1027
+ ),
1028
+ up_block_types=(
1029
+ "ResnetUpsampleBlock2D",
1030
+ "ResnetUpsampleBlock2D",
1031
+ "ResnetUpsampleBlock2D",
1032
+ "ResnetUpsampleBlock2D",
1033
+ ),
1034
+ block_out_channels=block_out_channels,
1035
+ layers_per_block=3,
1036
+ norm_num_groups=norm_num_groups,
1037
+ norm_eps=norm_eps,
1038
+ resnet_time_scale_shift="scale_shift",
1039
+ time_embedding_type="learned",
1040
+ num_train_timesteps=num_train_timesteps,
1041
+ add_attention=False,
1042
+ )
1043
+
1044
+ unet_state_dict = {}
1045
+
1046
+
1047
+ def add_state_dict(prefix, mod):
1048
+ for k, v in mod.state_dict().items():
1049
+ unet_state_dict[f"{prefix}.{k}"] = v
1050
+
1051
+
1052
+ add_state_dict("conv_in", conv_in)
1053
+ add_state_dict("time_proj", time_proj)
1054
+ add_state_dict("time_embedding", time_embedding)
1055
+ add_state_dict("down_blocks.0", block_one)
1056
+ add_state_dict("down_blocks.1", block_two)
1057
+ add_state_dict("down_blocks.2", block_three)
1058
+ add_state_dict("down_blocks.3", block_four)
1059
+ add_state_dict("mid_block", mid_block_one)
1060
+ add_state_dict("up_blocks.0", up_block_one)
1061
+ add_state_dict("up_blocks.1", up_block_two)
1062
+ add_state_dict("up_blocks.2", up_block_three)
1063
+ add_state_dict("up_blocks.3", up_block_four)
1064
+ add_state_dict("conv_norm_out", conv_norm_out)
1065
+ add_state_dict("conv_out", conv_out)
1066
+
1067
+ unet.load_state_dict(unet_state_dict)
1068
+
1069
+ print("running with diffusers unet")
1070
+
1071
+ unet.to("cuda")
1072
+
1073
+ decoder_consistency.ckpt = unet
1074
+
1075
+ sample_consistency_new_2 = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
1076
+ save_image(sample_consistency_new_2, "con_new_2.png")
1077
+
1078
+ assert (sample_consistency_orig == sample_consistency_new_2).all()
1079
+
1080
+ print("running with diffusers model")
1081
+
1082
+ Encoder.old_constructor = Encoder.__init__
1083
+
1084
+
1085
+ def new_constructor(self, **kwargs):
1086
+ self.old_constructor(**kwargs)
1087
+ self.constructor_arguments = kwargs
1088
+
1089
+
1090
+ Encoder.__init__ = new_constructor
1091
+
1092
+
1093
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
1094
+ consistency_vae = ConsistencyDecoderVAE(
1095
+ encoder_args=vae.encoder.constructor_arguments,
1096
+ decoder_args=unet.config,
1097
+ scaling_factor=vae.config.scaling_factor,
1098
+ block_out_channels=vae.config.block_out_channels,
1099
+ latent_channels=vae.config.latent_channels,
1100
+ )
1101
+ consistency_vae.encoder.load_state_dict(vae.encoder.state_dict())
1102
+ consistency_vae.quant_conv.load_state_dict(vae.quant_conv.state_dict())
1103
+ consistency_vae.decoder_unet.load_state_dict(unet.state_dict())
1104
+
1105
+ consistency_vae.to(dtype=torch.float16, device="cuda")
1106
+
1107
+ sample_consistency_new_3 = consistency_vae.decode(
1108
+ 0.18215 * latent, generator=torch.Generator("cpu").manual_seed(0)
1109
+ ).sample
1110
+
1111
+ print("max difference")
1112
+ print((sample_consistency_orig - sample_consistency_new_3).abs().max())
1113
+ print("total difference")
1114
+ print((sample_consistency_orig - sample_consistency_new_3).abs().sum())
1115
+ # assert (sample_consistency_orig == sample_consistency_new_3).all()
1116
+
1117
+ print("running with diffusers pipeline")
1118
+
1119
+ pipe = DiffusionPipeline.from_pretrained(
1120
+ "runwayml/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16
1121
+ )
1122
+ pipe.to("cuda")
1123
+
1124
+ pipe("horse", generator=torch.Generator("cpu").manual_seed(0)).images[0].save("horse.png")
1125
+
1126
+
1127
+ if args.save_pretrained is not None:
1128
+ consistency_vae.save_pretrained(args.save_pretrained)
diffusers/scripts/convert_dance_diffusion_to_diffusers.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import math
4
+ import os
5
+ from copy import deepcopy
6
+
7
+ import requests
8
+ import torch
9
+ from audio_diffusion.models import DiffusionAttnUnet1D
10
+ from diffusion import sampling
11
+ from torch import nn
12
+
13
+ from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
14
+ from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
15
+
16
+
17
+ MODELS_MAP = {
18
+ "gwf-440k": {
19
+ "url": "https://model-server.zqevans2.workers.dev/gwf-440k.ckpt",
20
+ "sample_rate": 48000,
21
+ "sample_size": 65536,
22
+ },
23
+ "jmann-small-190k": {
24
+ "url": "https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt",
25
+ "sample_rate": 48000,
26
+ "sample_size": 65536,
27
+ },
28
+ "jmann-large-580k": {
29
+ "url": "https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt",
30
+ "sample_rate": 48000,
31
+ "sample_size": 131072,
32
+ },
33
+ "maestro-uncond-150k": {
34
+ "url": "https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt",
35
+ "sample_rate": 16000,
36
+ "sample_size": 65536,
37
+ },
38
+ "unlocked-uncond-250k": {
39
+ "url": "https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt",
40
+ "sample_rate": 16000,
41
+ "sample_size": 65536,
42
+ },
43
+ "honk-140k": {
44
+ "url": "https://model-server.zqevans2.workers.dev/honk-140k.ckpt",
45
+ "sample_rate": 16000,
46
+ "sample_size": 65536,
47
+ },
48
+ }
49
+
50
+
51
+ def alpha_sigma_to_t(alpha, sigma):
52
+ """Returns a timestep, given the scaling factors for the clean image and for
53
+ the noise."""
54
+ return torch.atan2(sigma, alpha) / math.pi * 2
55
+
56
+
57
+ def get_crash_schedule(t):
58
+ sigma = torch.sin(t * math.pi / 2) ** 2
59
+ alpha = (1 - sigma**2) ** 0.5
60
+ return alpha_sigma_to_t(alpha, sigma)
61
+
62
+
63
+ class Object(object):
64
+ pass
65
+
66
+
67
+ class DiffusionUncond(nn.Module):
68
+ def __init__(self, global_args):
69
+ super().__init__()
70
+
71
+ self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4)
72
+ self.diffusion_ema = deepcopy(self.diffusion)
73
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
74
+
75
+
76
+ def download(model_name):
77
+ url = MODELS_MAP[model_name]["url"]
78
+ r = requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT)
79
+
80
+ local_filename = f"./{model_name}.ckpt"
81
+ with open(local_filename, "wb") as fp:
82
+ for chunk in r.iter_content(chunk_size=8192):
83
+ fp.write(chunk)
84
+
85
+ return local_filename
86
+
87
+
88
+ DOWN_NUM_TO_LAYER = {
89
+ "1": "resnets.0",
90
+ "2": "attentions.0",
91
+ "3": "resnets.1",
92
+ "4": "attentions.1",
93
+ "5": "resnets.2",
94
+ "6": "attentions.2",
95
+ }
96
+ UP_NUM_TO_LAYER = {
97
+ "8": "resnets.0",
98
+ "9": "attentions.0",
99
+ "10": "resnets.1",
100
+ "11": "attentions.1",
101
+ "12": "resnets.2",
102
+ "13": "attentions.2",
103
+ }
104
+ MID_NUM_TO_LAYER = {
105
+ "1": "resnets.0",
106
+ "2": "attentions.0",
107
+ "3": "resnets.1",
108
+ "4": "attentions.1",
109
+ "5": "resnets.2",
110
+ "6": "attentions.2",
111
+ "8": "resnets.3",
112
+ "9": "attentions.3",
113
+ "10": "resnets.4",
114
+ "11": "attentions.4",
115
+ "12": "resnets.5",
116
+ "13": "attentions.5",
117
+ }
118
+ DEPTH_0_TO_LAYER = {
119
+ "0": "resnets.0",
120
+ "1": "resnets.1",
121
+ "2": "resnets.2",
122
+ "4": "resnets.0",
123
+ "5": "resnets.1",
124
+ "6": "resnets.2",
125
+ }
126
+
127
+ RES_CONV_MAP = {
128
+ "skip": "conv_skip",
129
+ "main.0": "conv_1",
130
+ "main.1": "group_norm_1",
131
+ "main.3": "conv_2",
132
+ "main.4": "group_norm_2",
133
+ }
134
+
135
+ ATTN_MAP = {
136
+ "norm": "group_norm",
137
+ "qkv_proj": ["query", "key", "value"],
138
+ "out_proj": ["proj_attn"],
139
+ }
140
+
141
+
142
+ def convert_resconv_naming(name):
143
+ if name.startswith("skip"):
144
+ return name.replace("skip", RES_CONV_MAP["skip"])
145
+
146
+ # name has to be of format main.{digit}
147
+ if not name.startswith("main."):
148
+ raise ValueError(f"ResConvBlock error with {name}")
149
+
150
+ return name.replace(name[:6], RES_CONV_MAP[name[:6]])
151
+
152
+
153
+ def convert_attn_naming(name):
154
+ for key, value in ATTN_MAP.items():
155
+ if name.startswith(key) and not isinstance(value, list):
156
+ return name.replace(key, value)
157
+ elif name.startswith(key):
158
+ return [name.replace(key, v) for v in value]
159
+ raise ValueError(f"Attn error with {name}")
160
+
161
+
162
+ def rename(input_string, max_depth=13):
163
+ string = input_string
164
+
165
+ if string.split(".")[0] == "timestep_embed":
166
+ return string.replace("timestep_embed", "time_proj")
167
+
168
+ depth = 0
169
+ if string.startswith("net.3."):
170
+ depth += 1
171
+ string = string[6:]
172
+ elif string.startswith("net."):
173
+ string = string[4:]
174
+
175
+ while string.startswith("main.7."):
176
+ depth += 1
177
+ string = string[7:]
178
+
179
+ if string.startswith("main."):
180
+ string = string[5:]
181
+
182
+ # mid block
183
+ if string[:2].isdigit():
184
+ layer_num = string[:2]
185
+ string_left = string[2:]
186
+ else:
187
+ layer_num = string[0]
188
+ string_left = string[1:]
189
+
190
+ if depth == max_depth:
191
+ new_layer = MID_NUM_TO_LAYER[layer_num]
192
+ prefix = "mid_block"
193
+ elif depth > 0 and int(layer_num) < 7:
194
+ new_layer = DOWN_NUM_TO_LAYER[layer_num]
195
+ prefix = f"down_blocks.{depth}"
196
+ elif depth > 0 and int(layer_num) > 7:
197
+ new_layer = UP_NUM_TO_LAYER[layer_num]
198
+ prefix = f"up_blocks.{max_depth - depth - 1}"
199
+ elif depth == 0:
200
+ new_layer = DEPTH_0_TO_LAYER[layer_num]
201
+ prefix = f"up_blocks.{max_depth - 1}" if int(layer_num) > 3 else "down_blocks.0"
202
+
203
+ if not string_left.startswith("."):
204
+ raise ValueError(f"Naming error with {input_string} and string_left: {string_left}.")
205
+
206
+ string_left = string_left[1:]
207
+
208
+ if "resnets" in new_layer:
209
+ string_left = convert_resconv_naming(string_left)
210
+ elif "attentions" in new_layer:
211
+ new_string_left = convert_attn_naming(string_left)
212
+ string_left = new_string_left
213
+
214
+ if not isinstance(string_left, list):
215
+ new_string = prefix + "." + new_layer + "." + string_left
216
+ else:
217
+ new_string = [prefix + "." + new_layer + "." + s for s in string_left]
218
+ return new_string
219
+
220
+
221
+ def rename_orig_weights(state_dict):
222
+ new_state_dict = {}
223
+ for k, v in state_dict.items():
224
+ if k.endswith("kernel"):
225
+ # up- and downsample layers, don't have trainable weights
226
+ continue
227
+
228
+ new_k = rename(k)
229
+
230
+ # check if we need to transform from Conv => Linear for attention
231
+ if isinstance(new_k, list):
232
+ new_state_dict = transform_conv_attns(new_state_dict, new_k, v)
233
+ else:
234
+ new_state_dict[new_k] = v
235
+
236
+ return new_state_dict
237
+
238
+
239
+ def transform_conv_attns(new_state_dict, new_k, v):
240
+ if len(new_k) == 1:
241
+ if len(v.shape) == 3:
242
+ # weight
243
+ new_state_dict[new_k[0]] = v[:, :, 0]
244
+ else:
245
+ # bias
246
+ new_state_dict[new_k[0]] = v
247
+ else:
248
+ # qkv matrices
249
+ trippled_shape = v.shape[0]
250
+ single_shape = trippled_shape // 3
251
+ for i in range(3):
252
+ if len(v.shape) == 3:
253
+ new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape, :, 0]
254
+ else:
255
+ new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape]
256
+ return new_state_dict
257
+
258
+
259
+ def main(args):
260
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
261
+
262
+ model_name = args.model_path.split("/")[-1].split(".")[0]
263
+ if not os.path.isfile(args.model_path):
264
+ assert model_name == args.model_path, (
265
+ f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
266
+ )
267
+ args.model_path = download(model_name)
268
+
269
+ sample_rate = MODELS_MAP[model_name]["sample_rate"]
270
+ sample_size = MODELS_MAP[model_name]["sample_size"]
271
+
272
+ config = Object()
273
+ config.sample_size = sample_size
274
+ config.sample_rate = sample_rate
275
+ config.latent_dim = 0
276
+
277
+ diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate)
278
+ diffusers_state_dict = diffusers_model.state_dict()
279
+
280
+ orig_model = DiffusionUncond(config)
281
+ orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"])
282
+ orig_model = orig_model.diffusion_ema.eval()
283
+ orig_model_state_dict = orig_model.state_dict()
284
+ renamed_state_dict = rename_orig_weights(orig_model_state_dict)
285
+
286
+ renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys())
287
+ diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys())
288
+
289
+ assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}"
290
+ assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
291
+
292
+ for key, value in renamed_state_dict.items():
293
+ assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (
294
+ f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
295
+ )
296
+ if key == "time_proj.weight":
297
+ value = value.squeeze()
298
+
299
+ diffusers_state_dict[key] = value
300
+
301
+ diffusers_model.load_state_dict(diffusers_state_dict)
302
+
303
+ steps = 100
304
+ seed = 33
305
+
306
+ diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps)
307
+
308
+ generator = torch.manual_seed(seed)
309
+ noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device)
310
+
311
+ t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
312
+ step_list = get_crash_schedule(t)
313
+
314
+ pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler)
315
+
316
+ generator = torch.manual_seed(33)
317
+ audio = pipe(num_inference_steps=steps, generator=generator).audios
318
+
319
+ generated = sampling.iplms_sample(orig_model, noise, step_list, {})
320
+ generated = generated.clamp(-1, 1)
321
+
322
+ diff_sum = (generated - audio).abs().sum()
323
+ diff_max = (generated - audio).abs().max()
324
+
325
+ if args.save:
326
+ pipe.save_pretrained(args.checkpoint_path)
327
+
328
+ print("Diff sum", diff_sum)
329
+ print("Diff max", diff_max)
330
+
331
+ assert diff_max < 1e-3, f"Diff max: {diff_max} is too much :-/"
332
+
333
+ print(f"Conversion for {model_name} successful!")
334
+
335
+
336
+ if __name__ == "__main__":
337
+ parser = argparse.ArgumentParser()
338
+
339
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
340
+ parser.add_argument(
341
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."
342
+ )
343
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
344
+ args = parser.parse_args()
345
+
346
+ main(args)
diffusers/scripts/convert_dcae_to_diffusers.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import Any, Dict
3
+
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+ from safetensors.torch import load_file
7
+
8
+ from diffusers import AutoencoderDC
9
+
10
+
11
+ def remap_qkv_(key: str, state_dict: Dict[str, Any]):
12
+ qkv = state_dict.pop(key)
13
+ q, k, v = torch.chunk(qkv, 3, dim=0)
14
+ parent_module, _, _ = key.rpartition(".qkv.conv.weight")
15
+ state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
16
+ state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
17
+ state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
18
+
19
+
20
+ def remap_proj_conv_(key: str, state_dict: Dict[str, Any]):
21
+ parent_module, _, _ = key.rpartition(".proj.conv.weight")
22
+ state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
23
+
24
+
25
+ AE_KEYS_RENAME_DICT = {
26
+ # common
27
+ "main.": "",
28
+ "op_list.": "",
29
+ "context_module": "attn",
30
+ "local_module": "conv_out",
31
+ # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
32
+ # If there were more scales, there would be more layers, so a loop would be better to handle this
33
+ "aggreg.0.0": "to_qkv_multiscale.0.proj_in",
34
+ "aggreg.0.1": "to_qkv_multiscale.0.proj_out",
35
+ "depth_conv.conv": "conv_depth",
36
+ "inverted_conv.conv": "conv_inverted",
37
+ "point_conv.conv": "conv_point",
38
+ "point_conv.norm": "norm",
39
+ "conv.conv.": "conv.",
40
+ "conv1.conv": "conv1",
41
+ "conv2.conv": "conv2",
42
+ "conv2.norm": "norm",
43
+ "proj.norm": "norm_out",
44
+ # encoder
45
+ "encoder.project_in.conv": "encoder.conv_in",
46
+ "encoder.project_out.0.conv": "encoder.conv_out",
47
+ "encoder.stages": "encoder.down_blocks",
48
+ # decoder
49
+ "decoder.project_in.conv": "decoder.conv_in",
50
+ "decoder.project_out.0": "decoder.norm_out",
51
+ "decoder.project_out.2.conv": "decoder.conv_out",
52
+ "decoder.stages": "decoder.up_blocks",
53
+ }
54
+
55
+ AE_F32C32_KEYS = {
56
+ # encoder
57
+ "encoder.project_in.conv": "encoder.conv_in.conv",
58
+ # decoder
59
+ "decoder.project_out.2.conv": "decoder.conv_out.conv",
60
+ }
61
+
62
+ AE_F64C128_KEYS = {
63
+ # encoder
64
+ "encoder.project_in.conv": "encoder.conv_in.conv",
65
+ # decoder
66
+ "decoder.project_out.2.conv": "decoder.conv_out.conv",
67
+ }
68
+
69
+ AE_F128C512_KEYS = {
70
+ # encoder
71
+ "encoder.project_in.conv": "encoder.conv_in.conv",
72
+ # decoder
73
+ "decoder.project_out.2.conv": "decoder.conv_out.conv",
74
+ }
75
+
76
+ AE_SPECIAL_KEYS_REMAP = {
77
+ "qkv.conv.weight": remap_qkv_,
78
+ "proj.conv.weight": remap_proj_conv_,
79
+ }
80
+
81
+
82
+ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
83
+ state_dict = saved_dict
84
+ if "model" in saved_dict.keys():
85
+ state_dict = state_dict["model"]
86
+ if "module" in saved_dict.keys():
87
+ state_dict = state_dict["module"]
88
+ if "state_dict" in saved_dict.keys():
89
+ state_dict = state_dict["state_dict"]
90
+ return state_dict
91
+
92
+
93
+ def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
94
+ state_dict[new_key] = state_dict.pop(old_key)
95
+
96
+
97
+ def convert_ae(config_name: str, dtype: torch.dtype):
98
+ config = get_ae_config(config_name)
99
+ hub_id = f"mit-han-lab/{config_name}"
100
+ ckpt_path = hf_hub_download(hub_id, "model.safetensors")
101
+ original_state_dict = get_state_dict(load_file(ckpt_path))
102
+
103
+ ae = AutoencoderDC(**config).to(dtype=dtype)
104
+
105
+ for key in list(original_state_dict.keys()):
106
+ new_key = key[:]
107
+ for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
108
+ new_key = new_key.replace(replace_key, rename_key)
109
+ update_state_dict_(original_state_dict, key, new_key)
110
+
111
+ for key in list(original_state_dict.keys()):
112
+ for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
113
+ if special_key not in key:
114
+ continue
115
+ handler_fn_inplace(key, original_state_dict)
116
+
117
+ ae.load_state_dict(original_state_dict, strict=True)
118
+ return ae
119
+
120
+
121
+ def get_ae_config(name: str):
122
+ if name in ["dc-ae-f32c32-sana-1.0"]:
123
+ config = {
124
+ "latent_channels": 32,
125
+ "encoder_block_types": (
126
+ "ResBlock",
127
+ "ResBlock",
128
+ "ResBlock",
129
+ "EfficientViTBlock",
130
+ "EfficientViTBlock",
131
+ "EfficientViTBlock",
132
+ ),
133
+ "decoder_block_types": (
134
+ "ResBlock",
135
+ "ResBlock",
136
+ "ResBlock",
137
+ "EfficientViTBlock",
138
+ "EfficientViTBlock",
139
+ "EfficientViTBlock",
140
+ ),
141
+ "encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
142
+ "decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
143
+ "encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
144
+ "decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
145
+ "encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
146
+ "decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
147
+ "downsample_block_type": "conv",
148
+ "upsample_block_type": "interpolate",
149
+ "decoder_norm_types": "rms_norm",
150
+ "decoder_act_fns": "silu",
151
+ "scaling_factor": 0.41407,
152
+ }
153
+ elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
154
+ AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS)
155
+ config = {
156
+ "latent_channels": 32,
157
+ "encoder_block_types": [
158
+ "ResBlock",
159
+ "ResBlock",
160
+ "ResBlock",
161
+ "EfficientViTBlock",
162
+ "EfficientViTBlock",
163
+ "EfficientViTBlock",
164
+ ],
165
+ "decoder_block_types": [
166
+ "ResBlock",
167
+ "ResBlock",
168
+ "ResBlock",
169
+ "EfficientViTBlock",
170
+ "EfficientViTBlock",
171
+ "EfficientViTBlock",
172
+ ],
173
+ "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
174
+ "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
175
+ "encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
176
+ "decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
177
+ "encoder_qkv_multiscales": ((), (), (), (), (), ()),
178
+ "decoder_qkv_multiscales": ((), (), (), (), (), ()),
179
+ "decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
180
+ "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
181
+ }
182
+ if name == "dc-ae-f32c32-in-1.0":
183
+ config["scaling_factor"] = 0.3189
184
+ elif name == "dc-ae-f32c32-mix-1.0":
185
+ config["scaling_factor"] = 0.4552
186
+ elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
187
+ AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS)
188
+ config = {
189
+ "latent_channels": 128,
190
+ "encoder_block_types": [
191
+ "ResBlock",
192
+ "ResBlock",
193
+ "ResBlock",
194
+ "EfficientViTBlock",
195
+ "EfficientViTBlock",
196
+ "EfficientViTBlock",
197
+ "EfficientViTBlock",
198
+ ],
199
+ "decoder_block_types": [
200
+ "ResBlock",
201
+ "ResBlock",
202
+ "ResBlock",
203
+ "EfficientViTBlock",
204
+ "EfficientViTBlock",
205
+ "EfficientViTBlock",
206
+ "EfficientViTBlock",
207
+ ],
208
+ "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
209
+ "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
210
+ "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
211
+ "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
212
+ "encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
213
+ "decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
214
+ "decoder_norm_types": [
215
+ "batch_norm",
216
+ "batch_norm",
217
+ "batch_norm",
218
+ "rms_norm",
219
+ "rms_norm",
220
+ "rms_norm",
221
+ "rms_norm",
222
+ ],
223
+ "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
224
+ }
225
+ if name == "dc-ae-f64c128-in-1.0":
226
+ config["scaling_factor"] = 0.2889
227
+ elif name == "dc-ae-f64c128-mix-1.0":
228
+ config["scaling_factor"] = 0.4538
229
+ elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
230
+ AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS)
231
+ config = {
232
+ "latent_channels": 512,
233
+ "encoder_block_types": [
234
+ "ResBlock",
235
+ "ResBlock",
236
+ "ResBlock",
237
+ "EfficientViTBlock",
238
+ "EfficientViTBlock",
239
+ "EfficientViTBlock",
240
+ "EfficientViTBlock",
241
+ "EfficientViTBlock",
242
+ ],
243
+ "decoder_block_types": [
244
+ "ResBlock",
245
+ "ResBlock",
246
+ "ResBlock",
247
+ "EfficientViTBlock",
248
+ "EfficientViTBlock",
249
+ "EfficientViTBlock",
250
+ "EfficientViTBlock",
251
+ "EfficientViTBlock",
252
+ ],
253
+ "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
254
+ "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
255
+ "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
256
+ "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
257
+ "encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
258
+ "decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
259
+ "decoder_norm_types": [
260
+ "batch_norm",
261
+ "batch_norm",
262
+ "batch_norm",
263
+ "rms_norm",
264
+ "rms_norm",
265
+ "rms_norm",
266
+ "rms_norm",
267
+ "rms_norm",
268
+ ],
269
+ "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
270
+ }
271
+ if name == "dc-ae-f128c512-in-1.0":
272
+ config["scaling_factor"] = 0.4883
273
+ elif name == "dc-ae-f128c512-mix-1.0":
274
+ config["scaling_factor"] = 0.3620
275
+ else:
276
+ raise ValueError("Invalid config name provided.")
277
+
278
+ return config
279
+
280
+
281
+ def get_args():
282
+ parser = argparse.ArgumentParser()
283
+ parser.add_argument(
284
+ "--config_name",
285
+ type=str,
286
+ default="dc-ae-f32c32-sana-1.0",
287
+ choices=[
288
+ "dc-ae-f32c32-sana-1.0",
289
+ "dc-ae-f32c32-in-1.0",
290
+ "dc-ae-f32c32-mix-1.0",
291
+ "dc-ae-f64c128-in-1.0",
292
+ "dc-ae-f64c128-mix-1.0",
293
+ "dc-ae-f128c512-in-1.0",
294
+ "dc-ae-f128c512-mix-1.0",
295
+ ],
296
+ help="The DCAE checkpoint to convert",
297
+ )
298
+ parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
299
+ parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
300
+ return parser.parse_args()
301
+
302
+
303
+ DTYPE_MAPPING = {
304
+ "fp32": torch.float32,
305
+ "fp16": torch.float16,
306
+ "bf16": torch.bfloat16,
307
+ }
308
+
309
+ VARIANT_MAPPING = {
310
+ "fp32": None,
311
+ "fp16": "fp16",
312
+ "bf16": "bf16",
313
+ }
314
+
315
+
316
+ if __name__ == "__main__":
317
+ args = get_args()
318
+
319
+ dtype = DTYPE_MAPPING[args.dtype]
320
+ variant = VARIANT_MAPPING[args.dtype]
321
+
322
+ ae = convert_ae(args.config_name, dtype)
323
+ ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
diffusers/scripts/convert_diffusers_sdxl_lora_to_webui.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format
2
+ # This means that you can input your diffusers-trained LoRAs and
3
+ # Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others.
4
+
5
+ # To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy
6
+ # https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file
7
+ # and run the script:
8
+ # python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors
9
+ # now you can use corgy.safetensors in your WebUI of choice!
10
+
11
+ # To train your own, here are some diffusers training scripts and utils that you can use and then convert:
12
+ # LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease
13
+ # Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer:
14
+ # - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
15
+ # - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb
16
+ # Canonical diffusers training scripts:
17
+ # - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
18
+ # - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
19
+
20
+ import argparse
21
+ import os
22
+
23
+ from safetensors.torch import load_file, save_file
24
+
25
+ from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
26
+
27
+
28
+ def convert_and_save(input_lora, output_lora=None):
29
+ if output_lora is None:
30
+ base_name = os.path.splitext(input_lora)[0]
31
+ output_lora = f"{base_name}_webui.safetensors"
32
+
33
+ diffusers_state_dict = load_file(input_lora)
34
+ peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict)
35
+ kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
36
+ save_file(kohya_state_dict, output_lora)
37
+
38
+
39
+ if __name__ == "__main__":
40
+ parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format.")
41
+ parser.add_argument(
42
+ "--input_lora",
43
+ type=str,
44
+ required=True,
45
+ help="Path to the input LoRA model file in the diffusers format.",
46
+ )
47
+ parser.add_argument(
48
+ "--output_lora",
49
+ type=str,
50
+ required=False,
51
+ help="Path for the converted LoRA (safetensors format for AUTOMATIC1111, ComfyUI, etc.). Optional, defaults to input name with a _webui suffix.",
52
+ )
53
+
54
+ args = parser.parse_args()
55
+
56
+ convert_and_save(args.input_lora, args.output_lora)
diffusers/scripts/convert_flux_xlabs_ipadapter_to_diffusers.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from contextlib import nullcontext
3
+
4
+ import safetensors.torch
5
+ from accelerate import init_empty_weights
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
9
+
10
+
11
+ if is_transformers_available():
12
+ from transformers import CLIPVisionModelWithProjection
13
+
14
+ vision = True
15
+ else:
16
+ vision = False
17
+
18
+ """
19
+ python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
20
+ --original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
21
+ --filename "flux-ip-adapter.safetensors"
22
+ --output_path "flux-ip-adapter-hf/"
23
+ """
24
+
25
+
26
+ CTX = init_empty_weights if is_accelerate_available else nullcontext
27
+
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
30
+ parser.add_argument("--filename", default="flux.safetensors", type=str)
31
+ parser.add_argument("--checkpoint_path", default=None, type=str)
32
+ parser.add_argument("--output_path", type=str)
33
+ parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
34
+
35
+ args = parser.parse_args()
36
+
37
+
38
+ def load_original_checkpoint(args):
39
+ if args.original_state_dict_repo_id is not None:
40
+ ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
41
+ elif args.checkpoint_path is not None:
42
+ ckpt_path = args.checkpoint_path
43
+ else:
44
+ raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
45
+
46
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
47
+ return original_state_dict
48
+
49
+
50
+ def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
51
+ converted_state_dict = {}
52
+
53
+ # image_proj
54
+ ## norm
55
+ converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
56
+ converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
57
+ ## proj
58
+ converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
59
+ converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
60
+
61
+ # double transformer blocks
62
+ for i in range(num_layers):
63
+ block_prefix = f"ip_adapter.{i}."
64
+ # to_k_ip
65
+ converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
66
+ f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
67
+ )
68
+ converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
69
+ f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
70
+ )
71
+ # to_v_ip
72
+ converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
73
+ f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
74
+ )
75
+ converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
76
+ f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
77
+ )
78
+
79
+ return converted_state_dict
80
+
81
+
82
+ def main(args):
83
+ original_ckpt = load_original_checkpoint(args)
84
+
85
+ num_layers = 19
86
+ converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
87
+
88
+ print("Saving Flux IP-Adapter in Diffusers format.")
89
+ safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
90
+
91
+ if vision:
92
+ model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
93
+ model.save_pretrained(f"{args.output_path}/image_encoder")
94
+
95
+
96
+ if __name__ == "__main__":
97
+ main(args)
diffusers/scripts/convert_hunyuandit_controlnet_to_diffusers.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+
5
+ from diffusers import HunyuanDiT2DControlNetModel
6
+
7
+
8
+ def main(args):
9
+ state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
10
+
11
+ if args.load_key != "none":
12
+ try:
13
+ state_dict = state_dict[args.load_key]
14
+ except KeyError:
15
+ raise KeyError(
16
+ f"{args.load_key} not found in the checkpoint."
17
+ "Please load from the following keys:{state_dict.keys()}"
18
+ )
19
+ device = "cuda"
20
+
21
+ model_config = HunyuanDiT2DControlNetModel.load_config(
22
+ "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
23
+ )
24
+ model_config["use_style_cond_and_image_meta_size"] = (
25
+ args.use_style_cond_and_image_meta_size
26
+ ) ### version <= v1.1: True; version >= v1.2: False
27
+ print(model_config)
28
+
29
+ for key in state_dict:
30
+ print("local:", key)
31
+
32
+ model = HunyuanDiT2DControlNetModel.from_config(model_config).to(device)
33
+
34
+ for key in model.state_dict():
35
+ print("diffusers:", key)
36
+
37
+ num_layers = 19
38
+ for i in range(num_layers):
39
+ # attn1
40
+ # Wkqv -> to_q, to_k, to_v
41
+ q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
42
+ q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
43
+ state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
44
+ state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
45
+ state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
46
+ state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
47
+ state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
48
+ state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
49
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
50
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
51
+
52
+ # q_norm, k_norm -> norm_q, norm_k
53
+ state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
54
+ state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
55
+ state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
56
+ state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
57
+
58
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
59
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
60
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
61
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
62
+
63
+ # out_proj -> to_out
64
+ state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
65
+ state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
66
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
67
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
68
+
69
+ # attn2
70
+ # kq_proj -> to_k, to_v
71
+ k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
72
+ k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
73
+ state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
74
+ state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
75
+ state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
76
+ state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
77
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
78
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
79
+
80
+ # q_proj -> to_q
81
+ state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
82
+ state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
83
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
84
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
85
+
86
+ # q_norm, k_norm -> norm_q, norm_k
87
+ state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
88
+ state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
89
+ state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
90
+ state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
91
+
92
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
93
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
94
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
95
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
96
+
97
+ # out_proj -> to_out
98
+ state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
99
+ state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
100
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
101
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
102
+
103
+ # switch norm 2 and norm 3
104
+ norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
105
+ norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
106
+ state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
107
+ state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
108
+ state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
109
+ state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
110
+
111
+ # norm1 -> norm1.norm
112
+ # default_modulation.1 -> norm1.linear
113
+ state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]
114
+ state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]
115
+ state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]
116
+ state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]
117
+ state_dict.pop(f"blocks.{i}.norm1.weight")
118
+ state_dict.pop(f"blocks.{i}.norm1.bias")
119
+ state_dict.pop(f"blocks.{i}.default_modulation.1.weight")
120
+ state_dict.pop(f"blocks.{i}.default_modulation.1.bias")
121
+
122
+ # mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2
123
+ state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]
124
+ state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]
125
+ state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]
126
+ state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]
127
+ state_dict.pop(f"blocks.{i}.mlp.fc1.weight")
128
+ state_dict.pop(f"blocks.{i}.mlp.fc1.bias")
129
+ state_dict.pop(f"blocks.{i}.mlp.fc2.weight")
130
+ state_dict.pop(f"blocks.{i}.mlp.fc2.bias")
131
+
132
+ # after_proj_list -> controlnet_blocks
133
+ state_dict[f"controlnet_blocks.{i}.weight"] = state_dict[f"after_proj_list.{i}.weight"]
134
+ state_dict[f"controlnet_blocks.{i}.bias"] = state_dict[f"after_proj_list.{i}.bias"]
135
+ state_dict.pop(f"after_proj_list.{i}.weight")
136
+ state_dict.pop(f"after_proj_list.{i}.bias")
137
+
138
+ # before_proj -> input_block
139
+ state_dict["input_block.weight"] = state_dict["before_proj.weight"]
140
+ state_dict["input_block.bias"] = state_dict["before_proj.bias"]
141
+ state_dict.pop("before_proj.weight")
142
+ state_dict.pop("before_proj.bias")
143
+
144
+ # pooler -> time_extra_emb
145
+ state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
146
+ state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
147
+ state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
148
+ state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
149
+ state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
150
+ state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
151
+ state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
152
+ state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
153
+ state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
154
+ state_dict.pop("pooler.k_proj.weight")
155
+ state_dict.pop("pooler.k_proj.bias")
156
+ state_dict.pop("pooler.q_proj.weight")
157
+ state_dict.pop("pooler.q_proj.bias")
158
+ state_dict.pop("pooler.v_proj.weight")
159
+ state_dict.pop("pooler.v_proj.bias")
160
+ state_dict.pop("pooler.c_proj.weight")
161
+ state_dict.pop("pooler.c_proj.bias")
162
+ state_dict.pop("pooler.positional_embedding")
163
+
164
+ # t_embedder -> time_embedding (`TimestepEmbedding`)
165
+ state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
166
+ state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
167
+ state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
168
+ state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
169
+
170
+ state_dict.pop("t_embedder.mlp.0.bias")
171
+ state_dict.pop("t_embedder.mlp.0.weight")
172
+ state_dict.pop("t_embedder.mlp.2.bias")
173
+ state_dict.pop("t_embedder.mlp.2.weight")
174
+
175
+ # x_embedder -> pos_embd (`PatchEmbed`)
176
+ state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
177
+ state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
178
+ state_dict.pop("x_embedder.proj.weight")
179
+ state_dict.pop("x_embedder.proj.bias")
180
+
181
+ # mlp_t5 -> text_embedder
182
+ state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]
183
+ state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]
184
+ state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]
185
+ state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]
186
+ state_dict.pop("mlp_t5.0.bias")
187
+ state_dict.pop("mlp_t5.0.weight")
188
+ state_dict.pop("mlp_t5.2.bias")
189
+ state_dict.pop("mlp_t5.2.weight")
190
+
191
+ # extra_embedder -> extra_embedder
192
+ state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]
193
+ state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]
194
+ state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]
195
+ state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]
196
+ state_dict.pop("extra_embedder.0.bias")
197
+ state_dict.pop("extra_embedder.0.weight")
198
+ state_dict.pop("extra_embedder.2.bias")
199
+ state_dict.pop("extra_embedder.2.weight")
200
+
201
+ # style_embedder
202
+ if model_config["use_style_cond_and_image_meta_size"]:
203
+ print(state_dict["style_embedder.weight"])
204
+ print(state_dict["style_embedder.weight"].shape)
205
+ state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]
206
+ state_dict.pop("style_embedder.weight")
207
+
208
+ model.load_state_dict(state_dict)
209
+
210
+ if args.save:
211
+ model.save_pretrained(args.output_checkpoint_path)
212
+
213
+
214
+ if __name__ == "__main__":
215
+ parser = argparse.ArgumentParser()
216
+
217
+ parser.add_argument(
218
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
219
+ )
220
+ parser.add_argument(
221
+ "--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."
222
+ )
223
+ parser.add_argument(
224
+ "--output_checkpoint_path",
225
+ default=None,
226
+ type=str,
227
+ required=False,
228
+ help="Path to the output converted diffusers pipeline.",
229
+ )
230
+ parser.add_argument(
231
+ "--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"
232
+ )
233
+ parser.add_argument(
234
+ "--use_style_cond_and_image_meta_size",
235
+ type=bool,
236
+ default=False,
237
+ help="version <= v1.1: True; version >= v1.2: False",
238
+ )
239
+
240
+ args = parser.parse_args()
241
+ main(args)
diffusers/scripts/convert_i2vgen_to_diffusers.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Conversion script for the LDM checkpoints."""
16
+
17
+ import argparse
18
+
19
+ import torch
20
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
21
+
22
+ from diffusers import DDIMScheduler, I2VGenXLPipeline, I2VGenXLUNet, StableDiffusionPipeline
23
+
24
+
25
+ CLIP_ID = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
26
+
27
+
28
+ def assign_to_checkpoint(
29
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
30
+ ):
31
+ """
32
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
33
+ attention layers, and takes into account additional replacements that may arise.
34
+
35
+ Assigns the weights to the new checkpoint.
36
+ """
37
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
38
+
39
+ # Splits the attention layers into three variables.
40
+ if attention_paths_to_split is not None:
41
+ for path, path_map in attention_paths_to_split.items():
42
+ old_tensor = old_checkpoint[path]
43
+ channels = old_tensor.shape[0] // 3
44
+
45
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
46
+
47
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
48
+
49
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
50
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
51
+
52
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
53
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
54
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
55
+
56
+ for path in paths:
57
+ new_path = path["new"]
58
+
59
+ # These have already been assigned
60
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
61
+ continue
62
+
63
+ if additional_replacements is not None:
64
+ for replacement in additional_replacements:
65
+ new_path = new_path.replace(replacement["old"], replacement["new"])
66
+
67
+ # proj_attn.weight has to be converted from conv 1D to linear
68
+ weight = old_checkpoint[path["old"]]
69
+ names = ["proj_attn.weight"]
70
+ names_2 = ["proj_out.weight", "proj_in.weight"]
71
+ if any(k in new_path for k in names):
72
+ checkpoint[new_path] = weight[:, :, 0]
73
+ elif any(k in new_path for k in names_2) and len(weight.shape) > 2 and ".attentions." not in new_path:
74
+ checkpoint[new_path] = weight[:, :, 0]
75
+ else:
76
+ checkpoint[new_path] = weight
77
+
78
+
79
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
80
+ """
81
+ Updates paths inside attentions to the new naming scheme (local renaming)
82
+ """
83
+ mapping = []
84
+ for old_item in old_list:
85
+ new_item = old_item
86
+ mapping.append({"old": old_item, "new": new_item})
87
+
88
+ return mapping
89
+
90
+
91
+ def shave_segments(path, n_shave_prefix_segments=1):
92
+ """
93
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
94
+ """
95
+ if n_shave_prefix_segments >= 0:
96
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
97
+ else:
98
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
99
+
100
+
101
+ def renew_temp_conv_paths(old_list, n_shave_prefix_segments=0):
102
+ """
103
+ Updates paths inside resnets to the new naming scheme (local renaming)
104
+ """
105
+ mapping = []
106
+ for old_item in old_list:
107
+ mapping.append({"old": old_item, "new": old_item})
108
+
109
+ return mapping
110
+
111
+
112
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
113
+ """
114
+ Updates paths inside resnets to the new naming scheme (local renaming)
115
+ """
116
+ mapping = []
117
+ for old_item in old_list:
118
+ new_item = old_item.replace("in_layers.0", "norm1")
119
+ new_item = new_item.replace("in_layers.2", "conv1")
120
+
121
+ new_item = new_item.replace("out_layers.0", "norm2")
122
+ new_item = new_item.replace("out_layers.3", "conv2")
123
+
124
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
125
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
126
+
127
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
128
+
129
+ if "temopral_conv" not in old_item:
130
+ mapping.append({"old": old_item, "new": new_item})
131
+
132
+ return mapping
133
+
134
+
135
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
136
+ """
137
+ Takes a state dict and a config, and returns a converted checkpoint.
138
+ """
139
+
140
+ # extract state_dict for UNet
141
+ unet_state_dict = {}
142
+ keys = list(checkpoint.keys())
143
+
144
+ unet_key = "model.diffusion_model."
145
+
146
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
147
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
148
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
149
+ print(
150
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
151
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
152
+ )
153
+ for key in keys:
154
+ if key.startswith("model.diffusion_model"):
155
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
156
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
157
+ else:
158
+ if sum(k.startswith("model_ema") for k in keys) > 100:
159
+ print(
160
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
161
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
162
+ )
163
+
164
+ for key in keys:
165
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
166
+
167
+ new_checkpoint = {}
168
+
169
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
170
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
171
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
172
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
173
+
174
+ additional_embedding_substrings = [
175
+ "local_image_concat",
176
+ "context_embedding",
177
+ "local_image_embedding",
178
+ "fps_embedding",
179
+ ]
180
+ for k in unet_state_dict:
181
+ if any(substring in k for substring in additional_embedding_substrings):
182
+ diffusers_key = k.replace("local_image_concat", "image_latents_proj_in").replace(
183
+ "local_image_embedding", "image_latents_context_embedding"
184
+ )
185
+ new_checkpoint[diffusers_key] = unet_state_dict[k]
186
+
187
+ # temporal encoder.
188
+ new_checkpoint["image_latents_temporal_encoder.norm1.weight"] = unet_state_dict[
189
+ "local_temporal_encoder.layers.0.0.norm.weight"
190
+ ]
191
+ new_checkpoint["image_latents_temporal_encoder.norm1.bias"] = unet_state_dict[
192
+ "local_temporal_encoder.layers.0.0.norm.bias"
193
+ ]
194
+
195
+ # attention
196
+ qkv = unet_state_dict["local_temporal_encoder.layers.0.0.fn.to_qkv.weight"]
197
+ q, k, v = torch.chunk(qkv, 3, dim=0)
198
+ new_checkpoint["image_latents_temporal_encoder.attn1.to_q.weight"] = q
199
+ new_checkpoint["image_latents_temporal_encoder.attn1.to_k.weight"] = k
200
+ new_checkpoint["image_latents_temporal_encoder.attn1.to_v.weight"] = v
201
+ new_checkpoint["image_latents_temporal_encoder.attn1.to_out.0.weight"] = unet_state_dict[
202
+ "local_temporal_encoder.layers.0.0.fn.to_out.0.weight"
203
+ ]
204
+ new_checkpoint["image_latents_temporal_encoder.attn1.to_out.0.bias"] = unet_state_dict[
205
+ "local_temporal_encoder.layers.0.0.fn.to_out.0.bias"
206
+ ]
207
+
208
+ # feedforward
209
+ new_checkpoint["image_latents_temporal_encoder.ff.net.0.proj.weight"] = unet_state_dict[
210
+ "local_temporal_encoder.layers.0.1.net.0.0.weight"
211
+ ]
212
+ new_checkpoint["image_latents_temporal_encoder.ff.net.0.proj.bias"] = unet_state_dict[
213
+ "local_temporal_encoder.layers.0.1.net.0.0.bias"
214
+ ]
215
+ new_checkpoint["image_latents_temporal_encoder.ff.net.2.weight"] = unet_state_dict[
216
+ "local_temporal_encoder.layers.0.1.net.2.weight"
217
+ ]
218
+ new_checkpoint["image_latents_temporal_encoder.ff.net.2.bias"] = unet_state_dict[
219
+ "local_temporal_encoder.layers.0.1.net.2.bias"
220
+ ]
221
+
222
+ if "class_embed_type" in config:
223
+ if config["class_embed_type"] is None:
224
+ # No parameters to port
225
+ ...
226
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
227
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
228
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
229
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
230
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
231
+ else:
232
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
233
+
234
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
235
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
236
+
237
+ first_temp_attention = [v for v in unet_state_dict if v.startswith("input_blocks.0.1")]
238
+ paths = renew_attention_paths(first_temp_attention)
239
+ meta_path = {"old": "input_blocks.0.1", "new": "transformer_in"}
240
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
241
+
242
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
243
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
244
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
245
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
246
+
247
+ # Retrieves the keys for the input blocks only
248
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
249
+ input_blocks = {
250
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
251
+ for layer_id in range(num_input_blocks)
252
+ }
253
+
254
+ # Retrieves the keys for the middle blocks only
255
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
256
+ middle_blocks = {
257
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
258
+ for layer_id in range(num_middle_blocks)
259
+ }
260
+
261
+ # Retrieves the keys for the output blocks only
262
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
263
+ output_blocks = {
264
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
265
+ for layer_id in range(num_output_blocks)
266
+ }
267
+
268
+ for i in range(1, num_input_blocks):
269
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
270
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
271
+
272
+ resnets = [
273
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
274
+ ]
275
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
276
+ temp_attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.2" in key]
277
+
278
+ if f"input_blocks.{i}.op.weight" in unet_state_dict:
279
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
280
+ f"input_blocks.{i}.op.weight"
281
+ )
282
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
283
+ f"input_blocks.{i}.op.bias"
284
+ )
285
+
286
+ paths = renew_resnet_paths(resnets)
287
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
288
+ assign_to_checkpoint(
289
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
290
+ )
291
+
292
+ temporal_convs = [key for key in resnets if "temopral_conv" in key]
293
+ paths = renew_temp_conv_paths(temporal_convs)
294
+ meta_path = {
295
+ "old": f"input_blocks.{i}.0.temopral_conv",
296
+ "new": f"down_blocks.{block_id}.temp_convs.{layer_in_block_id}",
297
+ }
298
+ assign_to_checkpoint(
299
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
300
+ )
301
+
302
+ if len(attentions):
303
+ paths = renew_attention_paths(attentions)
304
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
305
+ assign_to_checkpoint(
306
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
307
+ )
308
+
309
+ if len(temp_attentions):
310
+ paths = renew_attention_paths(temp_attentions)
311
+ meta_path = {
312
+ "old": f"input_blocks.{i}.2",
313
+ "new": f"down_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
314
+ }
315
+ assign_to_checkpoint(
316
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
317
+ )
318
+
319
+ resnet_0 = middle_blocks[0]
320
+ temporal_convs_0 = [key for key in resnet_0 if "temopral_conv" in key]
321
+ attentions = middle_blocks[1]
322
+ temp_attentions = middle_blocks[2]
323
+ resnet_1 = middle_blocks[3]
324
+ temporal_convs_1 = [key for key in resnet_1 if "temopral_conv" in key]
325
+
326
+ resnet_0_paths = renew_resnet_paths(resnet_0)
327
+ meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"}
328
+ assign_to_checkpoint(
329
+ resnet_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
330
+ )
331
+
332
+ temp_conv_0_paths = renew_temp_conv_paths(temporal_convs_0)
333
+ meta_path = {"old": "middle_block.0.temopral_conv", "new": "mid_block.temp_convs.0"}
334
+ assign_to_checkpoint(
335
+ temp_conv_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
336
+ )
337
+
338
+ resnet_1_paths = renew_resnet_paths(resnet_1)
339
+ meta_path = {"old": "middle_block.3", "new": "mid_block.resnets.1"}
340
+ assign_to_checkpoint(
341
+ resnet_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
342
+ )
343
+
344
+ temp_conv_1_paths = renew_temp_conv_paths(temporal_convs_1)
345
+ meta_path = {"old": "middle_block.3.temopral_conv", "new": "mid_block.temp_convs.1"}
346
+ assign_to_checkpoint(
347
+ temp_conv_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
348
+ )
349
+
350
+ attentions_paths = renew_attention_paths(attentions)
351
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
352
+ assign_to_checkpoint(
353
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
354
+ )
355
+
356
+ temp_attentions_paths = renew_attention_paths(temp_attentions)
357
+ meta_path = {"old": "middle_block.2", "new": "mid_block.temp_attentions.0"}
358
+ assign_to_checkpoint(
359
+ temp_attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
360
+ )
361
+
362
+ for i in range(num_output_blocks):
363
+ block_id = i // (config["layers_per_block"] + 1)
364
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
365
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
366
+ output_block_list = {}
367
+
368
+ for layer in output_block_layers:
369
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
370
+ if layer_id in output_block_list:
371
+ output_block_list[layer_id].append(layer_name)
372
+ else:
373
+ output_block_list[layer_id] = [layer_name]
374
+
375
+ if len(output_block_list) > 1:
376
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
377
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
378
+ temp_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key]
379
+
380
+ resnet_0_paths = renew_resnet_paths(resnets)
381
+ paths = renew_resnet_paths(resnets)
382
+
383
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
384
+ assign_to_checkpoint(
385
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
386
+ )
387
+
388
+ temporal_convs = [key for key in resnets if "temopral_conv" in key]
389
+ paths = renew_temp_conv_paths(temporal_convs)
390
+ meta_path = {
391
+ "old": f"output_blocks.{i}.0.temopral_conv",
392
+ "new": f"up_blocks.{block_id}.temp_convs.{layer_in_block_id}",
393
+ }
394
+ assign_to_checkpoint(
395
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
396
+ )
397
+
398
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
399
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
400
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
401
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
402
+ f"output_blocks.{i}.{index}.conv.weight"
403
+ ]
404
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
405
+ f"output_blocks.{i}.{index}.conv.bias"
406
+ ]
407
+
408
+ # Clear attentions as they have been attributed above.
409
+ if len(attentions) == 2:
410
+ attentions = []
411
+
412
+ if len(attentions):
413
+ paths = renew_attention_paths(attentions)
414
+ meta_path = {
415
+ "old": f"output_blocks.{i}.1",
416
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
417
+ }
418
+ assign_to_checkpoint(
419
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
420
+ )
421
+
422
+ if len(temp_attentions):
423
+ paths = renew_attention_paths(temp_attentions)
424
+ meta_path = {
425
+ "old": f"output_blocks.{i}.2",
426
+ "new": f"up_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
427
+ }
428
+ assign_to_checkpoint(
429
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
430
+ )
431
+ else:
432
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
433
+ for path in resnet_0_paths:
434
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
435
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
436
+ new_checkpoint[new_path] = unet_state_dict[old_path]
437
+
438
+ temopral_conv_paths = [l for l in output_block_layers if "temopral_conv" in l]
439
+ for path in temopral_conv_paths:
440
+ pruned_path = path.split("temopral_conv.")[-1]
441
+ old_path = ".".join(["output_blocks", str(i), str(block_id), "temopral_conv", pruned_path])
442
+ new_path = ".".join(["up_blocks", str(block_id), "temp_convs", str(layer_in_block_id), pruned_path])
443
+ new_checkpoint[new_path] = unet_state_dict[old_path]
444
+
445
+ return new_checkpoint
446
+
447
+
448
+ if __name__ == "__main__":
449
+ parser = argparse.ArgumentParser()
450
+
451
+ parser.add_argument(
452
+ "--unet_checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
453
+ )
454
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
455
+ parser.add_argument("--push_to_hub", action="store_true")
456
+ args = parser.parse_args()
457
+
458
+ # UNet
459
+ unet_checkpoint = torch.load(args.unet_checkpoint_path, map_location="cpu")
460
+ unet_checkpoint = unet_checkpoint["state_dict"]
461
+ unet = I2VGenXLUNet(sample_size=32)
462
+
463
+ converted_ckpt = convert_ldm_unet_checkpoint(unet_checkpoint, unet.config)
464
+
465
+ diff_0 = set(unet.state_dict().keys()) - set(converted_ckpt.keys())
466
+ diff_1 = set(converted_ckpt.keys()) - set(unet.state_dict().keys())
467
+
468
+ assert len(diff_0) == len(diff_1) == 0, "Converted weights don't match"
469
+
470
+ unet.load_state_dict(converted_ckpt, strict=True)
471
+
472
+ # vae
473
+ temp_pipe = StableDiffusionPipeline.from_single_file(
474
+ "https://huggingface.co/ali-vilab/i2vgen-xl/blob/main/models/v2-1_512-ema-pruned.ckpt"
475
+ )
476
+ vae = temp_pipe.vae
477
+ del temp_pipe
478
+
479
+ # text encoder and tokenizer
480
+ text_encoder = CLIPTextModel.from_pretrained(CLIP_ID)
481
+ tokenizer = CLIPTokenizer.from_pretrained(CLIP_ID)
482
+
483
+ # image encoder and feature extractor
484
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(CLIP_ID)
485
+ feature_extractor = CLIPImageProcessor.from_pretrained(CLIP_ID)
486
+
487
+ # scheduler
488
+ # https://github.com/ali-vilab/i2vgen-xl/blob/main/configs/i2vgen_xl_train.yaml
489
+ scheduler = DDIMScheduler(
490
+ beta_schedule="squaredcos_cap_v2",
491
+ rescale_betas_zero_snr=True,
492
+ set_alpha_to_one=True,
493
+ clip_sample=False,
494
+ steps_offset=1,
495
+ timestep_spacing="leading",
496
+ prediction_type="v_prediction",
497
+ )
498
+
499
+ # final
500
+ pipeline = I2VGenXLPipeline(
501
+ unet=unet,
502
+ vae=vae,
503
+ image_encoder=image_encoder,
504
+ feature_extractor=feature_extractor,
505
+ text_encoder=text_encoder,
506
+ tokenizer=tokenizer,
507
+ scheduler=scheduler,
508
+ )
509
+
510
+ pipeline.save_pretrained(args.dump_path, push_to_hub=args.push_to_hub)
diffusers/scripts/convert_if.py ADDED
@@ -0,0 +1,1250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import inspect
3
+ import os
4
+
5
+ import numpy as np
6
+ import torch
7
+ import yaml
8
+ from torch.nn import functional as F
9
+ from transformers import CLIPConfig, CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer
10
+
11
+ from diffusers import DDPMScheduler, IFPipeline, IFSuperResolutionPipeline, UNet2DConditionModel
12
+ from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser()
17
+
18
+ parser.add_argument("--dump_path", required=False, default=None, type=str)
19
+
20
+ parser.add_argument("--dump_path_stage_2", required=False, default=None, type=str)
21
+
22
+ parser.add_argument("--dump_path_stage_3", required=False, default=None, type=str)
23
+
24
+ parser.add_argument("--unet_config", required=False, default=None, type=str, help="Path to unet config file")
25
+
26
+ parser.add_argument(
27
+ "--unet_checkpoint_path", required=False, default=None, type=str, help="Path to unet checkpoint file"
28
+ )
29
+
30
+ parser.add_argument(
31
+ "--unet_checkpoint_path_stage_2",
32
+ required=False,
33
+ default=None,
34
+ type=str,
35
+ help="Path to stage 2 unet checkpoint file",
36
+ )
37
+
38
+ parser.add_argument(
39
+ "--unet_checkpoint_path_stage_3",
40
+ required=False,
41
+ default=None,
42
+ type=str,
43
+ help="Path to stage 3 unet checkpoint file",
44
+ )
45
+
46
+ parser.add_argument("--p_head_path", type=str, required=True)
47
+
48
+ parser.add_argument("--w_head_path", type=str, required=True)
49
+
50
+ args = parser.parse_args()
51
+
52
+ return args
53
+
54
+
55
+ def main(args):
56
+ tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
57
+ text_encoder = T5EncoderModel.from_pretrained("google/t5-v1_1-xxl")
58
+
59
+ feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
60
+ safety_checker = convert_safety_checker(p_head_path=args.p_head_path, w_head_path=args.w_head_path)
61
+
62
+ if args.unet_config is not None and args.unet_checkpoint_path is not None and args.dump_path is not None:
63
+ convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args)
64
+
65
+ if args.unet_checkpoint_path_stage_2 is not None and args.dump_path_stage_2 is not None:
66
+ convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=2)
67
+
68
+ if args.unet_checkpoint_path_stage_3 is not None and args.dump_path_stage_3 is not None:
69
+ convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=3)
70
+
71
+
72
+ def convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args):
73
+ unet = get_stage_1_unet(args.unet_config, args.unet_checkpoint_path)
74
+
75
+ scheduler = DDPMScheduler(
76
+ variance_type="learned_range",
77
+ beta_schedule="squaredcos_cap_v2",
78
+ prediction_type="epsilon",
79
+ thresholding=True,
80
+ dynamic_thresholding_ratio=0.95,
81
+ sample_max_value=1.5,
82
+ )
83
+
84
+ pipe = IFPipeline(
85
+ tokenizer=tokenizer,
86
+ text_encoder=text_encoder,
87
+ unet=unet,
88
+ scheduler=scheduler,
89
+ safety_checker=safety_checker,
90
+ feature_extractor=feature_extractor,
91
+ requires_safety_checker=True,
92
+ )
93
+
94
+ pipe.save_pretrained(args.dump_path)
95
+
96
+
97
+ def convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage):
98
+ if stage == 2:
99
+ unet_checkpoint_path = args.unet_checkpoint_path_stage_2
100
+ sample_size = None
101
+ dump_path = args.dump_path_stage_2
102
+ elif stage == 3:
103
+ unet_checkpoint_path = args.unet_checkpoint_path_stage_3
104
+ sample_size = 1024
105
+ dump_path = args.dump_path_stage_3
106
+ else:
107
+ assert False
108
+
109
+ unet = get_super_res_unet(unet_checkpoint_path, verify_param_count=False, sample_size=sample_size)
110
+
111
+ image_noising_scheduler = DDPMScheduler(
112
+ beta_schedule="squaredcos_cap_v2",
113
+ )
114
+
115
+ scheduler = DDPMScheduler(
116
+ variance_type="learned_range",
117
+ beta_schedule="squaredcos_cap_v2",
118
+ prediction_type="epsilon",
119
+ thresholding=True,
120
+ dynamic_thresholding_ratio=0.95,
121
+ sample_max_value=1.0,
122
+ )
123
+
124
+ pipe = IFSuperResolutionPipeline(
125
+ tokenizer=tokenizer,
126
+ text_encoder=text_encoder,
127
+ unet=unet,
128
+ scheduler=scheduler,
129
+ image_noising_scheduler=image_noising_scheduler,
130
+ safety_checker=safety_checker,
131
+ feature_extractor=feature_extractor,
132
+ requires_safety_checker=True,
133
+ )
134
+
135
+ pipe.save_pretrained(dump_path)
136
+
137
+
138
+ def get_stage_1_unet(unet_config, unet_checkpoint_path):
139
+ original_unet_config = yaml.safe_load(unet_config)
140
+ original_unet_config = original_unet_config["params"]
141
+
142
+ unet_diffusers_config = create_unet_diffusers_config(original_unet_config)
143
+
144
+ unet = UNet2DConditionModel(**unet_diffusers_config)
145
+
146
+ device = "cuda" if torch.cuda.is_available() else "cpu"
147
+ unet_checkpoint = torch.load(unet_checkpoint_path, map_location=device)
148
+
149
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
150
+ unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path
151
+ )
152
+
153
+ unet.load_state_dict(converted_unet_checkpoint)
154
+
155
+ return unet
156
+
157
+
158
+ def convert_safety_checker(p_head_path, w_head_path):
159
+ state_dict = {}
160
+
161
+ # p head
162
+
163
+ p_head = np.load(p_head_path)
164
+
165
+ p_head_weights = p_head["weights"]
166
+ p_head_weights = torch.from_numpy(p_head_weights)
167
+ p_head_weights = p_head_weights.unsqueeze(0)
168
+
169
+ p_head_biases = p_head["biases"]
170
+ p_head_biases = torch.from_numpy(p_head_biases)
171
+ p_head_biases = p_head_biases.unsqueeze(0)
172
+
173
+ state_dict["p_head.weight"] = p_head_weights
174
+ state_dict["p_head.bias"] = p_head_biases
175
+
176
+ # w head
177
+
178
+ w_head = np.load(w_head_path)
179
+
180
+ w_head_weights = w_head["weights"]
181
+ w_head_weights = torch.from_numpy(w_head_weights)
182
+ w_head_weights = w_head_weights.unsqueeze(0)
183
+
184
+ w_head_biases = w_head["biases"]
185
+ w_head_biases = torch.from_numpy(w_head_biases)
186
+ w_head_biases = w_head_biases.unsqueeze(0)
187
+
188
+ state_dict["w_head.weight"] = w_head_weights
189
+ state_dict["w_head.bias"] = w_head_biases
190
+
191
+ # vision model
192
+
193
+ vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
194
+ vision_model_state_dict = vision_model.state_dict()
195
+
196
+ for key, value in vision_model_state_dict.items():
197
+ key = f"vision_model.{key}"
198
+ state_dict[key] = value
199
+
200
+ # full model
201
+
202
+ config = CLIPConfig.from_pretrained("openai/clip-vit-large-patch14")
203
+ safety_checker = IFSafetyChecker(config)
204
+
205
+ safety_checker.load_state_dict(state_dict)
206
+
207
+ return safety_checker
208
+
209
+
210
+ def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
211
+ attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
212
+ attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
213
+
214
+ channel_mult = parse_list(original_unet_config["channel_mult"])
215
+ block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
216
+
217
+ down_block_types = []
218
+ resolution = 1
219
+
220
+ for i in range(len(block_out_channels)):
221
+ if resolution in attention_resolutions:
222
+ block_type = "SimpleCrossAttnDownBlock2D"
223
+ elif original_unet_config["resblock_updown"]:
224
+ block_type = "ResnetDownsampleBlock2D"
225
+ else:
226
+ block_type = "DownBlock2D"
227
+
228
+ down_block_types.append(block_type)
229
+
230
+ if i != len(block_out_channels) - 1:
231
+ resolution *= 2
232
+
233
+ up_block_types = []
234
+ for i in range(len(block_out_channels)):
235
+ if resolution in attention_resolutions:
236
+ block_type = "SimpleCrossAttnUpBlock2D"
237
+ elif original_unet_config["resblock_updown"]:
238
+ block_type = "ResnetUpsampleBlock2D"
239
+ else:
240
+ block_type = "UpBlock2D"
241
+ up_block_types.append(block_type)
242
+ resolution //= 2
243
+
244
+ head_dim = original_unet_config["num_head_channels"]
245
+
246
+ use_linear_projection = (
247
+ original_unet_config["use_linear_in_transformer"]
248
+ if "use_linear_in_transformer" in original_unet_config
249
+ else False
250
+ )
251
+ if use_linear_projection:
252
+ # stable diffusion 2-base-512 and 2-768
253
+ if head_dim is None:
254
+ head_dim = [5, 10, 20, 20]
255
+
256
+ projection_class_embeddings_input_dim = None
257
+
258
+ if class_embed_type is None:
259
+ if "num_classes" in original_unet_config:
260
+ if original_unet_config["num_classes"] == "sequential":
261
+ class_embed_type = "projection"
262
+ assert "adm_in_channels" in original_unet_config
263
+ projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
264
+ else:
265
+ raise NotImplementedError(
266
+ f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
267
+ )
268
+
269
+ config = {
270
+ "sample_size": original_unet_config["image_size"],
271
+ "in_channels": original_unet_config["in_channels"],
272
+ "down_block_types": tuple(down_block_types),
273
+ "block_out_channels": tuple(block_out_channels),
274
+ "layers_per_block": original_unet_config["num_res_blocks"],
275
+ "cross_attention_dim": original_unet_config["encoder_channels"],
276
+ "attention_head_dim": head_dim,
277
+ "use_linear_projection": use_linear_projection,
278
+ "class_embed_type": class_embed_type,
279
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
280
+ "out_channels": original_unet_config["out_channels"],
281
+ "up_block_types": tuple(up_block_types),
282
+ "upcast_attention": False, # TODO: guessing
283
+ "cross_attention_norm": "group_norm",
284
+ "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
285
+ "addition_embed_type": "text",
286
+ "act_fn": "gelu",
287
+ }
288
+
289
+ if original_unet_config["use_scale_shift_norm"]:
290
+ config["resnet_time_scale_shift"] = "scale_shift"
291
+
292
+ if "encoder_dim" in original_unet_config:
293
+ config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
294
+
295
+ return config
296
+
297
+
298
+ def convert_ldm_unet_checkpoint(unet_state_dict, config, path=None):
299
+ """
300
+ Takes a state dict and a config, and returns a converted checkpoint.
301
+ """
302
+ new_checkpoint = {}
303
+
304
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
305
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
306
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
307
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
308
+
309
+ if config["class_embed_type"] in [None, "identity"]:
310
+ # No parameters to port
311
+ ...
312
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
313
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
314
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
315
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
316
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
317
+ else:
318
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
319
+
320
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
321
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
322
+
323
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
324
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
325
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
326
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
327
+
328
+ # Retrieves the keys for the input blocks only
329
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
330
+ input_blocks = {
331
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
332
+ for layer_id in range(num_input_blocks)
333
+ }
334
+
335
+ # Retrieves the keys for the middle blocks only
336
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
337
+ middle_blocks = {
338
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
339
+ for layer_id in range(num_middle_blocks)
340
+ }
341
+
342
+ # Retrieves the keys for the output blocks only
343
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
344
+ output_blocks = {
345
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
346
+ for layer_id in range(num_output_blocks)
347
+ }
348
+
349
+ for i in range(1, num_input_blocks):
350
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
351
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
352
+
353
+ resnets = [
354
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
355
+ ]
356
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
357
+
358
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
359
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
360
+ f"input_blocks.{i}.0.op.weight"
361
+ )
362
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
363
+ f"input_blocks.{i}.0.op.bias"
364
+ )
365
+
366
+ paths = renew_resnet_paths(resnets)
367
+
368
+ # TODO need better check than i in [4, 8, 12, 16]
369
+ block_type = config["down_block_types"][block_id]
370
+ if (block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D") and i in [
371
+ 4,
372
+ 8,
373
+ 12,
374
+ 16,
375
+ ]:
376
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"}
377
+ else:
378
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
379
+
380
+ assign_to_checkpoint(
381
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
382
+ )
383
+
384
+ if len(attentions):
385
+ old_path = f"input_blocks.{i}.1"
386
+ new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
387
+
388
+ assign_attention_to_checkpoint(
389
+ new_checkpoint=new_checkpoint,
390
+ unet_state_dict=unet_state_dict,
391
+ old_path=old_path,
392
+ new_path=new_path,
393
+ config=config,
394
+ )
395
+
396
+ paths = renew_attention_paths(attentions)
397
+ meta_path = {"old": old_path, "new": new_path}
398
+ assign_to_checkpoint(
399
+ paths,
400
+ new_checkpoint,
401
+ unet_state_dict,
402
+ additional_replacements=[meta_path],
403
+ config=config,
404
+ )
405
+
406
+ resnet_0 = middle_blocks[0]
407
+ attentions = middle_blocks[1]
408
+ resnet_1 = middle_blocks[2]
409
+
410
+ resnet_0_paths = renew_resnet_paths(resnet_0)
411
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
412
+
413
+ resnet_1_paths = renew_resnet_paths(resnet_1)
414
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
415
+
416
+ old_path = "middle_block.1"
417
+ new_path = "mid_block.attentions.0"
418
+
419
+ assign_attention_to_checkpoint(
420
+ new_checkpoint=new_checkpoint,
421
+ unet_state_dict=unet_state_dict,
422
+ old_path=old_path,
423
+ new_path=new_path,
424
+ config=config,
425
+ )
426
+
427
+ attentions_paths = renew_attention_paths(attentions)
428
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
429
+ assign_to_checkpoint(
430
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
431
+ )
432
+
433
+ for i in range(num_output_blocks):
434
+ block_id = i // (config["layers_per_block"] + 1)
435
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
436
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
437
+ output_block_list = {}
438
+
439
+ for layer in output_block_layers:
440
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
441
+ if layer_id in output_block_list:
442
+ output_block_list[layer_id].append(layer_name)
443
+ else:
444
+ output_block_list[layer_id] = [layer_name]
445
+
446
+ # len(output_block_list) == 1 -> resnet
447
+ # len(output_block_list) == 2 -> resnet, attention
448
+ # len(output_block_list) == 3 -> resnet, attention, upscale resnet
449
+
450
+ if len(output_block_list) > 1:
451
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
452
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
453
+
454
+ paths = renew_resnet_paths(resnets)
455
+
456
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
457
+
458
+ assign_to_checkpoint(
459
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
460
+ )
461
+
462
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
463
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
464
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
465
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
466
+ f"output_blocks.{i}.{index}.conv.weight"
467
+ ]
468
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
469
+ f"output_blocks.{i}.{index}.conv.bias"
470
+ ]
471
+
472
+ # Clear attentions as they have been attributed above.
473
+ if len(attentions) == 2:
474
+ attentions = []
475
+
476
+ if len(attentions):
477
+ old_path = f"output_blocks.{i}.1"
478
+ new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}"
479
+
480
+ assign_attention_to_checkpoint(
481
+ new_checkpoint=new_checkpoint,
482
+ unet_state_dict=unet_state_dict,
483
+ old_path=old_path,
484
+ new_path=new_path,
485
+ config=config,
486
+ )
487
+
488
+ paths = renew_attention_paths(attentions)
489
+ meta_path = {
490
+ "old": old_path,
491
+ "new": new_path,
492
+ }
493
+ assign_to_checkpoint(
494
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
495
+ )
496
+
497
+ if len(output_block_list) == 3:
498
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key]
499
+ paths = renew_resnet_paths(resnets)
500
+ meta_path = {"old": f"output_blocks.{i}.2", "new": f"up_blocks.{block_id}.upsamplers.0"}
501
+ assign_to_checkpoint(
502
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
503
+ )
504
+ else:
505
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
506
+ for path in resnet_0_paths:
507
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
508
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
509
+
510
+ new_checkpoint[new_path] = unet_state_dict[old_path]
511
+
512
+ if "encoder_proj.weight" in unet_state_dict:
513
+ new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict.pop("encoder_proj.weight")
514
+ new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict.pop("encoder_proj.bias")
515
+
516
+ if "encoder_pooling.0.weight" in unet_state_dict:
517
+ new_checkpoint["add_embedding.norm1.weight"] = unet_state_dict.pop("encoder_pooling.0.weight")
518
+ new_checkpoint["add_embedding.norm1.bias"] = unet_state_dict.pop("encoder_pooling.0.bias")
519
+
520
+ new_checkpoint["add_embedding.pool.positional_embedding"] = unet_state_dict.pop(
521
+ "encoder_pooling.1.positional_embedding"
522
+ )
523
+ new_checkpoint["add_embedding.pool.k_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.k_proj.weight")
524
+ new_checkpoint["add_embedding.pool.k_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.k_proj.bias")
525
+ new_checkpoint["add_embedding.pool.q_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.q_proj.weight")
526
+ new_checkpoint["add_embedding.pool.q_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.q_proj.bias")
527
+ new_checkpoint["add_embedding.pool.v_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.v_proj.weight")
528
+ new_checkpoint["add_embedding.pool.v_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.v_proj.bias")
529
+
530
+ new_checkpoint["add_embedding.proj.weight"] = unet_state_dict.pop("encoder_pooling.2.weight")
531
+ new_checkpoint["add_embedding.proj.bias"] = unet_state_dict.pop("encoder_pooling.2.bias")
532
+
533
+ new_checkpoint["add_embedding.norm2.weight"] = unet_state_dict.pop("encoder_pooling.3.weight")
534
+ new_checkpoint["add_embedding.norm2.bias"] = unet_state_dict.pop("encoder_pooling.3.bias")
535
+
536
+ return new_checkpoint
537
+
538
+
539
+ def shave_segments(path, n_shave_prefix_segments=1):
540
+ """
541
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
542
+ """
543
+ if n_shave_prefix_segments >= 0:
544
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
545
+ else:
546
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
547
+
548
+
549
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
550
+ """
551
+ Updates paths inside resnets to the new naming scheme (local renaming)
552
+ """
553
+ mapping = []
554
+ for old_item in old_list:
555
+ new_item = old_item.replace("in_layers.0", "norm1")
556
+ new_item = new_item.replace("in_layers.2", "conv1")
557
+
558
+ new_item = new_item.replace("out_layers.0", "norm2")
559
+ new_item = new_item.replace("out_layers.3", "conv2")
560
+
561
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
562
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
563
+
564
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
565
+
566
+ mapping.append({"old": old_item, "new": new_item})
567
+
568
+ return mapping
569
+
570
+
571
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
572
+ """
573
+ Updates paths inside attentions to the new naming scheme (local renaming)
574
+ """
575
+ mapping = []
576
+ for old_item in old_list:
577
+ new_item = old_item
578
+
579
+ if "qkv" in new_item:
580
+ continue
581
+
582
+ if "encoder_kv" in new_item:
583
+ continue
584
+
585
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
586
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
587
+
588
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
589
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
590
+
591
+ new_item = new_item.replace("norm_encoder.weight", "norm_cross.weight")
592
+ new_item = new_item.replace("norm_encoder.bias", "norm_cross.bias")
593
+
594
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
595
+
596
+ mapping.append({"old": old_item, "new": new_item})
597
+
598
+ return mapping
599
+
600
+
601
+ def assign_attention_to_checkpoint(new_checkpoint, unet_state_dict, old_path, new_path, config):
602
+ qkv_weight = unet_state_dict.pop(f"{old_path}.qkv.weight")
603
+ qkv_weight = qkv_weight[:, :, 0]
604
+
605
+ qkv_bias = unet_state_dict.pop(f"{old_path}.qkv.bias")
606
+
607
+ is_cross_attn_only = "only_cross_attention" in config and config["only_cross_attention"]
608
+
609
+ split = 1 if is_cross_attn_only else 3
610
+
611
+ weights, bias = split_attentions(
612
+ weight=qkv_weight,
613
+ bias=qkv_bias,
614
+ split=split,
615
+ chunk_size=config["attention_head_dim"],
616
+ )
617
+
618
+ if is_cross_attn_only:
619
+ query_weight, q_bias = weights, bias
620
+ new_checkpoint[f"{new_path}.to_q.weight"] = query_weight[0]
621
+ new_checkpoint[f"{new_path}.to_q.bias"] = q_bias[0]
622
+ else:
623
+ [query_weight, key_weight, value_weight], [q_bias, k_bias, v_bias] = weights, bias
624
+ new_checkpoint[f"{new_path}.to_q.weight"] = query_weight
625
+ new_checkpoint[f"{new_path}.to_q.bias"] = q_bias
626
+ new_checkpoint[f"{new_path}.to_k.weight"] = key_weight
627
+ new_checkpoint[f"{new_path}.to_k.bias"] = k_bias
628
+ new_checkpoint[f"{new_path}.to_v.weight"] = value_weight
629
+ new_checkpoint[f"{new_path}.to_v.bias"] = v_bias
630
+
631
+ encoder_kv_weight = unet_state_dict.pop(f"{old_path}.encoder_kv.weight")
632
+ encoder_kv_weight = encoder_kv_weight[:, :, 0]
633
+
634
+ encoder_kv_bias = unet_state_dict.pop(f"{old_path}.encoder_kv.bias")
635
+
636
+ [encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions(
637
+ weight=encoder_kv_weight,
638
+ bias=encoder_kv_bias,
639
+ split=2,
640
+ chunk_size=config["attention_head_dim"],
641
+ )
642
+
643
+ new_checkpoint[f"{new_path}.add_k_proj.weight"] = encoder_k_weight
644
+ new_checkpoint[f"{new_path}.add_k_proj.bias"] = encoder_k_bias
645
+ new_checkpoint[f"{new_path}.add_v_proj.weight"] = encoder_v_weight
646
+ new_checkpoint[f"{new_path}.add_v_proj.bias"] = encoder_v_bias
647
+
648
+
649
+ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, additional_replacements=None, config=None):
650
+ """
651
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
652
+ attention layers, and takes into account additional replacements that may arise.
653
+
654
+ Assigns the weights to the new checkpoint.
655
+ """
656
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
657
+
658
+ for path in paths:
659
+ new_path = path["new"]
660
+
661
+ # Global renaming happens here
662
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
663
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
664
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
665
+
666
+ if additional_replacements is not None:
667
+ for replacement in additional_replacements:
668
+ new_path = new_path.replace(replacement["old"], replacement["new"])
669
+
670
+ # proj_attn.weight has to be converted from conv 1D to linear
671
+ if "proj_attn.weight" in new_path or "to_out.0.weight" in new_path:
672
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
673
+ else:
674
+ checkpoint[new_path] = old_checkpoint[path["old"]]
675
+
676
+
677
+ # TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
678
+ def split_attentions(*, weight, bias, split, chunk_size):
679
+ weights = [None] * split
680
+ biases = [None] * split
681
+
682
+ weights_biases_idx = 0
683
+
684
+ for starting_row_index in range(0, weight.shape[0], chunk_size):
685
+ row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
686
+
687
+ weight_rows = weight[row_indices, :]
688
+ bias_rows = bias[row_indices]
689
+
690
+ if weights[weights_biases_idx] is None:
691
+ weights[weights_biases_idx] = weight_rows
692
+ biases[weights_biases_idx] = bias_rows
693
+ else:
694
+ assert weights[weights_biases_idx] is not None
695
+ weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
696
+ biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
697
+
698
+ weights_biases_idx = (weights_biases_idx + 1) % split
699
+
700
+ return weights, biases
701
+
702
+
703
+ def parse_list(value):
704
+ if isinstance(value, str):
705
+ value = value.split(",")
706
+ value = [int(v) for v in value]
707
+ elif isinstance(value, list):
708
+ pass
709
+ else:
710
+ raise ValueError(f"Can't parse list for type: {type(value)}")
711
+
712
+ return value
713
+
714
+
715
+ # below is copy and pasted from original convert_if_stage_2.py script
716
+
717
+
718
+ def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None):
719
+ orig_path = unet_checkpoint_path
720
+
721
+ original_unet_config = yaml.safe_load(os.path.join(orig_path, "config.yml"))
722
+ original_unet_config = original_unet_config["params"]
723
+
724
+ unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config)
725
+ unet_diffusers_config["time_embedding_dim"] = original_unet_config["model_channels"] * int(
726
+ original_unet_config["channel_mult"].split(",")[-1]
727
+ )
728
+ if original_unet_config["encoder_dim"] != original_unet_config["encoder_channels"]:
729
+ unet_diffusers_config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
730
+ unet_diffusers_config["class_embed_type"] = "timestep"
731
+ unet_diffusers_config["addition_embed_type"] = "text"
732
+
733
+ unet_diffusers_config["time_embedding_act_fn"] = "gelu"
734
+ unet_diffusers_config["resnet_skip_time_act"] = True
735
+ unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071
736
+ unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071
737
+ unet_diffusers_config["only_cross_attention"] = (
738
+ bool(original_unet_config["disable_self_attentions"])
739
+ if (
740
+ "disable_self_attentions" in original_unet_config
741
+ and isinstance(original_unet_config["disable_self_attentions"], int)
742
+ )
743
+ else True
744
+ )
745
+
746
+ if sample_size is None:
747
+ unet_diffusers_config["sample_size"] = original_unet_config["image_size"]
748
+ else:
749
+ # The second upscaler unet's sample size is incorrectly specified
750
+ # in the config and is instead hardcoded in source
751
+ unet_diffusers_config["sample_size"] = sample_size
752
+
753
+ unet_checkpoint = torch.load(os.path.join(unet_checkpoint_path, "pytorch_model.bin"), map_location="cpu")
754
+
755
+ if verify_param_count:
756
+ # check that architecture matches - is a bit slow
757
+ verify_param_count(orig_path, unet_diffusers_config)
758
+
759
+ converted_unet_checkpoint = superres_convert_ldm_unet_checkpoint(
760
+ unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path
761
+ )
762
+ converted_keys = converted_unet_checkpoint.keys()
763
+
764
+ model = UNet2DConditionModel(**unet_diffusers_config)
765
+ expected_weights = model.state_dict().keys()
766
+
767
+ diff_c_e = set(converted_keys) - set(expected_weights)
768
+ diff_e_c = set(expected_weights) - set(converted_keys)
769
+
770
+ assert len(diff_e_c) == 0, f"Expected, but not converted: {diff_e_c}"
771
+ assert len(diff_c_e) == 0, f"Converted, but not expected: {diff_c_e}"
772
+
773
+ model.load_state_dict(converted_unet_checkpoint)
774
+
775
+ return model
776
+
777
+
778
+ def superres_create_unet_diffusers_config(original_unet_config):
779
+ attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
780
+ attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
781
+
782
+ channel_mult = parse_list(original_unet_config["channel_mult"])
783
+ block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
784
+
785
+ down_block_types = []
786
+ resolution = 1
787
+
788
+ for i in range(len(block_out_channels)):
789
+ if resolution in attention_resolutions:
790
+ block_type = "SimpleCrossAttnDownBlock2D"
791
+ elif original_unet_config["resblock_updown"]:
792
+ block_type = "ResnetDownsampleBlock2D"
793
+ else:
794
+ block_type = "DownBlock2D"
795
+
796
+ down_block_types.append(block_type)
797
+
798
+ if i != len(block_out_channels) - 1:
799
+ resolution *= 2
800
+
801
+ up_block_types = []
802
+ for i in range(len(block_out_channels)):
803
+ if resolution in attention_resolutions:
804
+ block_type = "SimpleCrossAttnUpBlock2D"
805
+ elif original_unet_config["resblock_updown"]:
806
+ block_type = "ResnetUpsampleBlock2D"
807
+ else:
808
+ block_type = "UpBlock2D"
809
+ up_block_types.append(block_type)
810
+ resolution //= 2
811
+
812
+ head_dim = original_unet_config["num_head_channels"]
813
+ use_linear_projection = (
814
+ original_unet_config["use_linear_in_transformer"]
815
+ if "use_linear_in_transformer" in original_unet_config
816
+ else False
817
+ )
818
+ if use_linear_projection:
819
+ # stable diffusion 2-base-512 and 2-768
820
+ if head_dim is None:
821
+ head_dim = [5, 10, 20, 20]
822
+
823
+ class_embed_type = None
824
+ projection_class_embeddings_input_dim = None
825
+
826
+ if "num_classes" in original_unet_config:
827
+ if original_unet_config["num_classes"] == "sequential":
828
+ class_embed_type = "projection"
829
+ assert "adm_in_channels" in original_unet_config
830
+ projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
831
+ else:
832
+ raise NotImplementedError(
833
+ f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
834
+ )
835
+
836
+ config = {
837
+ "in_channels": original_unet_config["in_channels"],
838
+ "down_block_types": tuple(down_block_types),
839
+ "block_out_channels": tuple(block_out_channels),
840
+ "layers_per_block": tuple(original_unet_config["num_res_blocks"]),
841
+ "cross_attention_dim": original_unet_config["encoder_channels"],
842
+ "attention_head_dim": head_dim,
843
+ "use_linear_projection": use_linear_projection,
844
+ "class_embed_type": class_embed_type,
845
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
846
+ "out_channels": original_unet_config["out_channels"],
847
+ "up_block_types": tuple(up_block_types),
848
+ "upcast_attention": False, # TODO: guessing
849
+ "cross_attention_norm": "group_norm",
850
+ "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
851
+ "act_fn": "gelu",
852
+ }
853
+
854
+ if original_unet_config["use_scale_shift_norm"]:
855
+ config["resnet_time_scale_shift"] = "scale_shift"
856
+
857
+ return config
858
+
859
+
860
+ def superres_convert_ldm_unet_checkpoint(unet_state_dict, config, path=None, extract_ema=False):
861
+ """
862
+ Takes a state dict and a config, and returns a converted checkpoint.
863
+ """
864
+ new_checkpoint = {}
865
+
866
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
867
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
868
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
869
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
870
+
871
+ if config["class_embed_type"] is None:
872
+ # No parameters to port
873
+ ...
874
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
875
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["aug_proj.0.weight"]
876
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["aug_proj.0.bias"]
877
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["aug_proj.2.weight"]
878
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["aug_proj.2.bias"]
879
+ else:
880
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
881
+
882
+ if "encoder_proj.weight" in unet_state_dict:
883
+ new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict["encoder_proj.weight"]
884
+ new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict["encoder_proj.bias"]
885
+
886
+ if "encoder_pooling.0.weight" in unet_state_dict:
887
+ mapping = {
888
+ "encoder_pooling.0": "add_embedding.norm1",
889
+ "encoder_pooling.1": "add_embedding.pool",
890
+ "encoder_pooling.2": "add_embedding.proj",
891
+ "encoder_pooling.3": "add_embedding.norm2",
892
+ }
893
+ for key in unet_state_dict.keys():
894
+ if key.startswith("encoder_pooling"):
895
+ prefix = key[: len("encoder_pooling.0")]
896
+ new_key = key.replace(prefix, mapping[prefix])
897
+ new_checkpoint[new_key] = unet_state_dict[key]
898
+
899
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
900
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
901
+
902
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
903
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
904
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
905
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
906
+
907
+ # Retrieves the keys for the input blocks only
908
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
909
+ input_blocks = {
910
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
911
+ for layer_id in range(num_input_blocks)
912
+ }
913
+
914
+ # Retrieves the keys for the middle blocks only
915
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
916
+ middle_blocks = {
917
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
918
+ for layer_id in range(num_middle_blocks)
919
+ }
920
+
921
+ # Retrieves the keys for the output blocks only
922
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
923
+ output_blocks = {
924
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
925
+ for layer_id in range(num_output_blocks)
926
+ }
927
+ if not isinstance(config["layers_per_block"], int):
928
+ layers_per_block_list = [e + 1 for e in config["layers_per_block"]]
929
+ layers_per_block_cumsum = list(np.cumsum(layers_per_block_list))
930
+ downsampler_ids = layers_per_block_cumsum
931
+ else:
932
+ # TODO need better check than i in [4, 8, 12, 16]
933
+ downsampler_ids = [4, 8, 12, 16]
934
+
935
+ for i in range(1, num_input_blocks):
936
+ if isinstance(config["layers_per_block"], int):
937
+ layers_per_block = config["layers_per_block"]
938
+ block_id = (i - 1) // (layers_per_block + 1)
939
+ layer_in_block_id = (i - 1) % (layers_per_block + 1)
940
+ else:
941
+ block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if (i - 1) < n)
942
+ passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0
943
+ layer_in_block_id = (i - 1) - passed_blocks
944
+
945
+ resnets = [
946
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
947
+ ]
948
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
949
+
950
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
951
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
952
+ f"input_blocks.{i}.0.op.weight"
953
+ )
954
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
955
+ f"input_blocks.{i}.0.op.bias"
956
+ )
957
+
958
+ paths = renew_resnet_paths(resnets)
959
+
960
+ block_type = config["down_block_types"][block_id]
961
+ if (
962
+ block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D"
963
+ ) and i in downsampler_ids:
964
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"}
965
+ else:
966
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
967
+
968
+ assign_to_checkpoint(
969
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
970
+ )
971
+
972
+ if len(attentions):
973
+ old_path = f"input_blocks.{i}.1"
974
+ new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
975
+
976
+ assign_attention_to_checkpoint(
977
+ new_checkpoint=new_checkpoint,
978
+ unet_state_dict=unet_state_dict,
979
+ old_path=old_path,
980
+ new_path=new_path,
981
+ config=config,
982
+ )
983
+
984
+ paths = renew_attention_paths(attentions)
985
+ meta_path = {"old": old_path, "new": new_path}
986
+ assign_to_checkpoint(
987
+ paths,
988
+ new_checkpoint,
989
+ unet_state_dict,
990
+ additional_replacements=[meta_path],
991
+ config=config,
992
+ )
993
+
994
+ resnet_0 = middle_blocks[0]
995
+ attentions = middle_blocks[1]
996
+ resnet_1 = middle_blocks[2]
997
+
998
+ resnet_0_paths = renew_resnet_paths(resnet_0)
999
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
1000
+
1001
+ resnet_1_paths = renew_resnet_paths(resnet_1)
1002
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
1003
+
1004
+ old_path = "middle_block.1"
1005
+ new_path = "mid_block.attentions.0"
1006
+
1007
+ assign_attention_to_checkpoint(
1008
+ new_checkpoint=new_checkpoint,
1009
+ unet_state_dict=unet_state_dict,
1010
+ old_path=old_path,
1011
+ new_path=new_path,
1012
+ config=config,
1013
+ )
1014
+
1015
+ attentions_paths = renew_attention_paths(attentions)
1016
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
1017
+ assign_to_checkpoint(
1018
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
1019
+ )
1020
+ if not isinstance(config["layers_per_block"], int):
1021
+ layers_per_block_list = list(reversed([e + 1 for e in config["layers_per_block"]]))
1022
+ layers_per_block_cumsum = list(np.cumsum(layers_per_block_list))
1023
+
1024
+ for i in range(num_output_blocks):
1025
+ if isinstance(config["layers_per_block"], int):
1026
+ layers_per_block = config["layers_per_block"]
1027
+ block_id = i // (layers_per_block + 1)
1028
+ layer_in_block_id = i % (layers_per_block + 1)
1029
+ else:
1030
+ block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if i < n)
1031
+ passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0
1032
+ layer_in_block_id = i - passed_blocks
1033
+
1034
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
1035
+ output_block_list = {}
1036
+
1037
+ for layer in output_block_layers:
1038
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
1039
+ if layer_id in output_block_list:
1040
+ output_block_list[layer_id].append(layer_name)
1041
+ else:
1042
+ output_block_list[layer_id] = [layer_name]
1043
+
1044
+ # len(output_block_list) == 1 -> resnet
1045
+ # len(output_block_list) == 2 -> resnet, attention or resnet, upscale resnet
1046
+ # len(output_block_list) == 3 -> resnet, attention, upscale resnet
1047
+
1048
+ if len(output_block_list) > 1:
1049
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
1050
+
1051
+ has_attention = True
1052
+ if len(output_block_list) == 2 and any("in_layers" in k for k in output_block_list["1"]):
1053
+ has_attention = False
1054
+
1055
+ maybe_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
1056
+
1057
+ paths = renew_resnet_paths(resnets)
1058
+
1059
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
1060
+
1061
+ assign_to_checkpoint(
1062
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
1063
+ )
1064
+
1065
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
1066
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
1067
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
1068
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
1069
+ f"output_blocks.{i}.{index}.conv.weight"
1070
+ ]
1071
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
1072
+ f"output_blocks.{i}.{index}.conv.bias"
1073
+ ]
1074
+
1075
+ # this layer was no attention
1076
+ has_attention = False
1077
+ maybe_attentions = []
1078
+
1079
+ if has_attention:
1080
+ old_path = f"output_blocks.{i}.1"
1081
+ new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}"
1082
+
1083
+ assign_attention_to_checkpoint(
1084
+ new_checkpoint=new_checkpoint,
1085
+ unet_state_dict=unet_state_dict,
1086
+ old_path=old_path,
1087
+ new_path=new_path,
1088
+ config=config,
1089
+ )
1090
+
1091
+ paths = renew_attention_paths(maybe_attentions)
1092
+ meta_path = {
1093
+ "old": old_path,
1094
+ "new": new_path,
1095
+ }
1096
+ assign_to_checkpoint(
1097
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
1098
+ )
1099
+
1100
+ if len(output_block_list) == 3 or (not has_attention and len(maybe_attentions) > 0):
1101
+ layer_id = len(output_block_list) - 1
1102
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.{layer_id}" in key]
1103
+ paths = renew_resnet_paths(resnets)
1104
+ meta_path = {"old": f"output_blocks.{i}.{layer_id}", "new": f"up_blocks.{block_id}.upsamplers.0"}
1105
+ assign_to_checkpoint(
1106
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
1107
+ )
1108
+ else:
1109
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
1110
+ for path in resnet_0_paths:
1111
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
1112
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
1113
+
1114
+ new_checkpoint[new_path] = unet_state_dict[old_path]
1115
+
1116
+ return new_checkpoint
1117
+
1118
+
1119
+ def verify_param_count(orig_path, unet_diffusers_config):
1120
+ if "-II-" in orig_path:
1121
+ from deepfloyd_if.modules import IFStageII
1122
+
1123
+ if_II = IFStageII(device="cpu", dir_or_name=orig_path)
1124
+ elif "-III-" in orig_path:
1125
+ from deepfloyd_if.modules import IFStageIII
1126
+
1127
+ if_II = IFStageIII(device="cpu", dir_or_name=orig_path)
1128
+ else:
1129
+ assert f"Weird name. Should have -II- or -III- in path: {orig_path}"
1130
+
1131
+ unet = UNet2DConditionModel(**unet_diffusers_config)
1132
+
1133
+ # in params
1134
+ assert_param_count(unet.time_embedding, if_II.model.time_embed)
1135
+ assert_param_count(unet.conv_in, if_II.model.input_blocks[:1])
1136
+
1137
+ # downblocks
1138
+ assert_param_count(unet.down_blocks[0], if_II.model.input_blocks[1:4])
1139
+ assert_param_count(unet.down_blocks[1], if_II.model.input_blocks[4:7])
1140
+ assert_param_count(unet.down_blocks[2], if_II.model.input_blocks[7:11])
1141
+
1142
+ if "-II-" in orig_path:
1143
+ assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:17])
1144
+ assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[17:])
1145
+ if "-III-" in orig_path:
1146
+ assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:15])
1147
+ assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[15:20])
1148
+ assert_param_count(unet.down_blocks[5], if_II.model.input_blocks[20:])
1149
+
1150
+ # mid block
1151
+ assert_param_count(unet.mid_block, if_II.model.middle_block)
1152
+
1153
+ # up block
1154
+ if "-II-" in orig_path:
1155
+ assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:6])
1156
+ assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[6:12])
1157
+ assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[12:16])
1158
+ assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[16:19])
1159
+ assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[19:])
1160
+ if "-III-" in orig_path:
1161
+ assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:5])
1162
+ assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[5:10])
1163
+ assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[10:14])
1164
+ assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[14:18])
1165
+ assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[18:21])
1166
+ assert_param_count(unet.up_blocks[5], if_II.model.output_blocks[21:24])
1167
+
1168
+ # out params
1169
+ assert_param_count(unet.conv_norm_out, if_II.model.out[0])
1170
+ assert_param_count(unet.conv_out, if_II.model.out[2])
1171
+
1172
+ # make sure all model architecture has same param count
1173
+ assert_param_count(unet, if_II.model)
1174
+
1175
+
1176
+ def assert_param_count(model_1, model_2):
1177
+ count_1 = sum(p.numel() for p in model_1.parameters())
1178
+ count_2 = sum(p.numel() for p in model_2.parameters())
1179
+ assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}"
1180
+
1181
+
1182
+ def superres_check_against_original(dump_path, unet_checkpoint_path):
1183
+ model_path = dump_path
1184
+ model = UNet2DConditionModel.from_pretrained(model_path)
1185
+ model.to("cuda")
1186
+ orig_path = unet_checkpoint_path
1187
+
1188
+ if "-II-" in orig_path:
1189
+ from deepfloyd_if.modules import IFStageII
1190
+
1191
+ if_II_model = IFStageII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model
1192
+ elif "-III-" in orig_path:
1193
+ from deepfloyd_if.modules import IFStageIII
1194
+
1195
+ if_II_model = IFStageIII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model
1196
+
1197
+ batch_size = 1
1198
+ channels = model.config.in_channels // 2
1199
+ height = model.config.sample_size
1200
+ width = model.config.sample_size
1201
+ height = 1024
1202
+ width = 1024
1203
+
1204
+ torch.manual_seed(0)
1205
+
1206
+ latents = torch.randn((batch_size, channels, height, width), device=model.device)
1207
+ image_small = torch.randn((batch_size, channels, height // 4, width // 4), device=model.device)
1208
+
1209
+ interpolate_antialias = {}
1210
+ if "antialias" in inspect.signature(F.interpolate).parameters:
1211
+ interpolate_antialias["antialias"] = True
1212
+ image_upscaled = F.interpolate(
1213
+ image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
1214
+ )
1215
+
1216
+ latent_model_input = torch.cat([latents, image_upscaled], dim=1).to(model.dtype)
1217
+ t = torch.tensor([5], device=model.device).to(model.dtype)
1218
+
1219
+ seq_len = 64
1220
+ encoder_hidden_states = torch.randn((batch_size, seq_len, model.config.encoder_hid_dim), device=model.device).to(
1221
+ model.dtype
1222
+ )
1223
+
1224
+ fake_class_labels = torch.tensor([t], device=model.device).to(model.dtype)
1225
+
1226
+ with torch.no_grad():
1227
+ out = if_II_model(latent_model_input, t, aug_steps=fake_class_labels, text_emb=encoder_hidden_states)
1228
+
1229
+ if_II_model.to("cpu")
1230
+ del if_II_model
1231
+ import gc
1232
+
1233
+ torch.cuda.empty_cache()
1234
+ gc.collect()
1235
+ print(50 * "=")
1236
+
1237
+ with torch.no_grad():
1238
+ noise_pred = model(
1239
+ sample=latent_model_input,
1240
+ encoder_hidden_states=encoder_hidden_states,
1241
+ class_labels=fake_class_labels,
1242
+ timestep=t,
1243
+ ).sample
1244
+
1245
+ print("Out shape", noise_pred.shape)
1246
+ print("Diff", (out - noise_pred).abs().sum())
1247
+
1248
+
1249
+ if __name__ == "__main__":
1250
+ main(parse_args())
diffusers/scripts/convert_lora_safetensor_to_diffusers.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024, Haofan Wang, Qixun Wang, All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Conversion script for the LoRA's safetensors checkpoints."""
17
+
18
+ import argparse
19
+
20
+ import torch
21
+ from safetensors.torch import load_file
22
+
23
+ from diffusers import StableDiffusionPipeline
24
+
25
+
26
+ def convert(base_model_path, checkpoint_path, LORA_PREFIX_UNET, LORA_PREFIX_TEXT_ENCODER, alpha):
27
+ # load base model
28
+ pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
29
+
30
+ # load LoRA weight from .safetensors
31
+ state_dict = load_file(checkpoint_path)
32
+
33
+ visited = []
34
+
35
+ # directly update weight in diffusers model
36
+ for key in state_dict:
37
+ # it is suggested to print out the key, it usually will be something like below
38
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
39
+
40
+ # as we have set the alpha beforehand, so just skip
41
+ if ".alpha" in key or key in visited:
42
+ continue
43
+
44
+ if "text" in key:
45
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
46
+ curr_layer = pipeline.text_encoder
47
+ else:
48
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
49
+ curr_layer = pipeline.unet
50
+
51
+ # find the target layer
52
+ temp_name = layer_infos.pop(0)
53
+ while len(layer_infos) > -1:
54
+ try:
55
+ curr_layer = curr_layer.__getattr__(temp_name)
56
+ if len(layer_infos) > 0:
57
+ temp_name = layer_infos.pop(0)
58
+ elif len(layer_infos) == 0:
59
+ break
60
+ except Exception:
61
+ if len(temp_name) > 0:
62
+ temp_name += "_" + layer_infos.pop(0)
63
+ else:
64
+ temp_name = layer_infos.pop(0)
65
+
66
+ pair_keys = []
67
+ if "lora_down" in key:
68
+ pair_keys.append(key.replace("lora_down", "lora_up"))
69
+ pair_keys.append(key)
70
+ else:
71
+ pair_keys.append(key)
72
+ pair_keys.append(key.replace("lora_up", "lora_down"))
73
+
74
+ # update weight
75
+ if len(state_dict[pair_keys[0]].shape) == 4:
76
+ weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
77
+ weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
78
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
79
+ else:
80
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
81
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
82
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
83
+
84
+ # update visited list
85
+ for item in pair_keys:
86
+ visited.append(item)
87
+
88
+ return pipeline
89
+
90
+
91
+ if __name__ == "__main__":
92
+ parser = argparse.ArgumentParser()
93
+
94
+ parser.add_argument(
95
+ "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
96
+ )
97
+ parser.add_argument(
98
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
99
+ )
100
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
101
+ parser.add_argument(
102
+ "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
103
+ )
104
+ parser.add_argument(
105
+ "--lora_prefix_text_encoder",
106
+ default="lora_te",
107
+ type=str,
108
+ help="The prefix of text encoder weight in safetensors",
109
+ )
110
+ parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
111
+ parser.add_argument(
112
+ "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
113
+ )
114
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
115
+
116
+ args = parser.parse_args()
117
+
118
+ base_model_path = args.base_model_path
119
+ checkpoint_path = args.checkpoint_path
120
+ dump_path = args.dump_path
121
+ lora_prefix_unet = args.lora_prefix_unet
122
+ lora_prefix_text_encoder = args.lora_prefix_text_encoder
123
+ alpha = args.alpha
124
+
125
+ pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
126
+
127
+ pipe = pipe.to(args.device)
128
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
diffusers/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Conversion script for the NCSNPP checkpoints."""
16
+
17
+ import argparse
18
+ import json
19
+
20
+ import torch
21
+
22
+ from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
23
+
24
+
25
+ def convert_ncsnpp_checkpoint(checkpoint, config):
26
+ """
27
+ Takes a state dict and the path to
28
+ """
29
+ new_model_architecture = UNet2DModel(**config)
30
+ new_model_architecture.time_proj.W.data = checkpoint["all_modules.0.W"].data
31
+ new_model_architecture.time_proj.weight.data = checkpoint["all_modules.0.W"].data
32
+ new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data
33
+ new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data
34
+
35
+ new_model_architecture.time_embedding.linear_2.weight.data = checkpoint["all_modules.2.weight"].data
36
+ new_model_architecture.time_embedding.linear_2.bias.data = checkpoint["all_modules.2.bias"].data
37
+
38
+ new_model_architecture.conv_in.weight.data = checkpoint["all_modules.3.weight"].data
39
+ new_model_architecture.conv_in.bias.data = checkpoint["all_modules.3.bias"].data
40
+
41
+ new_model_architecture.conv_norm_out.weight.data = checkpoint[list(checkpoint.keys())[-4]].data
42
+ new_model_architecture.conv_norm_out.bias.data = checkpoint[list(checkpoint.keys())[-3]].data
43
+ new_model_architecture.conv_out.weight.data = checkpoint[list(checkpoint.keys())[-2]].data
44
+ new_model_architecture.conv_out.bias.data = checkpoint[list(checkpoint.keys())[-1]].data
45
+
46
+ module_index = 4
47
+
48
+ def set_attention_weights(new_layer, old_checkpoint, index):
49
+ new_layer.query.weight.data = old_checkpoint[f"all_modules.{index}.NIN_0.W"].data.T
50
+ new_layer.key.weight.data = old_checkpoint[f"all_modules.{index}.NIN_1.W"].data.T
51
+ new_layer.value.weight.data = old_checkpoint[f"all_modules.{index}.NIN_2.W"].data.T
52
+
53
+ new_layer.query.bias.data = old_checkpoint[f"all_modules.{index}.NIN_0.b"].data
54
+ new_layer.key.bias.data = old_checkpoint[f"all_modules.{index}.NIN_1.b"].data
55
+ new_layer.value.bias.data = old_checkpoint[f"all_modules.{index}.NIN_2.b"].data
56
+
57
+ new_layer.proj_attn.weight.data = old_checkpoint[f"all_modules.{index}.NIN_3.W"].data.T
58
+ new_layer.proj_attn.bias.data = old_checkpoint[f"all_modules.{index}.NIN_3.b"].data
59
+
60
+ new_layer.group_norm.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data
61
+ new_layer.group_norm.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data
62
+
63
+ def set_resnet_weights(new_layer, old_checkpoint, index):
64
+ new_layer.conv1.weight.data = old_checkpoint[f"all_modules.{index}.Conv_0.weight"].data
65
+ new_layer.conv1.bias.data = old_checkpoint[f"all_modules.{index}.Conv_0.bias"].data
66
+ new_layer.norm1.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data
67
+ new_layer.norm1.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data
68
+
69
+ new_layer.conv2.weight.data = old_checkpoint[f"all_modules.{index}.Conv_1.weight"].data
70
+ new_layer.conv2.bias.data = old_checkpoint[f"all_modules.{index}.Conv_1.bias"].data
71
+ new_layer.norm2.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.weight"].data
72
+ new_layer.norm2.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.bias"].data
73
+
74
+ new_layer.time_emb_proj.weight.data = old_checkpoint[f"all_modules.{index}.Dense_0.weight"].data
75
+ new_layer.time_emb_proj.bias.data = old_checkpoint[f"all_modules.{index}.Dense_0.bias"].data
76
+
77
+ if new_layer.in_channels != new_layer.out_channels or new_layer.up or new_layer.down:
78
+ new_layer.conv_shortcut.weight.data = old_checkpoint[f"all_modules.{index}.Conv_2.weight"].data
79
+ new_layer.conv_shortcut.bias.data = old_checkpoint[f"all_modules.{index}.Conv_2.bias"].data
80
+
81
+ for i, block in enumerate(new_model_architecture.downsample_blocks):
82
+ has_attentions = hasattr(block, "attentions")
83
+ for j in range(len(block.resnets)):
84
+ set_resnet_weights(block.resnets[j], checkpoint, module_index)
85
+ module_index += 1
86
+ if has_attentions:
87
+ set_attention_weights(block.attentions[j], checkpoint, module_index)
88
+ module_index += 1
89
+
90
+ if hasattr(block, "downsamplers") and block.downsamplers is not None:
91
+ set_resnet_weights(block.resnet_down, checkpoint, module_index)
92
+ module_index += 1
93
+ block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.Conv_0.weight"].data
94
+ block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data
95
+ module_index += 1
96
+
97
+ set_resnet_weights(new_model_architecture.mid_block.resnets[0], checkpoint, module_index)
98
+ module_index += 1
99
+ set_attention_weights(new_model_architecture.mid_block.attentions[0], checkpoint, module_index)
100
+ module_index += 1
101
+ set_resnet_weights(new_model_architecture.mid_block.resnets[1], checkpoint, module_index)
102
+ module_index += 1
103
+
104
+ for i, block in enumerate(new_model_architecture.up_blocks):
105
+ has_attentions = hasattr(block, "attentions")
106
+ for j in range(len(block.resnets)):
107
+ set_resnet_weights(block.resnets[j], checkpoint, module_index)
108
+ module_index += 1
109
+ if has_attentions:
110
+ set_attention_weights(
111
+ block.attentions[0], checkpoint, module_index
112
+ ) # why can there only be a single attention layer for up?
113
+ module_index += 1
114
+
115
+ if hasattr(block, "resnet_up") and block.resnet_up is not None:
116
+ block.skip_norm.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
117
+ block.skip_norm.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
118
+ module_index += 1
119
+ block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
120
+ block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
121
+ module_index += 1
122
+ set_resnet_weights(block.resnet_up, checkpoint, module_index)
123
+ module_index += 1
124
+
125
+ new_model_architecture.conv_norm_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
126
+ new_model_architecture.conv_norm_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
127
+ module_index += 1
128
+ new_model_architecture.conv_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
129
+ new_model_architecture.conv_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
130
+
131
+ return new_model_architecture.state_dict()
132
+
133
+
134
+ if __name__ == "__main__":
135
+ parser = argparse.ArgumentParser()
136
+
137
+ parser.add_argument(
138
+ "--checkpoint_path",
139
+ default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_pytorch_model.bin",
140
+ type=str,
141
+ required=False,
142
+ help="Path to the checkpoint to convert.",
143
+ )
144
+
145
+ parser.add_argument(
146
+ "--config_file",
147
+ default="/Users/arthurzucker/Work/diffusers/ArthurZ/config.json",
148
+ type=str,
149
+ required=False,
150
+ help="The config json file corresponding to the architecture.",
151
+ )
152
+
153
+ parser.add_argument(
154
+ "--dump_path",
155
+ default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model_new.pt",
156
+ type=str,
157
+ required=False,
158
+ help="Path to the output model.",
159
+ )
160
+
161
+ args = parser.parse_args()
162
+
163
+ checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
164
+
165
+ with open(args.config_file) as f:
166
+ config = json.loads(f.read())
167
+
168
+ converted_checkpoint = convert_ncsnpp_checkpoint(
169
+ checkpoint,
170
+ config,
171
+ )
172
+
173
+ if "sde" in config:
174
+ del config["sde"]
175
+
176
+ model = UNet2DModel(**config)
177
+ model.load_state_dict(converted_checkpoint)
178
+
179
+ try:
180
+ scheduler = ScoreSdeVeScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
181
+
182
+ pipe = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
183
+ pipe.save_pretrained(args.dump_path)
184
+ except: # noqa: E722
185
+ model.save_pretrained(args.dump_path)
diffusers/scripts/convert_omnigen_to_diffusers.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from huggingface_hub import snapshot_download
6
+ from safetensors.torch import load_file
7
+ from transformers import AutoTokenizer
8
+
9
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
10
+
11
+
12
+ def main(args):
13
+ # checkpoint from https://huggingface.co/Shitao/OmniGen-v1
14
+
15
+ if not os.path.exists(args.origin_ckpt_path):
16
+ print("Model not found, downloading...")
17
+ cache_folder = os.getenv("HF_HUB_CACHE")
18
+ args.origin_ckpt_path = snapshot_download(
19
+ repo_id=args.origin_ckpt_path,
20
+ cache_dir=cache_folder,
21
+ ignore_patterns=["flax_model.msgpack", "rust_model.ot", "tf_model.h5", "model.pt"],
22
+ )
23
+ print(f"Downloaded model to {args.origin_ckpt_path}")
24
+
25
+ ckpt = os.path.join(args.origin_ckpt_path, "model.safetensors")
26
+ ckpt = load_file(ckpt, device="cpu")
27
+
28
+ mapping_dict = {
29
+ "pos_embed": "patch_embedding.pos_embed",
30
+ "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
31
+ "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
32
+ "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
33
+ "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
34
+ "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
35
+ "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
36
+ "final_layer.linear.weight": "proj_out.weight",
37
+ "final_layer.linear.bias": "proj_out.bias",
38
+ "time_token.mlp.0.weight": "time_token.linear_1.weight",
39
+ "time_token.mlp.0.bias": "time_token.linear_1.bias",
40
+ "time_token.mlp.2.weight": "time_token.linear_2.weight",
41
+ "time_token.mlp.2.bias": "time_token.linear_2.bias",
42
+ "t_embedder.mlp.0.weight": "t_embedder.linear_1.weight",
43
+ "t_embedder.mlp.0.bias": "t_embedder.linear_1.bias",
44
+ "t_embedder.mlp.2.weight": "t_embedder.linear_2.weight",
45
+ "t_embedder.mlp.2.bias": "t_embedder.linear_2.bias",
46
+ "llm.embed_tokens.weight": "embed_tokens.weight",
47
+ }
48
+
49
+ converted_state_dict = {}
50
+ for k, v in ckpt.items():
51
+ if k in mapping_dict:
52
+ converted_state_dict[mapping_dict[k]] = v
53
+ elif "qkv" in k:
54
+ to_q, to_k, to_v = v.chunk(3)
55
+ converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_q.weight"] = to_q
56
+ converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_k.weight"] = to_k
57
+ converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_v.weight"] = to_v
58
+ elif "o_proj" in k:
59
+ converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_out.0.weight"] = v
60
+ else:
61
+ converted_state_dict[k[4:]] = v
62
+
63
+ transformer = OmniGenTransformer2DModel(
64
+ rope_scaling={
65
+ "long_factor": [
66
+ 1.0299999713897705,
67
+ 1.0499999523162842,
68
+ 1.0499999523162842,
69
+ 1.0799999237060547,
70
+ 1.2299998998641968,
71
+ 1.2299998998641968,
72
+ 1.2999999523162842,
73
+ 1.4499999284744263,
74
+ 1.5999999046325684,
75
+ 1.6499998569488525,
76
+ 1.8999998569488525,
77
+ 2.859999895095825,
78
+ 3.68999981880188,
79
+ 5.419999599456787,
80
+ 5.489999771118164,
81
+ 5.489999771118164,
82
+ 9.09000015258789,
83
+ 11.579999923706055,
84
+ 15.65999984741211,
85
+ 15.769999504089355,
86
+ 15.789999961853027,
87
+ 18.360000610351562,
88
+ 21.989999771118164,
89
+ 23.079999923706055,
90
+ 30.009998321533203,
91
+ 32.35000228881836,
92
+ 32.590003967285156,
93
+ 35.56000518798828,
94
+ 39.95000457763672,
95
+ 53.840003967285156,
96
+ 56.20000457763672,
97
+ 57.95000457763672,
98
+ 59.29000473022461,
99
+ 59.77000427246094,
100
+ 59.920005798339844,
101
+ 61.190006256103516,
102
+ 61.96000671386719,
103
+ 62.50000762939453,
104
+ 63.3700065612793,
105
+ 63.48000717163086,
106
+ 63.48000717163086,
107
+ 63.66000747680664,
108
+ 63.850006103515625,
109
+ 64.08000946044922,
110
+ 64.760009765625,
111
+ 64.80001068115234,
112
+ 64.81001281738281,
113
+ 64.81001281738281,
114
+ ],
115
+ "short_factor": [
116
+ 1.05,
117
+ 1.05,
118
+ 1.05,
119
+ 1.1,
120
+ 1.1,
121
+ 1.1,
122
+ 1.2500000000000002,
123
+ 1.2500000000000002,
124
+ 1.4000000000000004,
125
+ 1.4500000000000004,
126
+ 1.5500000000000005,
127
+ 1.8500000000000008,
128
+ 1.9000000000000008,
129
+ 2.000000000000001,
130
+ 2.000000000000001,
131
+ 2.000000000000001,
132
+ 2.000000000000001,
133
+ 2.000000000000001,
134
+ 2.000000000000001,
135
+ 2.000000000000001,
136
+ 2.000000000000001,
137
+ 2.000000000000001,
138
+ 2.000000000000001,
139
+ 2.000000000000001,
140
+ 2.000000000000001,
141
+ 2.000000000000001,
142
+ 2.000000000000001,
143
+ 2.000000000000001,
144
+ 2.000000000000001,
145
+ 2.000000000000001,
146
+ 2.000000000000001,
147
+ 2.000000000000001,
148
+ 2.1000000000000005,
149
+ 2.1000000000000005,
150
+ 2.2,
151
+ 2.3499999999999996,
152
+ 2.3499999999999996,
153
+ 2.3499999999999996,
154
+ 2.3499999999999996,
155
+ 2.3999999999999995,
156
+ 2.3999999999999995,
157
+ 2.6499999999999986,
158
+ 2.6999999999999984,
159
+ 2.8999999999999977,
160
+ 2.9499999999999975,
161
+ 3.049999999999997,
162
+ 3.049999999999997,
163
+ 3.049999999999997,
164
+ ],
165
+ "type": "su",
166
+ },
167
+ patch_size=2,
168
+ in_channels=4,
169
+ pos_embed_max_size=192,
170
+ )
171
+ transformer.load_state_dict(converted_state_dict, strict=True)
172
+ transformer.to(torch.bfloat16)
173
+
174
+ num_model_params = sum(p.numel() for p in transformer.parameters())
175
+ print(f"Total number of transformer parameters: {num_model_params}")
176
+
177
+ scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1)
178
+
179
+ vae = AutoencoderKL.from_pretrained(os.path.join(args.origin_ckpt_path, "vae"), torch_dtype=torch.float32)
180
+
181
+ tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path)
182
+
183
+ pipeline = OmniGenPipeline(tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler)
184
+ pipeline.save_pretrained(args.dump_path)
185
+
186
+
187
+ if __name__ == "__main__":
188
+ parser = argparse.ArgumentParser()
189
+
190
+ parser.add_argument(
191
+ "--origin_ckpt_path",
192
+ default="Shitao/OmniGen-v1",
193
+ type=str,
194
+ required=False,
195
+ help="Path to the checkpoint to convert.",
196
+ )
197
+
198
+ parser.add_argument(
199
+ "--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline."
200
+ )
201
+
202
+ args = parser.parse_args()
203
+ main(args)
diffusers/scripts/convert_original_audioldm2_to_diffusers.py ADDED
@@ -0,0 +1,1135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Conversion script for the AudioLDM2 checkpoints."""
16
+
17
+ import argparse
18
+ import re
19
+ from typing import List, Union
20
+
21
+ import torch
22
+ import yaml
23
+ from transformers import (
24
+ AutoFeatureExtractor,
25
+ AutoTokenizer,
26
+ ClapConfig,
27
+ ClapModel,
28
+ GPT2Config,
29
+ GPT2Model,
30
+ SpeechT5HifiGan,
31
+ SpeechT5HifiGanConfig,
32
+ T5Config,
33
+ T5EncoderModel,
34
+ )
35
+
36
+ from diffusers import (
37
+ AudioLDM2Pipeline,
38
+ AudioLDM2ProjectionModel,
39
+ AudioLDM2UNet2DConditionModel,
40
+ AutoencoderKL,
41
+ DDIMScheduler,
42
+ DPMSolverMultistepScheduler,
43
+ EulerAncestralDiscreteScheduler,
44
+ EulerDiscreteScheduler,
45
+ HeunDiscreteScheduler,
46
+ LMSDiscreteScheduler,
47
+ PNDMScheduler,
48
+ )
49
+ from diffusers.utils import is_safetensors_available
50
+ from diffusers.utils.import_utils import BACKENDS_MAPPING
51
+
52
+
53
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
54
+ def shave_segments(path, n_shave_prefix_segments=1):
55
+ """
56
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
57
+ """
58
+ if n_shave_prefix_segments >= 0:
59
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
60
+ else:
61
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
62
+
63
+
64
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_resnet_paths
65
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
66
+ """
67
+ Updates paths inside resnets to the new naming scheme (local renaming)
68
+ """
69
+ mapping = []
70
+ for old_item in old_list:
71
+ new_item = old_item.replace("in_layers.0", "norm1")
72
+ new_item = new_item.replace("in_layers.2", "conv1")
73
+
74
+ new_item = new_item.replace("out_layers.0", "norm2")
75
+ new_item = new_item.replace("out_layers.3", "conv2")
76
+
77
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
78
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
79
+
80
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
81
+
82
+ mapping.append({"old": old_item, "new": new_item})
83
+
84
+ return mapping
85
+
86
+
87
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths
88
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
89
+ """
90
+ Updates paths inside resnets to the new naming scheme (local renaming)
91
+ """
92
+ mapping = []
93
+ for old_item in old_list:
94
+ new_item = old_item
95
+
96
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
97
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
98
+
99
+ mapping.append({"old": old_item, "new": new_item})
100
+
101
+ return mapping
102
+
103
+
104
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_attention_paths
105
+ def renew_attention_paths(old_list):
106
+ """
107
+ Updates paths inside attentions to the new naming scheme (local renaming)
108
+ """
109
+ mapping = []
110
+ for old_item in old_list:
111
+ new_item = old_item
112
+
113
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
114
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
115
+
116
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
117
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
118
+
119
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
120
+
121
+ mapping.append({"old": old_item, "new": new_item})
122
+
123
+ return mapping
124
+
125
+
126
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
127
+ """
128
+ Updates paths inside attentions to the new naming scheme (local renaming)
129
+ """
130
+ mapping = []
131
+ for old_item in old_list:
132
+ new_item = old_item
133
+
134
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
135
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
136
+
137
+ new_item = new_item.replace("q.weight", "to_q.weight")
138
+ new_item = new_item.replace("q.bias", "to_q.bias")
139
+
140
+ new_item = new_item.replace("k.weight", "to_k.weight")
141
+ new_item = new_item.replace("k.bias", "to_k.bias")
142
+
143
+ new_item = new_item.replace("v.weight", "to_v.weight")
144
+ new_item = new_item.replace("v.bias", "to_v.bias")
145
+
146
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
147
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
148
+
149
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
150
+
151
+ mapping.append({"old": old_item, "new": new_item})
152
+
153
+ return mapping
154
+
155
+
156
+ def assign_to_checkpoint(
157
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
158
+ ):
159
+ """
160
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
161
+ attention layers, and takes into account additional replacements that may arise.
162
+
163
+ Assigns the weights to the new checkpoint.
164
+ """
165
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
166
+
167
+ # Splits the attention layers into three variables.
168
+ if attention_paths_to_split is not None:
169
+ for path, path_map in attention_paths_to_split.items():
170
+ old_tensor = old_checkpoint[path]
171
+ channels = old_tensor.shape[0] // 3
172
+
173
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
174
+
175
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
176
+
177
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
178
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
179
+
180
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
181
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
182
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
183
+
184
+ for path in paths:
185
+ new_path = path["new"]
186
+
187
+ # These have already been assigned
188
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
189
+ continue
190
+
191
+ if additional_replacements is not None:
192
+ for replacement in additional_replacements:
193
+ new_path = new_path.replace(replacement["old"], replacement["new"])
194
+
195
+ # proj_attn.weight has to be converted from conv 1D to linear
196
+ if "proj_attn.weight" in new_path:
197
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
198
+ else:
199
+ checkpoint[new_path] = old_checkpoint[path["old"]]
200
+
201
+
202
+ def conv_attn_to_linear(checkpoint):
203
+ keys = list(checkpoint.keys())
204
+ attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
205
+ proj_key = "to_out.0.weight"
206
+ for key in keys:
207
+ if ".".join(key.split(".")[-2:]) in attn_keys or ".".join(key.split(".")[-3:]) == proj_key:
208
+ if checkpoint[key].ndim > 2:
209
+ checkpoint[key] = checkpoint[key].squeeze()
210
+
211
+
212
+ def create_unet_diffusers_config(original_config, image_size: int):
213
+ """
214
+ Creates a UNet config for diffusers based on the config of the original AudioLDM2 model.
215
+ """
216
+ unet_params = original_config["model"]["params"]["unet_config"]["params"]
217
+ vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
218
+
219
+ block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
220
+
221
+ down_block_types = []
222
+ resolution = 1
223
+ for i in range(len(block_out_channels)):
224
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
225
+ down_block_types.append(block_type)
226
+ if i != len(block_out_channels) - 1:
227
+ resolution *= 2
228
+
229
+ up_block_types = []
230
+ for i in range(len(block_out_channels)):
231
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
232
+ up_block_types.append(block_type)
233
+ resolution //= 2
234
+
235
+ vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
236
+
237
+ cross_attention_dim = list(unet_params["context_dim"]) if "context_dim" in unet_params else block_out_channels
238
+ if len(cross_attention_dim) > 1:
239
+ # require two or more cross-attention layers per-block, each of different dimension
240
+ cross_attention_dim = [cross_attention_dim for _ in range(len(block_out_channels))]
241
+
242
+ config = {
243
+ "sample_size": image_size // vae_scale_factor,
244
+ "in_channels": unet_params["in_channels"],
245
+ "out_channels": unet_params["out_channels"],
246
+ "down_block_types": tuple(down_block_types),
247
+ "up_block_types": tuple(up_block_types),
248
+ "block_out_channels": tuple(block_out_channels),
249
+ "layers_per_block": unet_params["num_res_blocks"],
250
+ "transformer_layers_per_block": unet_params["transformer_depth"],
251
+ "cross_attention_dim": tuple(cross_attention_dim),
252
+ }
253
+
254
+ return config
255
+
256
+
257
+ # Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config
258
+ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
259
+ """
260
+ Creates a VAE config for diffusers based on the config of the original AudioLDM2 model. Compared to the original
261
+ Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
262
+ """
263
+ vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
264
+ _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
265
+
266
+ block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
267
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
268
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
269
+
270
+ scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215
271
+
272
+ config = {
273
+ "sample_size": image_size,
274
+ "in_channels": vae_params["in_channels"],
275
+ "out_channels": vae_params["out_ch"],
276
+ "down_block_types": tuple(down_block_types),
277
+ "up_block_types": tuple(up_block_types),
278
+ "block_out_channels": tuple(block_out_channels),
279
+ "latent_channels": vae_params["z_channels"],
280
+ "layers_per_block": vae_params["num_res_blocks"],
281
+ "scaling_factor": float(scaling_factor),
282
+ }
283
+ return config
284
+
285
+
286
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
287
+ def create_diffusers_schedular(original_config):
288
+ schedular = DDIMScheduler(
289
+ num_train_timesteps=original_config["model"]["params"]["timesteps"],
290
+ beta_start=original_config["model"]["params"]["linear_start"],
291
+ beta_end=original_config["model"]["params"]["linear_end"],
292
+ beta_schedule="scaled_linear",
293
+ )
294
+ return schedular
295
+
296
+
297
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
298
+ """
299
+ Takes a state dict and a config, and returns a converted UNet checkpoint.
300
+ """
301
+
302
+ # extract state_dict for UNet
303
+ unet_state_dict = {}
304
+ keys = list(checkpoint.keys())
305
+
306
+ unet_key = "model.diffusion_model."
307
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
308
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
309
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
310
+ print(
311
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
312
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
313
+ )
314
+ for key in keys:
315
+ if key.startswith("model.diffusion_model"):
316
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
317
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
318
+ else:
319
+ if sum(k.startswith("model_ema") for k in keys) > 100:
320
+ print(
321
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
322
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
323
+ )
324
+
325
+ # strip the unet prefix from the weight names
326
+ for key in keys:
327
+ if key.startswith(unet_key):
328
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
329
+
330
+ new_checkpoint = {}
331
+
332
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
333
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
334
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
335
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
336
+
337
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
338
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
339
+
340
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
341
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
342
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
343
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
344
+
345
+ # Retrieves the keys for the input blocks only
346
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
347
+ input_blocks = {
348
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
349
+ for layer_id in range(num_input_blocks)
350
+ }
351
+
352
+ # Retrieves the keys for the middle blocks only
353
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
354
+ middle_blocks = {
355
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
356
+ for layer_id in range(num_middle_blocks)
357
+ }
358
+
359
+ # Retrieves the keys for the output blocks only
360
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
361
+ output_blocks = {
362
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
363
+ for layer_id in range(num_output_blocks)
364
+ }
365
+
366
+ # Check how many Transformer blocks we have per layer
367
+ if isinstance(config.get("cross_attention_dim"), (list, tuple)):
368
+ if isinstance(config["cross_attention_dim"][0], (list, tuple)):
369
+ # in this case we have multiple cross-attention layers per-block
370
+ num_attention_layers = len(config.get("cross_attention_dim")[0])
371
+ else:
372
+ num_attention_layers = 1
373
+
374
+ if config.get("extra_self_attn_layer"):
375
+ num_attention_layers += 1
376
+
377
+ for i in range(1, num_input_blocks):
378
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
379
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
380
+
381
+ resnets = [
382
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
383
+ ]
384
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.0" not in key]
385
+
386
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
387
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
388
+ f"input_blocks.{i}.0.op.weight"
389
+ )
390
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
391
+ f"input_blocks.{i}.0.op.bias"
392
+ )
393
+
394
+ paths = renew_resnet_paths(resnets)
395
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
396
+ assign_to_checkpoint(
397
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
398
+ )
399
+
400
+ if len(attentions):
401
+ paths = renew_attention_paths(attentions)
402
+ meta_path = [
403
+ {
404
+ "old": f"input_blocks.{i}.{1 + layer_id}",
405
+ "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id * num_attention_layers + layer_id}",
406
+ }
407
+ for layer_id in range(num_attention_layers)
408
+ ]
409
+ assign_to_checkpoint(
410
+ paths, new_checkpoint, unet_state_dict, additional_replacements=meta_path, config=config
411
+ )
412
+
413
+ resnet_0 = middle_blocks[0]
414
+ resnet_1 = middle_blocks[num_middle_blocks - 1]
415
+
416
+ resnet_0_paths = renew_resnet_paths(resnet_0)
417
+ meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"}
418
+ assign_to_checkpoint(
419
+ resnet_0_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
420
+ )
421
+
422
+ resnet_1_paths = renew_resnet_paths(resnet_1)
423
+ meta_path = {"old": f"middle_block.{len(middle_blocks) - 1}", "new": "mid_block.resnets.1"}
424
+ assign_to_checkpoint(
425
+ resnet_1_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
426
+ )
427
+
428
+ for i in range(1, num_middle_blocks - 1):
429
+ attentions = middle_blocks[i]
430
+ attentions_paths = renew_attention_paths(attentions)
431
+ meta_path = {"old": f"middle_block.{i}", "new": f"mid_block.attentions.{i - 1}"}
432
+ assign_to_checkpoint(
433
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
434
+ )
435
+
436
+ for i in range(num_output_blocks):
437
+ block_id = i // (config["layers_per_block"] + 1)
438
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
439
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
440
+ output_block_list = {}
441
+
442
+ for layer in output_block_layers:
443
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
444
+ if layer_id in output_block_list:
445
+ output_block_list[layer_id].append(layer_name)
446
+ else:
447
+ output_block_list[layer_id] = [layer_name]
448
+
449
+ if len(output_block_list) > 1:
450
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
451
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.0" not in key]
452
+
453
+ paths = renew_resnet_paths(resnets)
454
+
455
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
456
+ assign_to_checkpoint(
457
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
458
+ )
459
+
460
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
461
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
462
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
463
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
464
+ f"output_blocks.{i}.{index}.conv.weight"
465
+ ]
466
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
467
+ f"output_blocks.{i}.{index}.conv.bias"
468
+ ]
469
+
470
+ attentions.remove(f"output_blocks.{i}.{index}.conv.bias")
471
+ attentions.remove(f"output_blocks.{i}.{index}.conv.weight")
472
+
473
+ # Clear attentions as they have been attributed above.
474
+ if len(attentions) == 2:
475
+ attentions = []
476
+
477
+ if len(attentions):
478
+ paths = renew_attention_paths(attentions)
479
+ meta_path = [
480
+ {
481
+ "old": f"output_blocks.{i}.{1 + layer_id}",
482
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id * num_attention_layers + layer_id}",
483
+ }
484
+ for layer_id in range(num_attention_layers)
485
+ ]
486
+ assign_to_checkpoint(
487
+ paths, new_checkpoint, unet_state_dict, additional_replacements=meta_path, config=config
488
+ )
489
+ else:
490
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
491
+ for path in resnet_0_paths:
492
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
493
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
494
+
495
+ new_checkpoint[new_path] = unet_state_dict[old_path]
496
+
497
+ return new_checkpoint
498
+
499
+
500
+ def convert_ldm_vae_checkpoint(checkpoint, config):
501
+ # extract state dict for VAE
502
+ vae_state_dict = {}
503
+ vae_key = "first_stage_model."
504
+ keys = list(checkpoint.keys())
505
+ for key in keys:
506
+ if key.startswith(vae_key):
507
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
508
+
509
+ new_checkpoint = {}
510
+
511
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
512
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
513
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
514
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
515
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
516
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
517
+
518
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
519
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
520
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
521
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
522
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
523
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
524
+
525
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
526
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
527
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
528
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
529
+
530
+ # Retrieves the keys for the encoder down blocks only
531
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
532
+ down_blocks = {
533
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
534
+ }
535
+
536
+ # Retrieves the keys for the decoder up blocks only
537
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
538
+ up_blocks = {
539
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
540
+ }
541
+
542
+ for i in range(num_down_blocks):
543
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
544
+
545
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
546
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
547
+ f"encoder.down.{i}.downsample.conv.weight"
548
+ )
549
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
550
+ f"encoder.down.{i}.downsample.conv.bias"
551
+ )
552
+
553
+ paths = renew_vae_resnet_paths(resnets)
554
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
555
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
556
+
557
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
558
+ num_mid_res_blocks = 2
559
+ for i in range(1, num_mid_res_blocks + 1):
560
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
561
+
562
+ paths = renew_vae_resnet_paths(resnets)
563
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
564
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
565
+
566
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
567
+ paths = renew_vae_attention_paths(mid_attentions)
568
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
569
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
570
+ conv_attn_to_linear(new_checkpoint)
571
+
572
+ for i in range(num_up_blocks):
573
+ block_id = num_up_blocks - 1 - i
574
+ resnets = [
575
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
576
+ ]
577
+
578
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
579
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
580
+ f"decoder.up.{block_id}.upsample.conv.weight"
581
+ ]
582
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
583
+ f"decoder.up.{block_id}.upsample.conv.bias"
584
+ ]
585
+
586
+ paths = renew_vae_resnet_paths(resnets)
587
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
588
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
589
+
590
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
591
+ num_mid_res_blocks = 2
592
+ for i in range(1, num_mid_res_blocks + 1):
593
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
594
+
595
+ paths = renew_vae_resnet_paths(resnets)
596
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
597
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
598
+
599
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
600
+ paths = renew_vae_attention_paths(mid_attentions)
601
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
602
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
603
+ conv_attn_to_linear(new_checkpoint)
604
+ return new_checkpoint
605
+
606
+
607
+ CLAP_KEYS_TO_MODIFY_MAPPING = {
608
+ "text_branch": "text_model",
609
+ "audio_branch": "audio_model.audio_encoder",
610
+ "attn": "attention.self",
611
+ "self.proj": "output.dense",
612
+ "attention.self_mask": "attn_mask",
613
+ "mlp.fc1": "intermediate.dense",
614
+ "mlp.fc2": "output.dense",
615
+ "norm1": "layernorm_before",
616
+ "norm2": "layernorm_after",
617
+ "bn0": "batch_norm",
618
+ }
619
+
620
+ CLAP_KEYS_TO_IGNORE = [
621
+ "text_transform",
622
+ "audio_transform",
623
+ "stft",
624
+ "logmel_extractor",
625
+ "tscam_conv",
626
+ "head",
627
+ "attn_mask",
628
+ ]
629
+
630
+ CLAP_EXPECTED_MISSING_KEYS = ["text_model.embeddings.token_type_ids"]
631
+
632
+
633
+ def convert_open_clap_checkpoint(checkpoint):
634
+ """
635
+ Takes a state dict and returns a converted CLAP checkpoint.
636
+ """
637
+ # extract state dict for CLAP text embedding model, discarding the audio component
638
+ model_state_dict = {}
639
+ model_key = "clap.model."
640
+ keys = list(checkpoint.keys())
641
+ for key in keys:
642
+ if key.startswith(model_key):
643
+ model_state_dict[key.replace(model_key, "")] = checkpoint.get(key)
644
+
645
+ new_checkpoint = {}
646
+
647
+ sequential_layers_pattern = r".*sequential.(\d+).*"
648
+ text_projection_pattern = r".*_projection.(\d+).*"
649
+
650
+ for key, value in model_state_dict.items():
651
+ # check if key should be ignored in mapping - if so map it to a key name that we'll filter out at the end
652
+ for key_to_ignore in CLAP_KEYS_TO_IGNORE:
653
+ if key_to_ignore in key:
654
+ key = "spectrogram"
655
+
656
+ # check if any key needs to be modified
657
+ for key_to_modify, new_key in CLAP_KEYS_TO_MODIFY_MAPPING.items():
658
+ if key_to_modify in key:
659
+ key = key.replace(key_to_modify, new_key)
660
+
661
+ if re.match(sequential_layers_pattern, key):
662
+ # replace sequential layers with list
663
+ sequential_layer = re.match(sequential_layers_pattern, key).group(1)
664
+
665
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
666
+ elif re.match(text_projection_pattern, key):
667
+ projecton_layer = int(re.match(text_projection_pattern, key).group(1))
668
+
669
+ # Because in CLAP they use `nn.Sequential`...
670
+ transformers_projection_layer = 1 if projecton_layer == 0 else 2
671
+
672
+ key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
673
+
674
+ if "audio" and "qkv" in key:
675
+ # split qkv into query key and value
676
+ mixed_qkv = value
677
+ qkv_dim = mixed_qkv.size(0) // 3
678
+
679
+ query_layer = mixed_qkv[:qkv_dim]
680
+ key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
681
+ value_layer = mixed_qkv[qkv_dim * 2 :]
682
+
683
+ new_checkpoint[key.replace("qkv", "query")] = query_layer
684
+ new_checkpoint[key.replace("qkv", "key")] = key_layer
685
+ new_checkpoint[key.replace("qkv", "value")] = value_layer
686
+ elif key != "spectrogram":
687
+ new_checkpoint[key] = value
688
+
689
+ return new_checkpoint
690
+
691
+
692
+ def create_transformers_vocoder_config(original_config):
693
+ """
694
+ Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
695
+ """
696
+ vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"]
697
+
698
+ config = {
699
+ "model_in_dim": vocoder_params["num_mels"],
700
+ "sampling_rate": vocoder_params["sampling_rate"],
701
+ "upsample_initial_channel": vocoder_params["upsample_initial_channel"],
702
+ "upsample_rates": list(vocoder_params["upsample_rates"]),
703
+ "upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]),
704
+ "resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]),
705
+ "resblock_dilation_sizes": [
706
+ list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"]
707
+ ],
708
+ "normalize_before": False,
709
+ }
710
+
711
+ return config
712
+
713
+
714
+ def extract_sub_model(checkpoint, key_prefix):
715
+ """
716
+ Takes a state dict and returns the state dict for a particular sub-model.
717
+ """
718
+
719
+ sub_model_state_dict = {}
720
+ keys = list(checkpoint.keys())
721
+ for key in keys:
722
+ if key.startswith(key_prefix):
723
+ sub_model_state_dict[key.replace(key_prefix, "")] = checkpoint.get(key)
724
+
725
+ return sub_model_state_dict
726
+
727
+
728
+ def convert_hifigan_checkpoint(checkpoint, config):
729
+ """
730
+ Takes a state dict and config, and returns a converted HiFiGAN vocoder checkpoint.
731
+ """
732
+ # extract state dict for vocoder
733
+ vocoder_state_dict = extract_sub_model(checkpoint, key_prefix="first_stage_model.vocoder.")
734
+
735
+ # fix upsampler keys, everything else is correct already
736
+ for i in range(len(config.upsample_rates)):
737
+ vocoder_state_dict[f"upsampler.{i}.weight"] = vocoder_state_dict.pop(f"ups.{i}.weight")
738
+ vocoder_state_dict[f"upsampler.{i}.bias"] = vocoder_state_dict.pop(f"ups.{i}.bias")
739
+
740
+ if not config.normalize_before:
741
+ # if we don't set normalize_before then these variables are unused, so we set them to their initialised values
742
+ vocoder_state_dict["mean"] = torch.zeros(config.model_in_dim)
743
+ vocoder_state_dict["scale"] = torch.ones(config.model_in_dim)
744
+
745
+ return vocoder_state_dict
746
+
747
+
748
+ def convert_projection_checkpoint(checkpoint):
749
+ projection_state_dict = {}
750
+ conditioner_state_dict = extract_sub_model(checkpoint, key_prefix="cond_stage_models.0.")
751
+
752
+ projection_state_dict["sos_embed"] = conditioner_state_dict["start_of_sequence_tokens.weight"][0]
753
+ projection_state_dict["sos_embed_1"] = conditioner_state_dict["start_of_sequence_tokens.weight"][1]
754
+
755
+ projection_state_dict["eos_embed"] = conditioner_state_dict["end_of_sequence_tokens.weight"][0]
756
+ projection_state_dict["eos_embed_1"] = conditioner_state_dict["end_of_sequence_tokens.weight"][1]
757
+
758
+ projection_state_dict["projection.weight"] = conditioner_state_dict["input_sequence_embed_linear.0.weight"]
759
+ projection_state_dict["projection.bias"] = conditioner_state_dict["input_sequence_embed_linear.0.bias"]
760
+
761
+ projection_state_dict["projection_1.weight"] = conditioner_state_dict["input_sequence_embed_linear.1.weight"]
762
+ projection_state_dict["projection_1.bias"] = conditioner_state_dict["input_sequence_embed_linear.1.bias"]
763
+
764
+ return projection_state_dict
765
+
766
+
767
+ # Adapted from https://github.com/haoheliu/AudioLDM2/blob/81ad2c6ce015c1310387695e2dae975a7d2ed6fd/audioldm2/utils.py#L143
768
+ DEFAULT_CONFIG = {
769
+ "model": {
770
+ "params": {
771
+ "linear_start": 0.0015,
772
+ "linear_end": 0.0195,
773
+ "timesteps": 1000,
774
+ "channels": 8,
775
+ "scale_by_std": True,
776
+ "unet_config": {
777
+ "target": "audioldm2.latent_diffusion.openaimodel.UNetModel",
778
+ "params": {
779
+ "context_dim": [None, 768, 1024],
780
+ "in_channels": 8,
781
+ "out_channels": 8,
782
+ "model_channels": 128,
783
+ "attention_resolutions": [8, 4, 2],
784
+ "num_res_blocks": 2,
785
+ "channel_mult": [1, 2, 3, 5],
786
+ "num_head_channels": 32,
787
+ "transformer_depth": 1,
788
+ },
789
+ },
790
+ "first_stage_config": {
791
+ "target": "audioldm2.variational_autoencoder.autoencoder.AutoencoderKL",
792
+ "params": {
793
+ "embed_dim": 8,
794
+ "ddconfig": {
795
+ "z_channels": 8,
796
+ "resolution": 256,
797
+ "in_channels": 1,
798
+ "out_ch": 1,
799
+ "ch": 128,
800
+ "ch_mult": [1, 2, 4],
801
+ "num_res_blocks": 2,
802
+ },
803
+ },
804
+ },
805
+ "cond_stage_config": {
806
+ "crossattn_audiomae_generated": {
807
+ "target": "audioldm2.latent_diffusion.modules.encoders.modules.SequenceGenAudioMAECond",
808
+ "params": {
809
+ "sequence_gen_length": 8,
810
+ "sequence_input_embed_dim": [512, 1024],
811
+ },
812
+ }
813
+ },
814
+ "vocoder_config": {
815
+ "target": "audioldm2.first_stage_model.vocoder",
816
+ "params": {
817
+ "upsample_rates": [5, 4, 2, 2, 2],
818
+ "upsample_kernel_sizes": [16, 16, 8, 4, 4],
819
+ "upsample_initial_channel": 1024,
820
+ "resblock_kernel_sizes": [3, 7, 11],
821
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
822
+ "num_mels": 64,
823
+ "sampling_rate": 16000,
824
+ },
825
+ },
826
+ },
827
+ },
828
+ }
829
+
830
+
831
+ def load_pipeline_from_original_AudioLDM2_ckpt(
832
+ checkpoint_path: str,
833
+ original_config_file: str = None,
834
+ image_size: int = 1024,
835
+ prediction_type: str = None,
836
+ extract_ema: bool = False,
837
+ scheduler_type: str = "ddim",
838
+ cross_attention_dim: Union[List, List[List]] = None,
839
+ transformer_layers_per_block: int = None,
840
+ device: str = None,
841
+ from_safetensors: bool = False,
842
+ ) -> AudioLDM2Pipeline:
843
+ """
844
+ Load an AudioLDM2 pipeline object from a `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file.
845
+
846
+ Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the
847
+ global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
848
+ recommended that you override the default values and/or supply an `original_config_file` wherever possible.
849
+
850
+ Args:
851
+ checkpoint_path (`str`): Path to `.ckpt` file.
852
+ original_config_file (`str`):
853
+ Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
854
+ set to the AudioLDM2 base config.
855
+ image_size (`int`, *optional*, defaults to 1024):
856
+ The image size that the model was trained on.
857
+ prediction_type (`str`, *optional*):
858
+ The prediction type that the model was trained on. If `None`, will be automatically
859
+ inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`.
860
+ scheduler_type (`str`, *optional*, defaults to 'ddim'):
861
+ Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
862
+ "ddim"]`.
863
+ cross_attention_dim (`list`, *optional*, defaults to `None`):
864
+ The dimension of the cross-attention layers. If `None`, the cross-attention dimension will be
865
+ automatically inferred. Set to `[768, 1024]` for the base model, or `[768, 1024, None]` for the large model.
866
+ transformer_layers_per_block (`int`, *optional*, defaults to `None`):
867
+ The number of transformer layers in each transformer block. If `None`, number of layers will be "
868
+ "automatically inferred. Set to `1` for the base model, or `2` for the large model.
869
+ extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
870
+ checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
871
+ `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
872
+ inference. Non-EMA weights are usually better to continue fine-tuning.
873
+ device (`str`, *optional*, defaults to `None`):
874
+ The device to use. Pass `None` to determine automatically.
875
+ from_safetensors (`str`, *optional*, defaults to `False`):
876
+ If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
877
+ return: An AudioLDM2Pipeline object representing the passed-in `.ckpt`/`.safetensors` file.
878
+ """
879
+
880
+ if from_safetensors:
881
+ if not is_safetensors_available():
882
+ raise ValueError(BACKENDS_MAPPING["safetensors"][1])
883
+
884
+ from safetensors import safe_open
885
+
886
+ checkpoint = {}
887
+ with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
888
+ for key in f.keys():
889
+ checkpoint[key] = f.get_tensor(key)
890
+ else:
891
+ if device is None:
892
+ device = "cuda" if torch.cuda.is_available() else "cpu"
893
+ checkpoint = torch.load(checkpoint_path, map_location=device)
894
+ else:
895
+ checkpoint = torch.load(checkpoint_path, map_location=device)
896
+
897
+ if "state_dict" in checkpoint:
898
+ checkpoint = checkpoint["state_dict"]
899
+
900
+ if original_config_file is None:
901
+ original_config = DEFAULT_CONFIG
902
+ else:
903
+ original_config = yaml.safe_load(original_config_file)
904
+
905
+ if image_size is not None:
906
+ original_config["model"]["params"]["unet_config"]["params"]["image_size"] = image_size
907
+
908
+ if cross_attention_dim is not None:
909
+ original_config["model"]["params"]["unet_config"]["params"]["context_dim"] = cross_attention_dim
910
+
911
+ if transformer_layers_per_block is not None:
912
+ original_config["model"]["params"]["unet_config"]["params"]["transformer_depth"] = transformer_layers_per_block
913
+
914
+ if (
915
+ "parameterization" in original_config["model"]["params"]
916
+ and original_config["model"]["params"]["parameterization"] == "v"
917
+ ):
918
+ if prediction_type is None:
919
+ prediction_type = "v_prediction"
920
+ else:
921
+ if prediction_type is None:
922
+ prediction_type = "epsilon"
923
+
924
+ num_train_timesteps = original_config["model"]["params"]["timesteps"]
925
+ beta_start = original_config["model"]["params"]["linear_start"]
926
+ beta_end = original_config["model"]["params"]["linear_end"]
927
+
928
+ scheduler = DDIMScheduler(
929
+ beta_end=beta_end,
930
+ beta_schedule="scaled_linear",
931
+ beta_start=beta_start,
932
+ num_train_timesteps=num_train_timesteps,
933
+ steps_offset=1,
934
+ clip_sample=False,
935
+ set_alpha_to_one=False,
936
+ prediction_type=prediction_type,
937
+ )
938
+ # make sure scheduler works correctly with DDIM
939
+ scheduler.register_to_config(clip_sample=False)
940
+
941
+ if scheduler_type == "pndm":
942
+ config = dict(scheduler.config)
943
+ config["skip_prk_steps"] = True
944
+ scheduler = PNDMScheduler.from_config(config)
945
+ elif scheduler_type == "lms":
946
+ scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
947
+ elif scheduler_type == "heun":
948
+ scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
949
+ elif scheduler_type == "euler":
950
+ scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
951
+ elif scheduler_type == "euler-ancestral":
952
+ scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
953
+ elif scheduler_type == "dpm":
954
+ scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
955
+ elif scheduler_type == "ddim":
956
+ scheduler = scheduler
957
+ else:
958
+ raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
959
+
960
+ # Convert the UNet2DModel
961
+ unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
962
+ unet = AudioLDM2UNet2DConditionModel(**unet_config)
963
+
964
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
965
+ checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
966
+ )
967
+
968
+ unet.load_state_dict(converted_unet_checkpoint)
969
+
970
+ # Convert the VAE model
971
+ vae_config = create_vae_diffusers_config(original_config, checkpoint=checkpoint, image_size=image_size)
972
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
973
+
974
+ vae = AutoencoderKL(**vae_config)
975
+ vae.load_state_dict(converted_vae_checkpoint)
976
+
977
+ # Convert the joint audio-text encoding model
978
+ clap_config = ClapConfig.from_pretrained("laion/clap-htsat-unfused")
979
+ clap_config.audio_config.update(
980
+ {
981
+ "patch_embeds_hidden_size": 128,
982
+ "hidden_size": 1024,
983
+ "depths": [2, 2, 12, 2],
984
+ }
985
+ )
986
+ # AudioLDM2 uses the same tokenizer and feature extractor as the original CLAP model
987
+ clap_tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
988
+ clap_feature_extractor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused")
989
+
990
+ converted_clap_model = convert_open_clap_checkpoint(checkpoint)
991
+ clap_model = ClapModel(clap_config)
992
+
993
+ missing_keys, unexpected_keys = clap_model.load_state_dict(converted_clap_model, strict=False)
994
+ # we expect not to have token_type_ids in our original state dict so let's ignore them
995
+ missing_keys = list(set(missing_keys) - set(CLAP_EXPECTED_MISSING_KEYS))
996
+
997
+ if len(unexpected_keys) > 0:
998
+ raise ValueError(f"Unexpected keys when loading CLAP model: {unexpected_keys}")
999
+
1000
+ if len(missing_keys) > 0:
1001
+ raise ValueError(f"Missing keys when loading CLAP model: {missing_keys}")
1002
+
1003
+ # Convert the vocoder model
1004
+ vocoder_config = create_transformers_vocoder_config(original_config)
1005
+ vocoder_config = SpeechT5HifiGanConfig(**vocoder_config)
1006
+ converted_vocoder_checkpoint = convert_hifigan_checkpoint(checkpoint, vocoder_config)
1007
+
1008
+ vocoder = SpeechT5HifiGan(vocoder_config)
1009
+ vocoder.load_state_dict(converted_vocoder_checkpoint)
1010
+
1011
+ # Convert the Flan-T5 encoder model: AudioLDM2 uses the same configuration and tokenizer as the original Flan-T5 large model
1012
+ t5_config = T5Config.from_pretrained("google/flan-t5-large")
1013
+ converted_t5_checkpoint = extract_sub_model(checkpoint, key_prefix="cond_stage_models.1.model.")
1014
+
1015
+ t5_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
1016
+ # hard-coded in the original implementation (i.e. not retrievable from the config)
1017
+ t5_tokenizer.model_max_length = 128
1018
+ t5_model = T5EncoderModel(t5_config)
1019
+ t5_model.load_state_dict(converted_t5_checkpoint)
1020
+
1021
+ # Convert the GPT2 encoder model: AudioLDM2 uses the same configuration as the original GPT2 base model
1022
+ gpt2_config = GPT2Config.from_pretrained("gpt2")
1023
+ gpt2_model = GPT2Model(gpt2_config)
1024
+ gpt2_model.config.max_new_tokens = original_config["model"]["params"]["cond_stage_config"][
1025
+ "crossattn_audiomae_generated"
1026
+ ]["params"]["sequence_gen_length"]
1027
+
1028
+ converted_gpt2_checkpoint = extract_sub_model(checkpoint, key_prefix="cond_stage_models.0.model.")
1029
+ gpt2_model.load_state_dict(converted_gpt2_checkpoint)
1030
+
1031
+ # Convert the extra embedding / projection layers
1032
+ projection_model = AudioLDM2ProjectionModel(clap_config.projection_dim, t5_config.d_model, gpt2_config.n_embd)
1033
+
1034
+ converted_projection_checkpoint = convert_projection_checkpoint(checkpoint)
1035
+ projection_model.load_state_dict(converted_projection_checkpoint)
1036
+
1037
+ # Instantiate the diffusers pipeline
1038
+ pipe = AudioLDM2Pipeline(
1039
+ vae=vae,
1040
+ text_encoder=clap_model,
1041
+ text_encoder_2=t5_model,
1042
+ projection_model=projection_model,
1043
+ language_model=gpt2_model,
1044
+ tokenizer=clap_tokenizer,
1045
+ tokenizer_2=t5_tokenizer,
1046
+ feature_extractor=clap_feature_extractor,
1047
+ unet=unet,
1048
+ scheduler=scheduler,
1049
+ vocoder=vocoder,
1050
+ )
1051
+
1052
+ return pipe
1053
+
1054
+
1055
+ if __name__ == "__main__":
1056
+ parser = argparse.ArgumentParser()
1057
+
1058
+ parser.add_argument(
1059
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
1060
+ )
1061
+ parser.add_argument(
1062
+ "--original_config_file",
1063
+ default=None,
1064
+ type=str,
1065
+ help="The YAML config file corresponding to the original architecture.",
1066
+ )
1067
+ parser.add_argument(
1068
+ "--cross_attention_dim",
1069
+ default=None,
1070
+ type=int,
1071
+ nargs="+",
1072
+ help="The dimension of the cross-attention layers. If `None`, the cross-attention dimension will be "
1073
+ "automatically inferred. Set to `768+1024` for the base model, or `768+1024+640` for the large model",
1074
+ )
1075
+ parser.add_argument(
1076
+ "--transformer_layers_per_block",
1077
+ default=None,
1078
+ type=int,
1079
+ help="The number of transformer layers in each transformer block. If `None`, number of layers will be "
1080
+ "automatically inferred. Set to `1` for the base model, or `2` for the large model.",
1081
+ )
1082
+ parser.add_argument(
1083
+ "--scheduler_type",
1084
+ default="ddim",
1085
+ type=str,
1086
+ help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
1087
+ )
1088
+ parser.add_argument(
1089
+ "--image_size",
1090
+ default=1048,
1091
+ type=int,
1092
+ help="The image size that the model was trained on.",
1093
+ )
1094
+ parser.add_argument(
1095
+ "--prediction_type",
1096
+ default=None,
1097
+ type=str,
1098
+ help=("The prediction type that the model was trained on."),
1099
+ )
1100
+ parser.add_argument(
1101
+ "--extract_ema",
1102
+ action="store_true",
1103
+ help=(
1104
+ "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
1105
+ " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
1106
+ " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
1107
+ ),
1108
+ )
1109
+ parser.add_argument(
1110
+ "--from_safetensors",
1111
+ action="store_true",
1112
+ help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
1113
+ )
1114
+ parser.add_argument(
1115
+ "--to_safetensors",
1116
+ action="store_true",
1117
+ help="Whether to store pipeline in safetensors format or not.",
1118
+ )
1119
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
1120
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
1121
+ args = parser.parse_args()
1122
+
1123
+ pipe = load_pipeline_from_original_AudioLDM2_ckpt(
1124
+ checkpoint_path=args.checkpoint_path,
1125
+ original_config_file=args.original_config_file,
1126
+ image_size=args.image_size,
1127
+ prediction_type=args.prediction_type,
1128
+ extract_ema=args.extract_ema,
1129
+ scheduler_type=args.scheduler_type,
1130
+ cross_attention_dim=args.cross_attention_dim,
1131
+ transformer_layers_per_block=args.transformer_layers_per_block,
1132
+ from_safetensors=args.from_safetensors,
1133
+ device=args.device,
1134
+ )
1135
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
diffusers/scripts/convert_original_musicldm_to_diffusers.py ADDED
@@ -0,0 +1,1056 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Conversion script for the MusicLDM checkpoints."""
16
+
17
+ import argparse
18
+ import re
19
+
20
+ import torch
21
+ import yaml
22
+ from transformers import (
23
+ AutoFeatureExtractor,
24
+ AutoTokenizer,
25
+ ClapConfig,
26
+ ClapModel,
27
+ SpeechT5HifiGan,
28
+ SpeechT5HifiGanConfig,
29
+ )
30
+
31
+ from diffusers import (
32
+ AutoencoderKL,
33
+ DDIMScheduler,
34
+ DPMSolverMultistepScheduler,
35
+ EulerAncestralDiscreteScheduler,
36
+ EulerDiscreteScheduler,
37
+ HeunDiscreteScheduler,
38
+ LMSDiscreteScheduler,
39
+ MusicLDMPipeline,
40
+ PNDMScheduler,
41
+ UNet2DConditionModel,
42
+ )
43
+
44
+
45
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
46
+ def shave_segments(path, n_shave_prefix_segments=1):
47
+ """
48
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
49
+ """
50
+ if n_shave_prefix_segments >= 0:
51
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
52
+ else:
53
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
54
+
55
+
56
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_resnet_paths
57
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
58
+ """
59
+ Updates paths inside resnets to the new naming scheme (local renaming)
60
+ """
61
+ mapping = []
62
+ for old_item in old_list:
63
+ new_item = old_item.replace("in_layers.0", "norm1")
64
+ new_item = new_item.replace("in_layers.2", "conv1")
65
+
66
+ new_item = new_item.replace("out_layers.0", "norm2")
67
+ new_item = new_item.replace("out_layers.3", "conv2")
68
+
69
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
70
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
71
+
72
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
73
+
74
+ mapping.append({"old": old_item, "new": new_item})
75
+
76
+ return mapping
77
+
78
+
79
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths
80
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
81
+ """
82
+ Updates paths inside resnets to the new naming scheme (local renaming)
83
+ """
84
+ mapping = []
85
+ for old_item in old_list:
86
+ new_item = old_item
87
+
88
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
89
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
90
+
91
+ mapping.append({"old": old_item, "new": new_item})
92
+
93
+ return mapping
94
+
95
+
96
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_attention_paths
97
+ def renew_attention_paths(old_list):
98
+ """
99
+ Updates paths inside attentions to the new naming scheme (local renaming)
100
+ """
101
+ mapping = []
102
+ for old_item in old_list:
103
+ new_item = old_item
104
+
105
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
106
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
107
+
108
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
109
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
110
+
111
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
112
+
113
+ mapping.append({"old": old_item, "new": new_item})
114
+
115
+ return mapping
116
+
117
+
118
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
119
+ """
120
+ Updates paths inside attentions to the new naming scheme (local renaming)
121
+ """
122
+ mapping = []
123
+ for old_item in old_list:
124
+ new_item = old_item
125
+
126
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
127
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
128
+
129
+ new_item = new_item.replace("q.weight", "to_q.weight")
130
+ new_item = new_item.replace("q.bias", "to_q.bias")
131
+
132
+ new_item = new_item.replace("k.weight", "to_k.weight")
133
+ new_item = new_item.replace("k.bias", "to_k.bias")
134
+
135
+ new_item = new_item.replace("v.weight", "to_v.weight")
136
+ new_item = new_item.replace("v.bias", "to_v.bias")
137
+
138
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
139
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
140
+
141
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
142
+
143
+ mapping.append({"old": old_item, "new": new_item})
144
+
145
+ return mapping
146
+
147
+
148
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
149
+ def assign_to_checkpoint(
150
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
151
+ ):
152
+ """
153
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
154
+ attention layers, and takes into account additional replacements that may arise.
155
+
156
+ Assigns the weights to the new checkpoint.
157
+ """
158
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
159
+
160
+ # Splits the attention layers into three variables.
161
+ if attention_paths_to_split is not None:
162
+ for path, path_map in attention_paths_to_split.items():
163
+ old_tensor = old_checkpoint[path]
164
+ channels = old_tensor.shape[0] // 3
165
+
166
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
167
+
168
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
169
+
170
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
171
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
172
+
173
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
174
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
175
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
176
+
177
+ for path in paths:
178
+ new_path = path["new"]
179
+
180
+ # These have already been assigned
181
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
182
+ continue
183
+
184
+ # Global renaming happens here
185
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
186
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
187
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
188
+
189
+ if additional_replacements is not None:
190
+ for replacement in additional_replacements:
191
+ new_path = new_path.replace(replacement["old"], replacement["new"])
192
+
193
+ # proj_attn.weight has to be converted from conv 1D to linear
194
+ if "proj_attn.weight" in new_path:
195
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
196
+ else:
197
+ checkpoint[new_path] = old_checkpoint[path["old"]]
198
+
199
+
200
+ def conv_attn_to_linear(checkpoint):
201
+ keys = list(checkpoint.keys())
202
+ attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
203
+ proj_key = "to_out.0.weight"
204
+ for key in keys:
205
+ if ".".join(key.split(".")[-2:]) in attn_keys or ".".join(key.split(".")[-3:]) == proj_key:
206
+ if checkpoint[key].ndim > 2:
207
+ checkpoint[key] = checkpoint[key].squeeze()
208
+
209
+
210
+ def create_unet_diffusers_config(original_config, image_size: int):
211
+ """
212
+ Creates a UNet config for diffusers based on the config of the original MusicLDM model.
213
+ """
214
+ unet_params = original_config["model"]["params"]["unet_config"]["params"]
215
+ vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
216
+
217
+ block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
218
+
219
+ down_block_types = []
220
+ resolution = 1
221
+ for i in range(len(block_out_channels)):
222
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
223
+ down_block_types.append(block_type)
224
+ if i != len(block_out_channels) - 1:
225
+ resolution *= 2
226
+
227
+ up_block_types = []
228
+ for i in range(len(block_out_channels)):
229
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
230
+ up_block_types.append(block_type)
231
+ resolution //= 2
232
+
233
+ vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
234
+
235
+ cross_attention_dim = (
236
+ unet_params["cross_attention_dim"] if "cross_attention_dim" in unet_params else block_out_channels
237
+ )
238
+
239
+ class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None
240
+ projection_class_embeddings_input_dim = (
241
+ unet_params["extra_film_condition_dim"] if "extra_film_condition_dim" in unet_params else None
242
+ )
243
+ class_embeddings_concat = unet_params["extra_film_use_concat"] if "extra_film_use_concat" in unet_params else None
244
+
245
+ config = {
246
+ "sample_size": image_size // vae_scale_factor,
247
+ "in_channels": unet_params["in_channels"],
248
+ "out_channels": unet_params["out_channels"],
249
+ "down_block_types": tuple(down_block_types),
250
+ "up_block_types": tuple(up_block_types),
251
+ "block_out_channels": tuple(block_out_channels),
252
+ "layers_per_block": unet_params["num_res_blocks"],
253
+ "cross_attention_dim": cross_attention_dim,
254
+ "class_embed_type": class_embed_type,
255
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
256
+ "class_embeddings_concat": class_embeddings_concat,
257
+ }
258
+
259
+ return config
260
+
261
+
262
+ # Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config
263
+ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
264
+ """
265
+ Creates a VAE config for diffusers based on the config of the original MusicLDM model. Compared to the original
266
+ Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
267
+ """
268
+ vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
269
+ _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
270
+
271
+ block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
272
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
273
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
274
+
275
+ scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215
276
+
277
+ config = {
278
+ "sample_size": image_size,
279
+ "in_channels": vae_params["in_channels"],
280
+ "out_channels": vae_params["out_ch"],
281
+ "down_block_types": tuple(down_block_types),
282
+ "up_block_types": tuple(up_block_types),
283
+ "block_out_channels": tuple(block_out_channels),
284
+ "latent_channels": vae_params["z_channels"],
285
+ "layers_per_block": vae_params["num_res_blocks"],
286
+ "scaling_factor": float(scaling_factor),
287
+ }
288
+ return config
289
+
290
+
291
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
292
+ def create_diffusers_schedular(original_config):
293
+ schedular = DDIMScheduler(
294
+ num_train_timesteps=original_config["model"]["params"]["timesteps"],
295
+ beta_start=original_config["model"]["params"]["linear_start"],
296
+ beta_end=original_config["model"]["params"]["linear_end"],
297
+ beta_schedule="scaled_linear",
298
+ )
299
+ return schedular
300
+
301
+
302
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
303
+ """
304
+ Takes a state dict and a config, and returns a converted checkpoint. Compared to the original Stable Diffusion
305
+ conversion, this function additionally converts the learnt film embedding linear layer.
306
+ """
307
+
308
+ # extract state_dict for UNet
309
+ unet_state_dict = {}
310
+ keys = list(checkpoint.keys())
311
+
312
+ unet_key = "model.diffusion_model."
313
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
314
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
315
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
316
+ print(
317
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
318
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
319
+ )
320
+ for key in keys:
321
+ if key.startswith("model.diffusion_model"):
322
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
323
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
324
+ else:
325
+ if sum(k.startswith("model_ema") for k in keys) > 100:
326
+ print(
327
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
328
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
329
+ )
330
+
331
+ for key in keys:
332
+ if key.startswith(unet_key):
333
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
334
+
335
+ new_checkpoint = {}
336
+
337
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
338
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
339
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
340
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
341
+
342
+ new_checkpoint["class_embedding.weight"] = unet_state_dict["film_emb.weight"]
343
+ new_checkpoint["class_embedding.bias"] = unet_state_dict["film_emb.bias"]
344
+
345
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
346
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
347
+
348
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
349
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
350
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
351
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
352
+
353
+ # Retrieves the keys for the input blocks only
354
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
355
+ input_blocks = {
356
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
357
+ for layer_id in range(num_input_blocks)
358
+ }
359
+
360
+ # Retrieves the keys for the middle blocks only
361
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
362
+ middle_blocks = {
363
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
364
+ for layer_id in range(num_middle_blocks)
365
+ }
366
+
367
+ # Retrieves the keys for the output blocks only
368
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
369
+ output_blocks = {
370
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
371
+ for layer_id in range(num_output_blocks)
372
+ }
373
+
374
+ for i in range(1, num_input_blocks):
375
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
376
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
377
+
378
+ resnets = [
379
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
380
+ ]
381
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
382
+
383
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
384
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
385
+ f"input_blocks.{i}.0.op.weight"
386
+ )
387
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
388
+ f"input_blocks.{i}.0.op.bias"
389
+ )
390
+
391
+ paths = renew_resnet_paths(resnets)
392
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
393
+ assign_to_checkpoint(
394
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
395
+ )
396
+
397
+ if len(attentions):
398
+ paths = renew_attention_paths(attentions)
399
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
400
+ assign_to_checkpoint(
401
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
402
+ )
403
+
404
+ resnet_0 = middle_blocks[0]
405
+ attentions = middle_blocks[1]
406
+ resnet_1 = middle_blocks[2]
407
+
408
+ resnet_0_paths = renew_resnet_paths(resnet_0)
409
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
410
+
411
+ resnet_1_paths = renew_resnet_paths(resnet_1)
412
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
413
+
414
+ attentions_paths = renew_attention_paths(attentions)
415
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
416
+ assign_to_checkpoint(
417
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
418
+ )
419
+
420
+ for i in range(num_output_blocks):
421
+ block_id = i // (config["layers_per_block"] + 1)
422
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
423
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
424
+ output_block_list = {}
425
+
426
+ for layer in output_block_layers:
427
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
428
+ if layer_id in output_block_list:
429
+ output_block_list[layer_id].append(layer_name)
430
+ else:
431
+ output_block_list[layer_id] = [layer_name]
432
+
433
+ if len(output_block_list) > 1:
434
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
435
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
436
+
437
+ resnet_0_paths = renew_resnet_paths(resnets)
438
+ paths = renew_resnet_paths(resnets)
439
+
440
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
441
+ assign_to_checkpoint(
442
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
443
+ )
444
+
445
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
446
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
447
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
448
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
449
+ f"output_blocks.{i}.{index}.conv.weight"
450
+ ]
451
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
452
+ f"output_blocks.{i}.{index}.conv.bias"
453
+ ]
454
+
455
+ # Clear attentions as they have been attributed above.
456
+ if len(attentions) == 2:
457
+ attentions = []
458
+
459
+ if len(attentions):
460
+ paths = renew_attention_paths(attentions)
461
+ meta_path = {
462
+ "old": f"output_blocks.{i}.1",
463
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
464
+ }
465
+ assign_to_checkpoint(
466
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
467
+ )
468
+ else:
469
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
470
+ for path in resnet_0_paths:
471
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
472
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
473
+
474
+ new_checkpoint[new_path] = unet_state_dict[old_path]
475
+
476
+ return new_checkpoint
477
+
478
+
479
+ # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
480
+ def convert_ldm_vae_checkpoint(checkpoint, config):
481
+ # extract state dict for VAE
482
+ vae_state_dict = {}
483
+ vae_key = "first_stage_model."
484
+ keys = list(checkpoint.keys())
485
+ for key in keys:
486
+ if key.startswith(vae_key):
487
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
488
+
489
+ new_checkpoint = {}
490
+
491
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
492
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
493
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
494
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
495
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
496
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
497
+
498
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
499
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
500
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
501
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
502
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
503
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
504
+
505
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
506
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
507
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
508
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
509
+
510
+ # Retrieves the keys for the encoder down blocks only
511
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
512
+ down_blocks = {
513
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
514
+ }
515
+
516
+ # Retrieves the keys for the decoder up blocks only
517
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
518
+ up_blocks = {
519
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
520
+ }
521
+
522
+ for i in range(num_down_blocks):
523
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
524
+
525
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
526
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
527
+ f"encoder.down.{i}.downsample.conv.weight"
528
+ )
529
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
530
+ f"encoder.down.{i}.downsample.conv.bias"
531
+ )
532
+
533
+ paths = renew_vae_resnet_paths(resnets)
534
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
535
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
536
+
537
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
538
+ num_mid_res_blocks = 2
539
+ for i in range(1, num_mid_res_blocks + 1):
540
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
541
+
542
+ paths = renew_vae_resnet_paths(resnets)
543
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
544
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
545
+
546
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
547
+ paths = renew_vae_attention_paths(mid_attentions)
548
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
549
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
550
+ conv_attn_to_linear(new_checkpoint)
551
+
552
+ for i in range(num_up_blocks):
553
+ block_id = num_up_blocks - 1 - i
554
+ resnets = [
555
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
556
+ ]
557
+
558
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
559
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
560
+ f"decoder.up.{block_id}.upsample.conv.weight"
561
+ ]
562
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
563
+ f"decoder.up.{block_id}.upsample.conv.bias"
564
+ ]
565
+
566
+ paths = renew_vae_resnet_paths(resnets)
567
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
568
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
569
+
570
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
571
+ num_mid_res_blocks = 2
572
+ for i in range(1, num_mid_res_blocks + 1):
573
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
574
+
575
+ paths = renew_vae_resnet_paths(resnets)
576
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
577
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
578
+
579
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
580
+ paths = renew_vae_attention_paths(mid_attentions)
581
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
582
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
583
+ conv_attn_to_linear(new_checkpoint)
584
+ return new_checkpoint
585
+
586
+
587
+ CLAP_KEYS_TO_MODIFY_MAPPING = {
588
+ "text_branch": "text_model",
589
+ "audio_branch": "audio_model.audio_encoder",
590
+ "attn": "attention.self",
591
+ "self.proj": "output.dense",
592
+ "attention.self_mask": "attn_mask",
593
+ "mlp.fc1": "intermediate.dense",
594
+ "mlp.fc2": "output.dense",
595
+ "norm1": "layernorm_before",
596
+ "norm2": "layernorm_after",
597
+ "bn0": "batch_norm",
598
+ }
599
+
600
+ CLAP_KEYS_TO_IGNORE = [
601
+ "text_transform",
602
+ "audio_transform",
603
+ "stft",
604
+ "logmel_extractor",
605
+ "tscam_conv",
606
+ "head",
607
+ "attn_mask",
608
+ ]
609
+
610
+ CLAP_EXPECTED_MISSING_KEYS = ["text_model.embeddings.token_type_ids"]
611
+
612
+
613
+ def convert_open_clap_checkpoint(checkpoint):
614
+ """
615
+ Takes a state dict and returns a converted CLAP checkpoint.
616
+ """
617
+ # extract state dict for CLAP text embedding model, discarding the audio component
618
+ model_state_dict = {}
619
+ model_key = "cond_stage_model.model."
620
+ keys = list(checkpoint.keys())
621
+ for key in keys:
622
+ if key.startswith(model_key):
623
+ model_state_dict[key.replace(model_key, "")] = checkpoint.get(key)
624
+
625
+ new_checkpoint = {}
626
+
627
+ sequential_layers_pattern = r".*sequential.(\d+).*"
628
+ text_projection_pattern = r".*_projection.(\d+).*"
629
+
630
+ for key, value in model_state_dict.items():
631
+ # check if key should be ignored in mapping - if so map it to a key name that we'll filter out at the end
632
+ for key_to_ignore in CLAP_KEYS_TO_IGNORE:
633
+ if key_to_ignore in key:
634
+ key = "spectrogram"
635
+
636
+ # check if any key needs to be modified
637
+ for key_to_modify, new_key in CLAP_KEYS_TO_MODIFY_MAPPING.items():
638
+ if key_to_modify in key:
639
+ key = key.replace(key_to_modify, new_key)
640
+
641
+ if re.match(sequential_layers_pattern, key):
642
+ # replace sequential layers with list
643
+ sequential_layer = re.match(sequential_layers_pattern, key).group(1)
644
+
645
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
646
+ elif re.match(text_projection_pattern, key):
647
+ projecton_layer = int(re.match(text_projection_pattern, key).group(1))
648
+
649
+ # Because in CLAP they use `nn.Sequential`...
650
+ transformers_projection_layer = 1 if projecton_layer == 0 else 2
651
+
652
+ key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
653
+
654
+ if "audio" and "qkv" in key:
655
+ # split qkv into query key and value
656
+ mixed_qkv = value
657
+ qkv_dim = mixed_qkv.size(0) // 3
658
+
659
+ query_layer = mixed_qkv[:qkv_dim]
660
+ key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
661
+ value_layer = mixed_qkv[qkv_dim * 2 :]
662
+
663
+ new_checkpoint[key.replace("qkv", "query")] = query_layer
664
+ new_checkpoint[key.replace("qkv", "key")] = key_layer
665
+ new_checkpoint[key.replace("qkv", "value")] = value_layer
666
+ elif key != "spectrogram":
667
+ new_checkpoint[key] = value
668
+
669
+ return new_checkpoint
670
+
671
+
672
+ def create_transformers_vocoder_config(original_config):
673
+ """
674
+ Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
675
+ """
676
+ vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"]
677
+
678
+ config = {
679
+ "model_in_dim": vocoder_params["num_mels"],
680
+ "sampling_rate": vocoder_params["sampling_rate"],
681
+ "upsample_initial_channel": vocoder_params["upsample_initial_channel"],
682
+ "upsample_rates": list(vocoder_params["upsample_rates"]),
683
+ "upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]),
684
+ "resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]),
685
+ "resblock_dilation_sizes": [
686
+ list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"]
687
+ ],
688
+ "normalize_before": False,
689
+ }
690
+
691
+ return config
692
+
693
+
694
+ def convert_hifigan_checkpoint(checkpoint, config):
695
+ """
696
+ Takes a state dict and config, and returns a converted HiFiGAN vocoder checkpoint.
697
+ """
698
+ # extract state dict for vocoder
699
+ vocoder_state_dict = {}
700
+ vocoder_key = "first_stage_model.vocoder."
701
+ keys = list(checkpoint.keys())
702
+ for key in keys:
703
+ if key.startswith(vocoder_key):
704
+ vocoder_state_dict[key.replace(vocoder_key, "")] = checkpoint.get(key)
705
+
706
+ # fix upsampler keys, everything else is correct already
707
+ for i in range(len(config.upsample_rates)):
708
+ vocoder_state_dict[f"upsampler.{i}.weight"] = vocoder_state_dict.pop(f"ups.{i}.weight")
709
+ vocoder_state_dict[f"upsampler.{i}.bias"] = vocoder_state_dict.pop(f"ups.{i}.bias")
710
+
711
+ if not config.normalize_before:
712
+ # if we don't set normalize_before then these variables are unused, so we set them to their initialised values
713
+ vocoder_state_dict["mean"] = torch.zeros(config.model_in_dim)
714
+ vocoder_state_dict["scale"] = torch.ones(config.model_in_dim)
715
+
716
+ return vocoder_state_dict
717
+
718
+
719
+ # Adapted from https://huggingface.co/spaces/haoheliu/MusicLDM-text-to-audio-generation/blob/84a0384742a22bd80c44e903e241f0623e874f1d/MusicLDM/utils.py#L72-L73
720
+ DEFAULT_CONFIG = {
721
+ "model": {
722
+ "params": {
723
+ "linear_start": 0.0015,
724
+ "linear_end": 0.0195,
725
+ "timesteps": 1000,
726
+ "channels": 8,
727
+ "scale_by_std": True,
728
+ "unet_config": {
729
+ "target": "MusicLDM.latent_diffusion.openaimodel.UNetModel",
730
+ "params": {
731
+ "extra_film_condition_dim": 512,
732
+ "extra_film_use_concat": True,
733
+ "in_channels": 8,
734
+ "out_channels": 8,
735
+ "model_channels": 128,
736
+ "attention_resolutions": [8, 4, 2],
737
+ "num_res_blocks": 2,
738
+ "channel_mult": [1, 2, 3, 5],
739
+ "num_head_channels": 32,
740
+ },
741
+ },
742
+ "first_stage_config": {
743
+ "target": "MusicLDM.variational_autoencoder.autoencoder.AutoencoderKL",
744
+ "params": {
745
+ "embed_dim": 8,
746
+ "ddconfig": {
747
+ "z_channels": 8,
748
+ "resolution": 256,
749
+ "in_channels": 1,
750
+ "out_ch": 1,
751
+ "ch": 128,
752
+ "ch_mult": [1, 2, 4],
753
+ "num_res_blocks": 2,
754
+ },
755
+ },
756
+ },
757
+ "vocoder_config": {
758
+ "target": "MusicLDM.first_stage_model.vocoder",
759
+ "params": {
760
+ "upsample_rates": [5, 4, 2, 2, 2],
761
+ "upsample_kernel_sizes": [16, 16, 8, 4, 4],
762
+ "upsample_initial_channel": 1024,
763
+ "resblock_kernel_sizes": [3, 7, 11],
764
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
765
+ "num_mels": 64,
766
+ "sampling_rate": 16000,
767
+ },
768
+ },
769
+ },
770
+ },
771
+ }
772
+
773
+
774
+ def load_pipeline_from_original_MusicLDM_ckpt(
775
+ checkpoint_path: str,
776
+ original_config_file: str = None,
777
+ image_size: int = 1024,
778
+ prediction_type: str = None,
779
+ extract_ema: bool = False,
780
+ scheduler_type: str = "ddim",
781
+ num_in_channels: int = None,
782
+ model_channels: int = None,
783
+ num_head_channels: int = None,
784
+ device: str = None,
785
+ from_safetensors: bool = False,
786
+ ) -> MusicLDMPipeline:
787
+ """
788
+ Load an MusicLDM pipeline object from a `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file.
789
+
790
+ Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the
791
+ global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
792
+ recommended that you override the default values and/or supply an `original_config_file` wherever possible.
793
+
794
+ Args:
795
+ checkpoint_path (`str`): Path to `.ckpt` file.
796
+ original_config_file (`str`):
797
+ Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
798
+ set to the MusicLDM-s-full-v2 config.
799
+ image_size (`int`, *optional*, defaults to 1024):
800
+ The image size that the model was trained on.
801
+ prediction_type (`str`, *optional*):
802
+ The prediction type that the model was trained on. If `None`, will be automatically
803
+ inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`.
804
+ num_in_channels (`int`, *optional*, defaults to None):
805
+ The number of UNet input channels. If `None`, it will be automatically inferred from the config.
806
+ model_channels (`int`, *optional*, defaults to None):
807
+ The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override
808
+ to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.
809
+ num_head_channels (`int`, *optional*, defaults to None):
810
+ The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override
811
+ to 32 for the small and medium checkpoints, and 64 for the large.
812
+ scheduler_type (`str`, *optional*, defaults to 'pndm'):
813
+ Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
814
+ "ddim"]`.
815
+ extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
816
+ checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
817
+ `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
818
+ inference. Non-EMA weights are usually better to continue fine-tuning.
819
+ device (`str`, *optional*, defaults to `None`):
820
+ The device to use. Pass `None` to determine automatically.
821
+ from_safetensors (`str`, *optional*, defaults to `False`):
822
+ If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
823
+ return: An MusicLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
824
+ """
825
+ if from_safetensors:
826
+ from safetensors import safe_open
827
+
828
+ checkpoint = {}
829
+ with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
830
+ for key in f.keys():
831
+ checkpoint[key] = f.get_tensor(key)
832
+ else:
833
+ if device is None:
834
+ device = "cuda" if torch.cuda.is_available() else "cpu"
835
+ checkpoint = torch.load(checkpoint_path, map_location=device)
836
+ else:
837
+ checkpoint = torch.load(checkpoint_path, map_location=device)
838
+
839
+ if "state_dict" in checkpoint:
840
+ checkpoint = checkpoint["state_dict"]
841
+
842
+ if original_config_file is None:
843
+ original_config = DEFAULT_CONFIG
844
+ else:
845
+ original_config = yaml.safe_load(original_config_file)
846
+
847
+ if num_in_channels is not None:
848
+ original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
849
+
850
+ if model_channels is not None:
851
+ original_config["model"]["params"]["unet_config"]["params"]["model_channels"] = model_channels
852
+
853
+ if num_head_channels is not None:
854
+ original_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = num_head_channels
855
+
856
+ if (
857
+ "parameterization" in original_config["model"]["params"]
858
+ and original_config["model"]["params"]["parameterization"] == "v"
859
+ ):
860
+ if prediction_type is None:
861
+ prediction_type = "v_prediction"
862
+ else:
863
+ if prediction_type is None:
864
+ prediction_type = "epsilon"
865
+
866
+ if image_size is None:
867
+ image_size = 512
868
+
869
+ num_train_timesteps = original_config["model"]["params"]["timesteps"]
870
+ beta_start = original_config["model"]["params"]["linear_start"]
871
+ beta_end = original_config["model"]["params"]["linear_end"]
872
+
873
+ scheduler = DDIMScheduler(
874
+ beta_end=beta_end,
875
+ beta_schedule="scaled_linear",
876
+ beta_start=beta_start,
877
+ num_train_timesteps=num_train_timesteps,
878
+ steps_offset=1,
879
+ clip_sample=False,
880
+ set_alpha_to_one=False,
881
+ prediction_type=prediction_type,
882
+ )
883
+ # make sure scheduler works correctly with DDIM
884
+ scheduler.register_to_config(clip_sample=False)
885
+
886
+ if scheduler_type == "pndm":
887
+ config = dict(scheduler.config)
888
+ config["skip_prk_steps"] = True
889
+ scheduler = PNDMScheduler.from_config(config)
890
+ elif scheduler_type == "lms":
891
+ scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
892
+ elif scheduler_type == "heun":
893
+ scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
894
+ elif scheduler_type == "euler":
895
+ scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
896
+ elif scheduler_type == "euler-ancestral":
897
+ scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
898
+ elif scheduler_type == "dpm":
899
+ scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
900
+ elif scheduler_type == "ddim":
901
+ scheduler = scheduler
902
+ else:
903
+ raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
904
+
905
+ # Convert the UNet2DModel
906
+ unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
907
+ unet = UNet2DConditionModel(**unet_config)
908
+
909
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
910
+ checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
911
+ )
912
+
913
+ unet.load_state_dict(converted_unet_checkpoint)
914
+
915
+ # Convert the VAE model
916
+ vae_config = create_vae_diffusers_config(original_config, checkpoint=checkpoint, image_size=image_size)
917
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
918
+
919
+ vae = AutoencoderKL(**vae_config)
920
+ vae.load_state_dict(converted_vae_checkpoint)
921
+
922
+ # Convert the text model
923
+ # MusicLDM uses the same tokenizer as the original CLAP model, but a slightly different configuration
924
+ config = ClapConfig.from_pretrained("laion/clap-htsat-unfused")
925
+ config.audio_config.update(
926
+ {
927
+ "patch_embeds_hidden_size": 128,
928
+ "hidden_size": 1024,
929
+ "depths": [2, 2, 12, 2],
930
+ }
931
+ )
932
+ tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
933
+ feature_extractor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused")
934
+
935
+ converted_text_model = convert_open_clap_checkpoint(checkpoint)
936
+ text_model = ClapModel(config)
937
+
938
+ missing_keys, unexpected_keys = text_model.load_state_dict(converted_text_model, strict=False)
939
+ # we expect not to have token_type_ids in our original state dict so let's ignore them
940
+ missing_keys = list(set(missing_keys) - set(CLAP_EXPECTED_MISSING_KEYS))
941
+
942
+ if len(unexpected_keys) > 0:
943
+ raise ValueError(f"Unexpected keys when loading CLAP model: {unexpected_keys}")
944
+
945
+ if len(missing_keys) > 0:
946
+ raise ValueError(f"Missing keys when loading CLAP model: {missing_keys}")
947
+
948
+ # Convert the vocoder model
949
+ vocoder_config = create_transformers_vocoder_config(original_config)
950
+ vocoder_config = SpeechT5HifiGanConfig(**vocoder_config)
951
+ converted_vocoder_checkpoint = convert_hifigan_checkpoint(checkpoint, vocoder_config)
952
+
953
+ vocoder = SpeechT5HifiGan(vocoder_config)
954
+ vocoder.load_state_dict(converted_vocoder_checkpoint)
955
+
956
+ # Instantiate the diffusers pipeline
957
+ pipe = MusicLDMPipeline(
958
+ vae=vae,
959
+ text_encoder=text_model,
960
+ tokenizer=tokenizer,
961
+ unet=unet,
962
+ scheduler=scheduler,
963
+ vocoder=vocoder,
964
+ feature_extractor=feature_extractor,
965
+ )
966
+
967
+ return pipe
968
+
969
+
970
+ if __name__ == "__main__":
971
+ parser = argparse.ArgumentParser()
972
+
973
+ parser.add_argument(
974
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
975
+ )
976
+ parser.add_argument(
977
+ "--original_config_file",
978
+ default=None,
979
+ type=str,
980
+ help="The YAML config file corresponding to the original architecture.",
981
+ )
982
+ parser.add_argument(
983
+ "--num_in_channels",
984
+ default=None,
985
+ type=int,
986
+ help="The number of input channels. If `None` number of input channels will be automatically inferred.",
987
+ )
988
+ parser.add_argument(
989
+ "--model_channels",
990
+ default=None,
991
+ type=int,
992
+ help="The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override"
993
+ " to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.",
994
+ )
995
+ parser.add_argument(
996
+ "--num_head_channels",
997
+ default=None,
998
+ type=int,
999
+ help="The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override"
1000
+ " to 32 for the small and medium checkpoints, and 64 for the large.",
1001
+ )
1002
+ parser.add_argument(
1003
+ "--scheduler_type",
1004
+ default="ddim",
1005
+ type=str,
1006
+ help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
1007
+ )
1008
+ parser.add_argument(
1009
+ "--image_size",
1010
+ default=None,
1011
+ type=int,
1012
+ help=("The image size that the model was trained on."),
1013
+ )
1014
+ parser.add_argument(
1015
+ "--prediction_type",
1016
+ default=None,
1017
+ type=str,
1018
+ help=("The prediction type that the model was trained on."),
1019
+ )
1020
+ parser.add_argument(
1021
+ "--extract_ema",
1022
+ action="store_true",
1023
+ help=(
1024
+ "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
1025
+ " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
1026
+ " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
1027
+ ),
1028
+ )
1029
+ parser.add_argument(
1030
+ "--from_safetensors",
1031
+ action="store_true",
1032
+ help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
1033
+ )
1034
+ parser.add_argument(
1035
+ "--to_safetensors",
1036
+ action="store_true",
1037
+ help="Whether to store pipeline in safetensors format or not.",
1038
+ )
1039
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
1040
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
1041
+ args = parser.parse_args()
1042
+
1043
+ pipe = load_pipeline_from_original_MusicLDM_ckpt(
1044
+ checkpoint_path=args.checkpoint_path,
1045
+ original_config_file=args.original_config_file,
1046
+ image_size=args.image_size,
1047
+ prediction_type=args.prediction_type,
1048
+ extract_ema=args.extract_ema,
1049
+ scheduler_type=args.scheduler_type,
1050
+ num_in_channels=args.num_in_channels,
1051
+ model_channels=args.model_channels,
1052
+ num_head_channels=args.num_head_channels,
1053
+ from_safetensors=args.from_safetensors,
1054
+ device=args.device,
1055
+ )
1056
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
diffusers/scripts/convert_pixart_sigma_to_diffusers.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from transformers import T5EncoderModel, T5Tokenizer
6
+
7
+ from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtSigmaPipeline, Transformer2DModel
8
+
9
+
10
+ ckpt_id = "PixArt-alpha"
11
+ # https://github.com/PixArt-alpha/PixArt-sigma/blob/dd087141864e30ec44f12cb7448dd654be065e88/scripts/inference.py#L158
12
+ interpolation_scale = {256: 0.5, 512: 1, 1024: 2, 2048: 4}
13
+
14
+
15
+ def main(args):
16
+ all_state_dict = torch.load(args.orig_ckpt_path)
17
+ state_dict = all_state_dict.pop("state_dict")
18
+ converted_state_dict = {}
19
+
20
+ # Patch embeddings.
21
+ converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
22
+ converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
23
+
24
+ # Caption projection.
25
+ converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
26
+ converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
27
+ converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
28
+ converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
29
+
30
+ # AdaLN-single LN
31
+ converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
32
+ "t_embedder.mlp.0.weight"
33
+ )
34
+ converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
35
+ converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
36
+ "t_embedder.mlp.2.weight"
37
+ )
38
+ converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
39
+
40
+ if args.micro_condition:
41
+ # Resolution.
42
+ converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = state_dict.pop(
43
+ "csize_embedder.mlp.0.weight"
44
+ )
45
+ converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = state_dict.pop(
46
+ "csize_embedder.mlp.0.bias"
47
+ )
48
+ converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = state_dict.pop(
49
+ "csize_embedder.mlp.2.weight"
50
+ )
51
+ converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = state_dict.pop(
52
+ "csize_embedder.mlp.2.bias"
53
+ )
54
+ # Aspect ratio.
55
+ converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = state_dict.pop(
56
+ "ar_embedder.mlp.0.weight"
57
+ )
58
+ converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = state_dict.pop(
59
+ "ar_embedder.mlp.0.bias"
60
+ )
61
+ converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = state_dict.pop(
62
+ "ar_embedder.mlp.2.weight"
63
+ )
64
+ converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = state_dict.pop(
65
+ "ar_embedder.mlp.2.bias"
66
+ )
67
+ # Shared norm.
68
+ converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight")
69
+ converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias")
70
+
71
+ for depth in range(28):
72
+ # Transformer blocks.
73
+ converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
74
+ f"blocks.{depth}.scale_shift_table"
75
+ )
76
+ # Attention is all you need 🤘
77
+
78
+ # Self attention.
79
+ q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
80
+ q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0)
81
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
82
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
83
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
84
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
85
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
86
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
87
+ # Projection.
88
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
89
+ f"blocks.{depth}.attn.proj.weight"
90
+ )
91
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
92
+ f"blocks.{depth}.attn.proj.bias"
93
+ )
94
+ if args.qk_norm:
95
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.weight"] = state_dict.pop(
96
+ f"blocks.{depth}.attn.q_norm.weight"
97
+ )
98
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.bias"] = state_dict.pop(
99
+ f"blocks.{depth}.attn.q_norm.bias"
100
+ )
101
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.weight"] = state_dict.pop(
102
+ f"blocks.{depth}.attn.k_norm.weight"
103
+ )
104
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.bias"] = state_dict.pop(
105
+ f"blocks.{depth}.attn.k_norm.bias"
106
+ )
107
+
108
+ # Feed-forward.
109
+ converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop(
110
+ f"blocks.{depth}.mlp.fc1.weight"
111
+ )
112
+ converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop(
113
+ f"blocks.{depth}.mlp.fc1.bias"
114
+ )
115
+ converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop(
116
+ f"blocks.{depth}.mlp.fc2.weight"
117
+ )
118
+ converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop(
119
+ f"blocks.{depth}.mlp.fc2.bias"
120
+ )
121
+
122
+ # Cross-attention.
123
+ q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
124
+ q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
125
+ k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
126
+ k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
127
+
128
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
129
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
130
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
131
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
132
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
133
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
134
+
135
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
136
+ f"blocks.{depth}.cross_attn.proj.weight"
137
+ )
138
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
139
+ f"blocks.{depth}.cross_attn.proj.bias"
140
+ )
141
+
142
+ # Final block.
143
+ converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
144
+ converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
145
+ converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
146
+
147
+ # PixArt XL/2
148
+ transformer = Transformer2DModel(
149
+ sample_size=args.image_size // 8,
150
+ num_layers=28,
151
+ attention_head_dim=72,
152
+ in_channels=4,
153
+ out_channels=8,
154
+ patch_size=2,
155
+ attention_bias=True,
156
+ num_attention_heads=16,
157
+ cross_attention_dim=1152,
158
+ activation_fn="gelu-approximate",
159
+ num_embeds_ada_norm=1000,
160
+ norm_type="ada_norm_single",
161
+ norm_elementwise_affine=False,
162
+ norm_eps=1e-6,
163
+ caption_channels=4096,
164
+ interpolation_scale=interpolation_scale[args.image_size],
165
+ use_additional_conditions=args.micro_condition,
166
+ )
167
+ transformer.load_state_dict(converted_state_dict, strict=True)
168
+
169
+ assert transformer.pos_embed.pos_embed is not None
170
+ try:
171
+ state_dict.pop("y_embedder.y_embedding")
172
+ state_dict.pop("pos_embed")
173
+ except Exception as e:
174
+ print(f"Skipping {str(e)}")
175
+ pass
176
+ assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
177
+
178
+ num_model_params = sum(p.numel() for p in transformer.parameters())
179
+ print(f"Total number of transformer parameters: {num_model_params}")
180
+
181
+ if args.only_transformer:
182
+ transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
183
+ else:
184
+ # pixart-Sigma vae link: https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main/vae
185
+ vae = AutoencoderKL.from_pretrained(f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="vae")
186
+
187
+ scheduler = DPMSolverMultistepScheduler()
188
+
189
+ tokenizer = T5Tokenizer.from_pretrained(f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="tokenizer")
190
+ text_encoder = T5EncoderModel.from_pretrained(
191
+ f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="text_encoder"
192
+ )
193
+
194
+ pipeline = PixArtSigmaPipeline(
195
+ tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
196
+ )
197
+
198
+ pipeline.save_pretrained(args.dump_path)
199
+
200
+
201
+ if __name__ == "__main__":
202
+ parser = argparse.ArgumentParser()
203
+
204
+ parser.add_argument(
205
+ "--micro_condition", action="store_true", help="If use Micro-condition in PixArtMS structure during training."
206
+ )
207
+ parser.add_argument("--qk_norm", action="store_true", help="If use qk norm during training.")
208
+ parser.add_argument(
209
+ "--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
210
+ )
211
+ parser.add_argument(
212
+ "--image_size",
213
+ default=1024,
214
+ type=int,
215
+ choices=[256, 512, 1024, 2048],
216
+ required=False,
217
+ help="Image size of pretrained model, 256, 512, 1024, or 2048.",
218
+ )
219
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
220
+ parser.add_argument("--only_transformer", default=True, type=bool, required=True)
221
+
222
+ args = parser.parse_args()
223
+ main(args)
diffusers/scripts/convert_sana_to_diffusers.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import os
6
+ from contextlib import nullcontext
7
+
8
+ import torch
9
+ from accelerate import init_empty_weights
10
+ from huggingface_hub import hf_hub_download, snapshot_download
11
+ from termcolor import colored
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+
14
+ from diffusers import (
15
+ AutoencoderDC,
16
+ DPMSolverMultistepScheduler,
17
+ FlowMatchEulerDiscreteScheduler,
18
+ SanaPipeline,
19
+ SanaSprintPipeline,
20
+ SanaTransformer2DModel,
21
+ SCMScheduler,
22
+ )
23
+ from diffusers.models.modeling_utils import load_model_dict_into_meta
24
+ from diffusers.utils.import_utils import is_accelerate_available
25
+
26
+
27
+ CTX = init_empty_weights if is_accelerate_available else nullcontext
28
+
29
+ ckpt_ids = [
30
+ "Efficient-Large-Model/Sana_Sprint_0.6B_1024px/checkpoints/Sana_Sprint_0.6B_1024px.pth"
31
+ "Efficient-Large-Model/Sana_Sprint_1.6B_1024px/checkpoints/Sana_Sprint_1.6B_1024px.pth"
32
+ "Efficient-Large-Model/SANA1.5_4.8B_1024px/checkpoints/SANA1.5_4.8B_1024px.pth",
33
+ "Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth",
34
+ "Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
35
+ "Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
36
+ "Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
37
+ "Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
38
+ "Efficient-Large-Model/Sana_1600M_512px_MultiLing/checkpoints/Sana_1600M_512px_MultiLing.pth",
39
+ "Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
40
+ "Efficient-Large-Model/Sana_1600M_512px/checkpoints/Sana_1600M_512px.pth",
41
+ "Efficient-Large-Model/Sana_600M_1024px/checkpoints/Sana_600M_1024px_MultiLing.pth",
42
+ "Efficient-Large-Model/Sana_600M_512px/checkpoints/Sana_600M_512px_MultiLing.pth",
43
+ ]
44
+ # https://github.com/NVlabs/Sana/blob/main/scripts/inference.py
45
+
46
+
47
+ def main(args):
48
+ cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
49
+
50
+ if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
51
+ ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
52
+ snapshot_download(
53
+ repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
54
+ cache_dir=cache_dir_path,
55
+ repo_type="model",
56
+ )
57
+ file_path = hf_hub_download(
58
+ repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
59
+ filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
60
+ cache_dir=cache_dir_path,
61
+ repo_type="model",
62
+ )
63
+ else:
64
+ file_path = args.orig_ckpt_path
65
+
66
+ print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
67
+ all_state_dict = torch.load(file_path, weights_only=True)
68
+ state_dict = all_state_dict.pop("state_dict")
69
+ converted_state_dict = {}
70
+
71
+ # Patch embeddings.
72
+ converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
73
+ converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
74
+
75
+ # Caption projection.
76
+ converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
77
+ converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
78
+ converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
79
+ converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
80
+
81
+ # Handle different time embedding structure based on model type
82
+
83
+ if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
84
+ # For Sana Sprint, the time embedding structure is different
85
+ converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop(
86
+ "t_embedder.mlp.0.weight"
87
+ )
88
+ converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
89
+ converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = state_dict.pop(
90
+ "t_embedder.mlp.2.weight"
91
+ )
92
+ converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
93
+
94
+ # Guidance embedder for Sana Sprint
95
+ converted_state_dict["time_embed.guidance_embedder.linear_1.weight"] = state_dict.pop(
96
+ "cfg_embedder.mlp.0.weight"
97
+ )
98
+ converted_state_dict["time_embed.guidance_embedder.linear_1.bias"] = state_dict.pop("cfg_embedder.mlp.0.bias")
99
+ converted_state_dict["time_embed.guidance_embedder.linear_2.weight"] = state_dict.pop(
100
+ "cfg_embedder.mlp.2.weight"
101
+ )
102
+ converted_state_dict["time_embed.guidance_embedder.linear_2.bias"] = state_dict.pop("cfg_embedder.mlp.2.bias")
103
+ else:
104
+ # Original Sana time embedding structure
105
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
106
+ "t_embedder.mlp.0.weight"
107
+ )
108
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop(
109
+ "t_embedder.mlp.0.bias"
110
+ )
111
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
112
+ "t_embedder.mlp.2.weight"
113
+ )
114
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop(
115
+ "t_embedder.mlp.2.bias"
116
+ )
117
+
118
+ # Shared norm.
119
+ converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
120
+ converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
121
+
122
+ # y norm
123
+ converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
124
+
125
+ # scheduler
126
+ if args.image_size == 4096:
127
+ flow_shift = 6.0
128
+ else:
129
+ flow_shift = 3.0
130
+
131
+ # model config
132
+ if args.model_type in ["SanaMS_1600M_P1_D20", "SanaSprint_1600M_P1_D20", "SanaMS1.5_1600M_P1_D20"]:
133
+ layer_num = 20
134
+ elif args.model_type in ["SanaMS_600M_P1_D28", "SanaSprint_600M_P1_D28"]:
135
+ layer_num = 28
136
+ elif args.model_type == "SanaMS_4800M_P1_D60":
137
+ layer_num = 60
138
+ else:
139
+ raise ValueError(f"{args.model_type} is not supported.")
140
+ # Positional embedding interpolation scale.
141
+ interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
142
+ qk_norm = (
143
+ "rms_norm_across_heads"
144
+ if args.model_type
145
+ in ["SanaMS1.5_1600M_P1_D20", "SanaMS1.5_4800M_P1_D60", "SanaSprint_600M_P1_D28", "SanaSprint_1600M_P1_D20"]
146
+ else None
147
+ )
148
+
149
+ for depth in range(layer_num):
150
+ # Transformer blocks.
151
+ converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
152
+ f"blocks.{depth}.scale_shift_table"
153
+ )
154
+
155
+ # Linear Attention is all you need 🤘
156
+ # Self attention.
157
+ q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
158
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
159
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
160
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
161
+ if qk_norm is not None:
162
+ # Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
163
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
164
+ f"blocks.{depth}.attn.q_norm.weight"
165
+ )
166
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
167
+ f"blocks.{depth}.attn.k_norm.weight"
168
+ )
169
+ # Projection.
170
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
171
+ f"blocks.{depth}.attn.proj.weight"
172
+ )
173
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
174
+ f"blocks.{depth}.attn.proj.bias"
175
+ )
176
+
177
+ # Feed-forward.
178
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
179
+ f"blocks.{depth}.mlp.inverted_conv.conv.weight"
180
+ )
181
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
182
+ f"blocks.{depth}.mlp.inverted_conv.conv.bias"
183
+ )
184
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
185
+ f"blocks.{depth}.mlp.depth_conv.conv.weight"
186
+ )
187
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
188
+ f"blocks.{depth}.mlp.depth_conv.conv.bias"
189
+ )
190
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
191
+ f"blocks.{depth}.mlp.point_conv.conv.weight"
192
+ )
193
+
194
+ # Cross-attention.
195
+ q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
196
+ q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
197
+ k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
198
+ k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
199
+
200
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
201
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
202
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
203
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
204
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
205
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
206
+ if qk_norm is not None:
207
+ # Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
208
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
209
+ f"blocks.{depth}.cross_attn.q_norm.weight"
210
+ )
211
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
212
+ f"blocks.{depth}.cross_attn.k_norm.weight"
213
+ )
214
+
215
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
216
+ f"blocks.{depth}.cross_attn.proj.weight"
217
+ )
218
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
219
+ f"blocks.{depth}.cross_attn.proj.bias"
220
+ )
221
+
222
+ # Final block.
223
+ converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
224
+ converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
225
+ converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
226
+
227
+ # Transformer
228
+ with CTX():
229
+ transformer_kwargs = {
230
+ "in_channels": 32,
231
+ "out_channels": 32,
232
+ "num_attention_heads": model_kwargs[args.model_type]["num_attention_heads"],
233
+ "attention_head_dim": model_kwargs[args.model_type]["attention_head_dim"],
234
+ "num_layers": model_kwargs[args.model_type]["num_layers"],
235
+ "num_cross_attention_heads": model_kwargs[args.model_type]["num_cross_attention_heads"],
236
+ "cross_attention_head_dim": model_kwargs[args.model_type]["cross_attention_head_dim"],
237
+ "cross_attention_dim": model_kwargs[args.model_type]["cross_attention_dim"],
238
+ "caption_channels": 2304,
239
+ "mlp_ratio": 2.5,
240
+ "attention_bias": False,
241
+ "sample_size": args.image_size // 32,
242
+ "patch_size": 1,
243
+ "norm_elementwise_affine": False,
244
+ "norm_eps": 1e-6,
245
+ "interpolation_scale": interpolation_scale[args.image_size],
246
+ }
247
+
248
+ # Add qk_norm parameter for Sana Sprint
249
+ if args.model_type in [
250
+ "SanaMS1.5_1600M_P1_D20",
251
+ "SanaMS1.5_4800M_P1_D60",
252
+ "SanaSprint_600M_P1_D28",
253
+ "SanaSprint_1600M_P1_D20",
254
+ ]:
255
+ transformer_kwargs["qk_norm"] = "rms_norm_across_heads"
256
+ if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
257
+ transformer_kwargs["guidance_embeds"] = True
258
+
259
+ transformer = SanaTransformer2DModel(**transformer_kwargs)
260
+
261
+ if is_accelerate_available():
262
+ load_model_dict_into_meta(transformer, converted_state_dict)
263
+ else:
264
+ transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
265
+
266
+ try:
267
+ state_dict.pop("y_embedder.y_embedding")
268
+ state_dict.pop("pos_embed")
269
+ state_dict.pop("logvar_linear.weight")
270
+ state_dict.pop("logvar_linear.bias")
271
+ except KeyError:
272
+ print("y_embedder.y_embedding or pos_embed not found in the state_dict")
273
+
274
+ assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
275
+
276
+ num_model_params = sum(p.numel() for p in transformer.parameters())
277
+ print(f"Total number of transformer parameters: {num_model_params}")
278
+
279
+ transformer = transformer.to(weight_dtype)
280
+
281
+ if not args.save_full_pipeline:
282
+ print(
283
+ colored(
284
+ f"Only saving transformer model of {args.model_type}. "
285
+ f"Set --save_full_pipeline to save the whole Pipeline",
286
+ "green",
287
+ attrs=["bold"],
288
+ )
289
+ )
290
+ transformer.save_pretrained(
291
+ os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
292
+ )
293
+ else:
294
+ print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
295
+ # VAE
296
+ ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32)
297
+
298
+ # Text Encoder
299
+ text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
300
+ tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
301
+ tokenizer.padding_side = "right"
302
+ text_encoder = AutoModelForCausalLM.from_pretrained(
303
+ text_encoder_model_path, torch_dtype=torch.bfloat16
304
+ ).get_decoder()
305
+
306
+ # Choose the appropriate pipeline and scheduler based on model type
307
+ if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
308
+ # Force SCM Scheduler for Sana Sprint regardless of scheduler_type
309
+ if args.scheduler_type != "scm":
310
+ print(
311
+ colored(
312
+ f"Warning: Overriding scheduler_type '{args.scheduler_type}' to 'scm' for SanaSprint model",
313
+ "yellow",
314
+ attrs=["bold"],
315
+ )
316
+ )
317
+
318
+ # SCM Scheduler for Sana Sprint
319
+ scheduler_config = {
320
+ "prediction_type": "trigflow",
321
+ "sigma_data": 0.5,
322
+ }
323
+ scheduler = SCMScheduler(**scheduler_config)
324
+ pipe = SanaSprintPipeline(
325
+ tokenizer=tokenizer,
326
+ text_encoder=text_encoder,
327
+ transformer=transformer,
328
+ vae=ae,
329
+ scheduler=scheduler,
330
+ )
331
+ else:
332
+ # Original Sana scheduler
333
+ if args.scheduler_type == "flow-dpm_solver":
334
+ scheduler = DPMSolverMultistepScheduler(
335
+ flow_shift=flow_shift,
336
+ use_flow_sigmas=True,
337
+ prediction_type="flow_prediction",
338
+ )
339
+ elif args.scheduler_type == "flow-euler":
340
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
341
+ else:
342
+ raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
343
+
344
+ pipe = SanaPipeline(
345
+ tokenizer=tokenizer,
346
+ text_encoder=text_encoder,
347
+ transformer=transformer,
348
+ vae=ae,
349
+ scheduler=scheduler,
350
+ )
351
+
352
+ pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
353
+
354
+
355
+ DTYPE_MAPPING = {
356
+ "fp32": torch.float32,
357
+ "fp16": torch.float16,
358
+ "bf16": torch.bfloat16,
359
+ }
360
+
361
+
362
+ if __name__ == "__main__":
363
+ parser = argparse.ArgumentParser()
364
+
365
+ parser.add_argument(
366
+ "--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
367
+ )
368
+ parser.add_argument(
369
+ "--image_size",
370
+ default=1024,
371
+ type=int,
372
+ choices=[512, 1024, 2048, 4096],
373
+ required=False,
374
+ help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
375
+ )
376
+ parser.add_argument(
377
+ "--model_type",
378
+ default="SanaMS_1600M_P1_D20",
379
+ type=str,
380
+ choices=[
381
+ "SanaMS_1600M_P1_D20",
382
+ "SanaMS_600M_P1_D28",
383
+ "SanaMS1.5_1600M_P1_D20",
384
+ "SanaMS1.5_4800M_P1_D60",
385
+ "SanaSprint_1600M_P1_D20",
386
+ "SanaSprint_600M_P1_D28",
387
+ ],
388
+ )
389
+ parser.add_argument(
390
+ "--scheduler_type",
391
+ default="flow-dpm_solver",
392
+ type=str,
393
+ choices=["flow-dpm_solver", "flow-euler", "scm"],
394
+ help="Scheduler type to use. Use 'scm' for Sana Sprint models.",
395
+ )
396
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
397
+ parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
398
+ parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
399
+
400
+ args = parser.parse_args()
401
+
402
+ model_kwargs = {
403
+ "SanaMS_1600M_P1_D20": {
404
+ "num_attention_heads": 70,
405
+ "attention_head_dim": 32,
406
+ "num_cross_attention_heads": 20,
407
+ "cross_attention_head_dim": 112,
408
+ "cross_attention_dim": 2240,
409
+ "num_layers": 20,
410
+ },
411
+ "SanaMS_600M_P1_D28": {
412
+ "num_attention_heads": 36,
413
+ "attention_head_dim": 32,
414
+ "num_cross_attention_heads": 16,
415
+ "cross_attention_head_dim": 72,
416
+ "cross_attention_dim": 1152,
417
+ "num_layers": 28,
418
+ },
419
+ "SanaMS1.5_1600M_P1_D20": {
420
+ "num_attention_heads": 70,
421
+ "attention_head_dim": 32,
422
+ "num_cross_attention_heads": 20,
423
+ "cross_attention_head_dim": 112,
424
+ "cross_attention_dim": 2240,
425
+ "num_layers": 20,
426
+ },
427
+ "SanaMS1.5_4800M_P1_D60": {
428
+ "num_attention_heads": 70,
429
+ "attention_head_dim": 32,
430
+ "num_cross_attention_heads": 20,
431
+ "cross_attention_head_dim": 112,
432
+ "cross_attention_dim": 2240,
433
+ "num_layers": 60,
434
+ },
435
+ "SanaSprint_600M_P1_D28": {
436
+ "num_attention_heads": 36,
437
+ "attention_head_dim": 32,
438
+ "num_cross_attention_heads": 16,
439
+ "cross_attention_head_dim": 72,
440
+ "cross_attention_dim": 1152,
441
+ "num_layers": 28,
442
+ },
443
+ "SanaSprint_1600M_P1_D20": {
444
+ "num_attention_heads": 70,
445
+ "attention_head_dim": 32,
446
+ "num_cross_attention_heads": 20,
447
+ "cross_attention_head_dim": 112,
448
+ "cross_attention_dim": 2240,
449
+ "num_layers": 20,
450
+ },
451
+ }
452
+
453
+ device = "cuda" if torch.cuda.is_available() else "cpu"
454
+ weight_dtype = DTYPE_MAPPING[args.dtype]
455
+
456
+ main(args)
diffusers/scripts/convert_stable_cascade.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
2
+ import argparse
3
+ from contextlib import nullcontext
4
+
5
+ import torch
6
+ from safetensors.torch import load_file
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ CLIPConfig,
10
+ CLIPImageProcessor,
11
+ CLIPTextModelWithProjection,
12
+ CLIPVisionModelWithProjection,
13
+ )
14
+
15
+ from diffusers import (
16
+ DDPMWuerstchenScheduler,
17
+ StableCascadeCombinedPipeline,
18
+ StableCascadeDecoderPipeline,
19
+ StableCascadePriorPipeline,
20
+ )
21
+ from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
22
+ from diffusers.models import StableCascadeUNet
23
+ from diffusers.models.modeling_utils import load_model_dict_into_meta
24
+ from diffusers.pipelines.wuerstchen import PaellaVQModel
25
+ from diffusers.utils import is_accelerate_available
26
+
27
+
28
+ if is_accelerate_available():
29
+ from accelerate import init_empty_weights
30
+
31
+ parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
32
+ parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
33
+ parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file")
34
+ parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file")
35
+ parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
36
+ parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
37
+ parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
38
+ parser.add_argument(
39
+ "--prior_output_path", default="stable-cascade-prior", type=str, help="Hub organization to save the pipelines to"
40
+ )
41
+ parser.add_argument(
42
+ "--decoder_output_path",
43
+ type=str,
44
+ default="stable-cascade-decoder",
45
+ help="Hub organization to save the pipelines to",
46
+ )
47
+ parser.add_argument(
48
+ "--combined_output_path",
49
+ type=str,
50
+ default="stable-cascade-combined",
51
+ help="Hub organization to save the pipelines to",
52
+ )
53
+ parser.add_argument("--save_combined", action="store_true")
54
+ parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
55
+ parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
56
+
57
+ args = parser.parse_args()
58
+
59
+ if args.skip_stage_b and args.skip_stage_c:
60
+ raise ValueError("At least one stage should be converted")
61
+ if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
62
+ raise ValueError("Cannot skip stages when creating a combined pipeline")
63
+
64
+ model_path = args.model_path
65
+
66
+ device = "cpu"
67
+ if args.variant == "bf16":
68
+ dtype = torch.bfloat16
69
+ else:
70
+ dtype = torch.float32
71
+
72
+ # set paths to model weights
73
+ prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
74
+ decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"
75
+
76
+ # Clip Text encoder and tokenizer
77
+ config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
78
+ config.text_config.projection_dim = config.projection_dim
79
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
80
+ "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
81
+ )
82
+ tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
83
+
84
+ # image processor
85
+ feature_extractor = CLIPImageProcessor()
86
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
87
+
88
+ # scheduler for prior and decoder
89
+ scheduler = DDPMWuerstchenScheduler()
90
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
91
+
92
+ if not args.skip_stage_c:
93
+ # Prior
94
+ if args.use_safetensors:
95
+ prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
96
+ else:
97
+ prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
98
+
99
+ prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)
100
+
101
+ with ctx():
102
+ prior_model = StableCascadeUNet(
103
+ in_channels=16,
104
+ out_channels=16,
105
+ timestep_ratio_embedding_dim=64,
106
+ patch_size=1,
107
+ conditioning_dim=2048,
108
+ block_out_channels=[2048, 2048],
109
+ num_attention_heads=[32, 32],
110
+ down_num_layers_per_block=[8, 24],
111
+ up_num_layers_per_block=[24, 8],
112
+ down_blocks_repeat_mappers=[1, 1],
113
+ up_blocks_repeat_mappers=[1, 1],
114
+ block_types_per_layer=[
115
+ ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
116
+ ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
117
+ ],
118
+ clip_text_in_channels=1280,
119
+ clip_text_pooled_in_channels=1280,
120
+ clip_image_in_channels=768,
121
+ clip_seq=4,
122
+ kernel_size=3,
123
+ dropout=[0.1, 0.1],
124
+ self_attn=True,
125
+ timestep_conditioning_type=["sca", "crp"],
126
+ switch_level=[False],
127
+ )
128
+ if is_accelerate_available():
129
+ load_model_dict_into_meta(prior_model, prior_state_dict)
130
+ else:
131
+ prior_model.load_state_dict(prior_state_dict)
132
+
133
+ # Prior pipeline
134
+ prior_pipeline = StableCascadePriorPipeline(
135
+ prior=prior_model,
136
+ tokenizer=tokenizer,
137
+ text_encoder=text_encoder,
138
+ image_encoder=image_encoder,
139
+ scheduler=scheduler,
140
+ feature_extractor=feature_extractor,
141
+ )
142
+ prior_pipeline.to(dtype).save_pretrained(
143
+ args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
144
+ )
145
+
146
+ if not args.skip_stage_b:
147
+ # Decoder
148
+ if args.use_safetensors:
149
+ decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
150
+ else:
151
+ decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
152
+
153
+ decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)
154
+ with ctx():
155
+ decoder = StableCascadeUNet(
156
+ in_channels=4,
157
+ out_channels=4,
158
+ timestep_ratio_embedding_dim=64,
159
+ patch_size=2,
160
+ conditioning_dim=1280,
161
+ block_out_channels=[320, 640, 1280, 1280],
162
+ down_num_layers_per_block=[2, 6, 28, 6],
163
+ up_num_layers_per_block=[6, 28, 6, 2],
164
+ down_blocks_repeat_mappers=[1, 1, 1, 1],
165
+ up_blocks_repeat_mappers=[3, 3, 2, 2],
166
+ num_attention_heads=[0, 0, 20, 20],
167
+ block_types_per_layer=[
168
+ ["SDCascadeResBlock", "SDCascadeTimestepBlock"],
169
+ ["SDCascadeResBlock", "SDCascadeTimestepBlock"],
170
+ ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
171
+ ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
172
+ ],
173
+ clip_text_pooled_in_channels=1280,
174
+ clip_seq=4,
175
+ effnet_in_channels=16,
176
+ pixel_mapper_in_channels=3,
177
+ kernel_size=3,
178
+ dropout=[0, 0, 0.1, 0.1],
179
+ self_attn=True,
180
+ timestep_conditioning_type=["sca"],
181
+ )
182
+
183
+ if is_accelerate_available():
184
+ load_model_dict_into_meta(decoder, decoder_state_dict)
185
+ else:
186
+ decoder.load_state_dict(decoder_state_dict)
187
+
188
+ # VQGAN from Wuerstchen-V2
189
+ vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
190
+
191
+ # Decoder pipeline
192
+ decoder_pipeline = StableCascadeDecoderPipeline(
193
+ decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
194
+ )
195
+ decoder_pipeline.to(dtype).save_pretrained(
196
+ args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
197
+ )
198
+
199
+ if args.save_combined:
200
+ # Stable Cascade combined pipeline
201
+ stable_cascade_pipeline = StableCascadeCombinedPipeline(
202
+ # Decoder
203
+ text_encoder=text_encoder,
204
+ tokenizer=tokenizer,
205
+ decoder=decoder,
206
+ scheduler=scheduler,
207
+ vqgan=vqmodel,
208
+ # Prior
209
+ prior_text_encoder=text_encoder,
210
+ prior_tokenizer=tokenizer,
211
+ prior_prior=prior_model,
212
+ prior_scheduler=scheduler,
213
+ prior_image_encoder=image_encoder,
214
+ prior_feature_extractor=feature_extractor,
215
+ )
216
+ stable_cascade_pipeline.to(dtype).save_pretrained(
217
+ args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
218
+ )
diffusers/scripts/convert_vae_pt_to_diffusers.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+
4
+ import requests
5
+ import torch
6
+ import yaml
7
+
8
+ from diffusers import AutoencoderKL
9
+ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
10
+ assign_to_checkpoint,
11
+ conv_attn_to_linear,
12
+ create_vae_diffusers_config,
13
+ renew_vae_attention_paths,
14
+ renew_vae_resnet_paths,
15
+ )
16
+ from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
17
+
18
+
19
+ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
20
+ vae_state_dict = checkpoint
21
+
22
+ new_checkpoint = {}
23
+
24
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
25
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
26
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
27
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
28
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
29
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
30
+
31
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
32
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
33
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
34
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
35
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
36
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
37
+
38
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
39
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
40
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
41
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
42
+
43
+ # Retrieves the keys for the encoder down blocks only
44
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
45
+ down_blocks = {
46
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
47
+ }
48
+
49
+ # Retrieves the keys for the decoder up blocks only
50
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
51
+ up_blocks = {
52
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
53
+ }
54
+
55
+ for i in range(num_down_blocks):
56
+ resnets = [
57
+ key
58
+ for key in down_blocks[i]
59
+ if f"down.{i}" in key and f"down.{i}.downsample" not in key and "attn" not in key
60
+ ]
61
+ attentions = [key for key in down_blocks[i] if f"down.{i}.attn" in key]
62
+
63
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
64
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
65
+ f"encoder.down.{i}.downsample.conv.weight"
66
+ )
67
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
68
+ f"encoder.down.{i}.downsample.conv.bias"
69
+ )
70
+
71
+ paths = renew_vae_resnet_paths(resnets)
72
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
73
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
74
+
75
+ paths = renew_vae_attention_paths(attentions)
76
+ meta_path = {"old": f"down.{i}.attn", "new": f"down_blocks.{i}.attentions"}
77
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
78
+
79
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
80
+ num_mid_res_blocks = 2
81
+ for i in range(1, num_mid_res_blocks + 1):
82
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
83
+
84
+ paths = renew_vae_resnet_paths(resnets)
85
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
86
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
87
+
88
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
89
+ paths = renew_vae_attention_paths(mid_attentions)
90
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
91
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
92
+ conv_attn_to_linear(new_checkpoint)
93
+
94
+ for i in range(num_up_blocks):
95
+ block_id = num_up_blocks - 1 - i
96
+ resnets = [
97
+ key
98
+ for key in up_blocks[block_id]
99
+ if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key and "attn" not in key
100
+ ]
101
+ attentions = [key for key in up_blocks[block_id] if f"up.{block_id}.attn" in key]
102
+
103
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
104
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
105
+ f"decoder.up.{block_id}.upsample.conv.weight"
106
+ ]
107
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
108
+ f"decoder.up.{block_id}.upsample.conv.bias"
109
+ ]
110
+
111
+ paths = renew_vae_resnet_paths(resnets)
112
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
113
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
114
+
115
+ paths = renew_vae_attention_paths(attentions)
116
+ meta_path = {"old": f"up.{block_id}.attn", "new": f"up_blocks.{i}.attentions"}
117
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
118
+
119
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
120
+ num_mid_res_blocks = 2
121
+ for i in range(1, num_mid_res_blocks + 1):
122
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
123
+
124
+ paths = renew_vae_resnet_paths(resnets)
125
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
126
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
127
+
128
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
129
+ paths = renew_vae_attention_paths(mid_attentions)
130
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
131
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
132
+ conv_attn_to_linear(new_checkpoint)
133
+ return new_checkpoint
134
+
135
+
136
+ def vae_pt_to_vae_diffuser(
137
+ checkpoint_path: str,
138
+ output_path: str,
139
+ ):
140
+ # Only support V1
141
+ r = requests.get(
142
+ " https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml",
143
+ timeout=DIFFUSERS_REQUEST_TIMEOUT,
144
+ )
145
+ io_obj = io.BytesIO(r.content)
146
+
147
+ original_config = yaml.safe_load(io_obj)
148
+ image_size = 512
149
+ device = "cuda" if torch.cuda.is_available() else "cpu"
150
+ if checkpoint_path.endswith("safetensors"):
151
+ from safetensors import safe_open
152
+
153
+ checkpoint = {}
154
+ with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
155
+ for key in f.keys():
156
+ checkpoint[key] = f.get_tensor(key)
157
+ else:
158
+ checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"]
159
+
160
+ # Convert the VAE model.
161
+ vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
162
+ converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config)
163
+
164
+ vae = AutoencoderKL(**vae_config)
165
+ vae.load_state_dict(converted_vae_checkpoint)
166
+ vae.save_pretrained(output_path)
167
+
168
+
169
+ if __name__ == "__main__":
170
+ parser = argparse.ArgumentParser()
171
+
172
+ parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
173
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
174
+
175
+ args = parser.parse_args()
176
+
177
+ vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path)
diffusers/scripts/convert_wuerstchen.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run inside root directory of official source code: https://github.com/dome272/wuerstchen/
2
+ import os
3
+
4
+ import torch
5
+ from transformers import AutoTokenizer, CLIPTextModel
6
+ from vqgan import VQModel
7
+
8
+ from diffusers import (
9
+ DDPMWuerstchenScheduler,
10
+ WuerstchenCombinedPipeline,
11
+ WuerstchenDecoderPipeline,
12
+ WuerstchenPriorPipeline,
13
+ )
14
+ from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior
15
+
16
+
17
+ model_path = "models/"
18
+ device = "cpu"
19
+
20
+ paella_vqmodel = VQModel()
21
+ state_dict = torch.load(os.path.join(model_path, "vqgan_f4_v1_500k.pt"), map_location=device)["state_dict"]
22
+ paella_vqmodel.load_state_dict(state_dict)
23
+
24
+ state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"]
25
+ state_dict.pop("vquantizer.codebook.weight")
26
+ vqmodel = PaellaVQModel(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent)
27
+ vqmodel.load_state_dict(state_dict)
28
+
29
+ # Clip Text encoder and tokenizer
30
+ text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
31
+ tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
32
+
33
+ # Generator
34
+ gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu")
35
+ gen_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
36
+
37
+ orig_state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["state_dict"]
38
+ state_dict = {}
39
+ for key in orig_state_dict.keys():
40
+ if key.endswith("in_proj_weight"):
41
+ weights = orig_state_dict[key].chunk(3, 0)
42
+ state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
43
+ state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
44
+ state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
45
+ elif key.endswith("in_proj_bias"):
46
+ weights = orig_state_dict[key].chunk(3, 0)
47
+ state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
48
+ state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
49
+ state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
50
+ elif key.endswith("out_proj.weight"):
51
+ weights = orig_state_dict[key]
52
+ state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
53
+ elif key.endswith("out_proj.bias"):
54
+ weights = orig_state_dict[key]
55
+ state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
56
+ else:
57
+ state_dict[key] = orig_state_dict[key]
58
+ decoder = WuerstchenDiffNeXt()
59
+ decoder.load_state_dict(state_dict)
60
+
61
+ # Prior
62
+ orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"]
63
+ state_dict = {}
64
+ for key in orig_state_dict.keys():
65
+ if key.endswith("in_proj_weight"):
66
+ weights = orig_state_dict[key].chunk(3, 0)
67
+ state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
68
+ state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
69
+ state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
70
+ elif key.endswith("in_proj_bias"):
71
+ weights = orig_state_dict[key].chunk(3, 0)
72
+ state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
73
+ state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
74
+ state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
75
+ elif key.endswith("out_proj.weight"):
76
+ weights = orig_state_dict[key]
77
+ state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
78
+ elif key.endswith("out_proj.bias"):
79
+ weights = orig_state_dict[key]
80
+ state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
81
+ else:
82
+ state_dict[key] = orig_state_dict[key]
83
+ prior_model = WuerstchenPrior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device)
84
+ prior_model.load_state_dict(state_dict)
85
+
86
+ # scheduler
87
+ scheduler = DDPMWuerstchenScheduler()
88
+
89
+ # Prior pipeline
90
+ prior_pipeline = WuerstchenPriorPipeline(
91
+ prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
92
+ )
93
+
94
+ prior_pipeline.save_pretrained("warp-ai/wuerstchen-prior")
95
+
96
+ decoder_pipeline = WuerstchenDecoderPipeline(
97
+ text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=decoder, scheduler=scheduler
98
+ )
99
+ decoder_pipeline.save_pretrained("warp-ai/wuerstchen")
100
+
101
+ # Wuerstchen pipeline
102
+ wuerstchen_pipeline = WuerstchenCombinedPipeline(
103
+ # Decoder
104
+ text_encoder=gen_text_encoder,
105
+ tokenizer=gen_tokenizer,
106
+ decoder=decoder,
107
+ scheduler=scheduler,
108
+ vqgan=vqmodel,
109
+ # Prior
110
+ prior_tokenizer=tokenizer,
111
+ prior_text_encoder=text_encoder,
112
+ prior=prior_model,
113
+ prior_scheduler=scheduler,
114
+ )
115
+ wuerstchen_pipeline.save_pretrained("warp-ai/WuerstchenCombinedPipeline")
illustrious_generated/low_quality_images.json ADDED
The diff for this file is too large to render. See raw diff
 
illustrious_generated/natural_caption_generation_report.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ === Natural Caption Generation Report ===
3
+
4
+ Processing Statistics:
5
+ - Total images processed: 9618
6
+ - Successfully captioned: 9618
7
+ - Errors encountered: 0
8
+ - Success rate: 100.0%
9
+
10
+ Time Statistics:
11
+ - Total processing time: 533.1 minutes
12
+ - Average time per image: 3.33 seconds
13
+
14
+ Completion time: 2025-07-29 20:42:34
illustrious_generated/optimization_final_results.json ADDED
The diff for this file is too large to render. See raw diff
 
illustrious_generated/optimization_summary_report.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ === 图像质量优化总结报告 ===
3
+
4
+ 处理统计:
5
+ - 总图像数: 9618
6
+ - 检测到低质量图像: 181
7
+ - 重新生成处理: 100
8
+ - 成功改善质量: 17
9
+ - 改善成功率: 17.0%
10
+
11
+ 质量提升:
12
+ - 平均质量提升: 2.3 分
13
+ - 改善图像保存位置: /home/ubuntu/lyl/QwenIllustrious/illustrious_generated/improved
14
+
15
+ 详细结果文件:
16
+ - 低质量图像记录: low_quality_images.json
17
+ - 重新生成结果: regeneration_results.json
18
+ - 最终优化结果: optimization_final_results.json
19
+
20
+ 优化完成时间: 2025-07-29 08:24:41
illustrious_generated/regeneration_results.json ADDED
The diff for this file is too large to render. See raw diff
 
peft/.gitignore ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # VSCode
132
+ .vscode
133
+
134
+ # IntelliJ
135
+ .idea
136
+
137
+ # Mac .DS_Store
138
+ .DS_Store
139
+
140
+ # More test things
141
+ wandb
142
+
143
+ # method_comparison logs
144
+ method_comparison/MetaMathQA/cancelled_results/
145
+ method_comparison/MetaMathQA/temporary_results/
peft/.pre-commit-config.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/astral-sh/ruff-pre-commit
3
+ rev: v0.9.2
4
+ hooks:
5
+ - id: ruff
6
+ args:
7
+ - --fix
8
+ - id: ruff-format
9
+ - repo: https://github.com/pre-commit/pre-commit-hooks
10
+ rev: v4.6.0
11
+ hooks:
12
+ - id: check-merge-conflict
13
+ - id: check-yaml
peft/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
peft/Makefile ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: quality style test docs
2
+
3
+ check_dirs := src tests examples docs scripts docker
4
+
5
+ # Check that source code meets quality standards
6
+
7
+ # this target runs checks on all files
8
+ quality:
9
+ ruff check $(check_dirs)
10
+ ruff format --check $(check_dirs)
11
+ doc-builder style src/peft tests docs/source --max_len 119 --check_only
12
+
13
+ # Format source code automatically and check is there are any problems left that need manual fixing
14
+ style:
15
+ ruff check --fix $(check_dirs)
16
+ ruff format $(check_dirs)
17
+ doc-builder style src/peft tests docs/source --max_len 119
18
+
19
+ test:
20
+ python -m pytest -n 3 tests/ $(if $(IS_GITHUB_CI),--report-log "ci_tests.log",)
21
+
22
+ tests_examples_multi_gpu:
23
+ python -m pytest -m multi_gpu_tests tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "multi_gpu_examples.log",)
24
+
25
+ tests_examples_single_gpu:
26
+ python -m pytest -m single_gpu_tests tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "single_gpu_examples.log",)
27
+
28
+ tests_core_multi_gpu:
29
+ python -m pytest -m multi_gpu_tests tests/test_common_gpu.py $(if $(IS_GITHUB_CI),--report-log "core_multi_gpu.log",)
30
+
31
+ tests_core_single_gpu:
32
+ python -m pytest -m single_gpu_tests tests/test_common_gpu.py $(if $(IS_GITHUB_CI),--report-log "core_single_gpu.log",)
33
+
34
+ # exclude gemma tests, as generation fails with torch.compile, these failures
35
+ # trigger side effects that make other tests fail with 'RuntimeError: Offset
36
+ # increment outside graph capture encountered unexpectedly.'
37
+ # TODO re-enable gemma once/if it is fixed
38
+ tests_common_gpu:
39
+ python -m pytest tests/test_decoder_models.py -k "not gemma" $(if $(IS_GITHUB_CI),--report-log "common_decoder.log",)
40
+ python -m pytest tests/test_encoder_decoder_models.py $(if $(IS_GITHUB_CI),--report-log "common_encoder_decoder.log",)
41
+ python -m pytest tests/test_gptqmodel.py $(if $(IS_GITHUB_CI),--report-log "gptqmodel_gpu.log",)
42
+
43
+ tests_examples_multi_gpu_bnb:
44
+ python -m pytest -m "multi_gpu_tests and bitsandbytes" tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "multi_gpu_examples.log",)
45
+
46
+ tests_examples_single_gpu_bnb:
47
+ python -m pytest -m "single_gpu_tests and bitsandbytes" tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "single_gpu_examples.log",)
48
+
49
+ tests_core_multi_gpu_bnb:
50
+ python -m pytest -m "multi_gpu_tests and bitsandbytes" tests/test_common_gpu.py $(if $(IS_GITHUB_CI),--report-log "core_multi_gpu.log",)
51
+
52
+ tests_core_single_gpu_bnb:
53
+ python -m pytest -m "single_gpu_tests and bitsandbytes" tests/test_common_gpu.py $(if $(IS_GITHUB_CI),--report-log "core_single_gpu.log",)
54
+
55
+ tests_gpu_bnb_regression:
56
+ python -m pytest tests/bnb/test_bnb_regression.py $(if $(IS_GITHUB_CI),--report-log "bnb_regression_gpu.log",)
57
+
58
+ # For testing transformers tests for bnb runners
59
+ transformers_tests:
60
+ RUN_SLOW=1 python -m pytest transformers-clone/tests/quantization/bnb $(if $(IS_GITHUB_CI),--report-log "transformers_tests.log",)
61
+
62
+ tests_regression:
63
+ python -m pytest -s --regression tests/regression/ $(if $(IS_GITHUB_CI),--report-log "regression_tests.log",)
64
+
65
+ tests_torch_compile:
66
+ python -m pytest tests/test_torch_compile.py $(if $(IS_GITHUB_CI),--report-log "compile_tests.log",)
peft/README.md ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!---
2
+ Copyright 2023 The HuggingFace Team. All rights reserved.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ -->
16
+
17
+ <h1 align="center"> <p>🤗 PEFT</p></h1>
18
+ <h3 align="center">
19
+ <p>State-of-the-art Parameter-Efficient Fine-Tuning (PEFT) methods</p>
20
+ </h3>
21
+
22
+ Fine-tuning large pretrained models is often prohibitively costly due to their scale. Parameter-Efficient Fine-Tuning (PEFT) methods enable efficient adaptation of large pretrained models to various downstream applications by only fine-tuning a small number of (extra) model parameters instead of all the model's parameters. This significantly decreases the computational and storage costs. Recent state-of-the-art PEFT techniques achieve performance comparable to fully fine-tuned models.
23
+
24
+ PEFT is integrated with Transformers for easy model training and inference, Diffusers for conveniently managing different adapters, and Accelerate for distributed training and inference for really big models.
25
+
26
+ > [!TIP]
27
+ > Visit the [PEFT](https://huggingface.co/PEFT) organization to read about the PEFT methods implemented in the library and to see notebooks demonstrating how to apply these methods to a variety of downstream tasks. Click the "Watch repos" button on the organization page to be notified of newly implemented methods and notebooks!
28
+
29
+ Check the PEFT Adapters API Reference section for a list of supported PEFT methods, and read the [Adapters](https://huggingface.co/docs/peft/en/conceptual_guides/adapter), [Soft prompts](https://huggingface.co/docs/peft/en/conceptual_guides/prompting), and [IA3](https://huggingface.co/docs/peft/en/conceptual_guides/ia3) conceptual guides to learn more about how these methods work.
30
+
31
+ ## Quickstart
32
+
33
+ Install PEFT from pip:
34
+
35
+ ```bash
36
+ pip install peft
37
+ ```
38
+
39
+ Prepare a model for training with a PEFT method such as LoRA by wrapping the base model and PEFT configuration with `get_peft_model`. For the bigscience/mt0-large model, you're only training 0.19% of the parameters!
40
+
41
+ ```python
42
+ from transformers import AutoModelForCausalLM
43
+ from peft import LoraConfig, TaskType, get_peft_model
44
+
45
+ device = "cuda"
46
+ model_id = "Qwen/Qwen2.5-3B-Instruct"
47
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)
48
+ peft_config = LoraConfig(
49
+ r=16,
50
+ lora_alpha=32,
51
+ task_type=TaskType.CAUSAL_LM,
52
+ # target_modules=["q_proj", "v_proj", ...] # optionally indicate target modules
53
+ )
54
+ model = get_peft_model(model, peft_config)
55
+ model.print_trainable_parameters()
56
+ # prints: trainable params: 3,686,400 || all params: 3,089,625,088 || trainable%: 0.1193
57
+
58
+ # now perform training on your dataset, e.g. using transformers Trainer, then save the model
59
+ model.save_pretrained("qwen2.5-3b-lora")
60
+ ```
61
+
62
+ To load a PEFT model for inference:
63
+
64
+ ```python
65
+ from transformers import AutoModelForCausalLM, AutoTokenizer
66
+ from peft import PeftModel
67
+
68
+ device = "cuda"
69
+ model_id = "Qwen/Qwen2.5-3B-Instruct"
70
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
71
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)
72
+ model = PeftModel.from_pretrained(model, "qwen2.5-3b-lora")
73
+
74
+ inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors="pt")
75
+ outputs = model.generate(**inputs.to(device), max_new_tokens=50)
76
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
77
+
78
+ # prints something like: Preheat the oven to 350 degrees and place the cookie dough in a baking dish [...]
79
+ ```
80
+
81
+ ## Why you should use PEFT
82
+
83
+ There are many benefits of using PEFT but the main one is the huge savings in compute and storage, making PEFT applicable to many different use cases.
84
+
85
+ ### High performance on consumer hardware
86
+
87
+ Consider the memory requirements for training the following models on the [ought/raft/twitter_complaints](https://huggingface.co/datasets/ought/raft/viewer/twitter_complaints) dataset with an A100 80GB GPU with more than 64GB of CPU RAM.
88
+
89
+ | Model | Full Finetuning | PEFT-LoRA PyTorch | PEFT-LoRA DeepSpeed with CPU Offloading |
90
+ | --------- | ---- | ---- | ---- |
91
+ | bigscience/T0_3B (3B params) | 47.14GB GPU / 2.96GB CPU | 14.4GB GPU / 2.96GB CPU | 9.8GB GPU / 17.8GB CPU |
92
+ | bigscience/mt0-xxl (12B params) | OOM GPU | 56GB GPU / 3GB CPU | 22GB GPU / 52GB CPU |
93
+ | bigscience/bloomz-7b1 (7B params) | OOM GPU | 32GB GPU / 3.8GB CPU | 18.1GB GPU / 35GB CPU |
94
+
95
+ With LoRA you can fully finetune a 12B parameter model that would've otherwise run out of memory on the 80GB GPU, and comfortably fit and train a 3B parameter model. When you look at the 3B parameter model's performance, it is comparable to a fully finetuned model at a fraction of the GPU memory.
96
+
97
+ | Submission Name | Accuracy |
98
+ | --------- | ---- |
99
+ | Human baseline (crowdsourced) | 0.897 |
100
+ | Flan-T5 | 0.892 |
101
+ | lora-t0-3b | 0.863 |
102
+
103
+ > [!TIP]
104
+ > The bigscience/T0_3B model performance isn't optimized in the table above. You can squeeze even more performance out of it by playing around with the input instruction templates, LoRA hyperparameters, and other training related hyperparameters. The final checkpoint size of this model is just 19MB compared to 11GB of the full bigscience/T0_3B model. Learn more about the advantages of finetuning with PEFT in this [blog post](https://www.philschmid.de/fine-tune-flan-t5-peft).
105
+
106
+ ### Quantization
107
+
108
+ Quantization is another method for reducing the memory requirements of a model by representing the data in a lower precision. It can be combined with PEFT methods to make it even easier to train and load LLMs for inference.
109
+
110
+ * Learn how to finetune [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) with QLoRA and the [TRL](https://huggingface.co/docs/trl/index) library on a 16GB GPU in the [Finetune LLMs on your own consumer hardware using tools from PyTorch and Hugging Face ecosystem](https://pytorch.org/blog/finetune-llms/) blog post.
111
+ * Learn how to finetune a [openai/whisper-large-v2](https://huggingface.co/openai/whisper-large-v2) model for multilingual automatic speech recognition with LoRA and 8-bit quantization in this [notebook](https://colab.research.google.com/drive/1DOkD_5OUjFa0r5Ik3SgywJLJtEo2qLxO?usp=sharing) (see this [notebook](https://colab.research.google.com/drive/1vhF8yueFqha3Y3CpTHN6q9EVcII9EYzs?usp=sharing) instead for an example of streaming a dataset).
112
+
113
+ ### Save compute and storage
114
+
115
+ PEFT can help you save storage by avoiding full finetuning of models on each of downstream task or dataset. In many cases, you're only finetuning a very small fraction of a model's parameters and each checkpoint is only a few MBs in size (instead of GBs). These smaller PEFT adapters demonstrate performance comparable to a fully finetuned model. If you have many datasets, you can save a lot of storage with a PEFT model and not have to worry about catastrophic forgetting or overfitting the backbone or base model.
116
+
117
+ ## PEFT integrations
118
+
119
+ PEFT is widely supported across the Hugging Face ecosystem because of the massive efficiency it brings to training and inference.
120
+
121
+ ### Diffusers
122
+
123
+ The iterative diffusion process consumes a lot of memory which can make it difficult to train. PEFT can help reduce the memory requirements and reduce the storage size of the final model checkpoint. For example, consider the memory required for training a Stable Diffusion model with LoRA on an A100 80GB GPU with more than 64GB of CPU RAM. The final model checkpoint size is only 8.8MB!
124
+
125
+ | Model | Full Finetuning | PEFT-LoRA | PEFT-LoRA with Gradient Checkpointing |
126
+ | --------- | ---- | ---- | ---- |
127
+ | CompVis/stable-diffusion-v1-4 | 27.5GB GPU / 3.97GB CPU | 15.5GB GPU / 3.84GB CPU | 8.12GB GPU / 3.77GB CPU |
128
+
129
+ > [!TIP]
130
+ > Take a look at the [examples/lora_dreambooth/train_dreambooth.py](examples/lora_dreambooth/train_dreambooth.py) training script to try training your own Stable Diffusion model with LoRA, and play around with the [smangrul/peft-lora-sd-dreambooth](https://huggingface.co/spaces/smangrul/peft-lora-sd-dreambooth) Space which is running on a T4 instance. Learn more about the PEFT integration in Diffusers in this [tutorial](https://huggingface.co/docs/peft/main/en/tutorial/peft_integrations#diffusers).
131
+
132
+ ### Transformers
133
+
134
+ PEFT is directly integrated with [Transformers](https://huggingface.co/docs/transformers/main/en/peft). After loading a model, call `add_adapter` to add a new PEFT adapter to the model:
135
+
136
+ ```python
137
+ from peft import LoraConfig
138
+ model = ... # transformers model
139
+ peft_config = LoraConfig(...)
140
+ model.add_adapter(lora_config, adapter_name="lora_1")
141
+ ```
142
+
143
+ To load a trained PEFT adapter, call `load_adapter`:
144
+
145
+ ```python
146
+ model = ... # transformers model
147
+ model.load_adapter(<path-to-adapter>, adapter_name="lora_1")
148
+ ```
149
+
150
+ And to switch between different adapters, call `set_adapter`:
151
+
152
+ ```python
153
+ model.set_adapter("lora_2")
154
+ ```
155
+
156
+ The Transformers integration doesn't include all the functionalities offered in PEFT, such as methods for merging the adapter into the base model.
157
+
158
+ ### Accelerate
159
+
160
+ [Accelerate](https://huggingface.co/docs/accelerate/index) is a library for distributed training and inference on various training setups and hardware (GPUs, TPUs, Apple Silicon, etc.). PEFT models work with Accelerate out of the box, making it really convenient to train really large models or use them for inference on consumer hardware with limited resources.
161
+
162
+ ### TRL
163
+
164
+ PEFT can also be applied to training LLMs with RLHF components such as the ranker and policy. Get started by reading:
165
+
166
+ * [Fine-tune a Mistral-7b model with Direct Preference Optimization](https://towardsdatascience.com/fine-tune-a-mistral-7b-model-with-direct-preference-optimization-708042745aac) with PEFT and the [TRL](https://huggingface.co/docs/trl/index) library to learn more about the Direct Preference Optimization (DPO) method and how to apply it to a LLM.
167
+ * [Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU](https://huggingface.co/blog/trl-peft) with PEFT and the [TRL](https://huggingface.co/docs/trl/index) library, and then try out the [gpt2-sentiment_peft.ipynb](https://github.com/huggingface/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb) notebook to optimize GPT2 to generate positive movie reviews.
168
+ * [StackLLaMA: A hands-on guide to train LLaMA with RLHF](https://huggingface.co/blog/stackllama) with PEFT, and then try out the [stack_llama/scripts](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama/scripts) for supervised finetuning, reward modeling, and RL finetuning.
169
+
170
+ ## Model support
171
+
172
+ Use this [Space](https://stevhliu-peft-methods.hf.space) or check out the [docs](https://huggingface.co/docs/peft/main/en/index) to find which models officially support a PEFT method out of the box. Even if you don't see a model listed below, you can manually configure the model config to enable PEFT for a model. Read the [New transformers architecture](https://huggingface.co/docs/peft/main/en/developer_guides/custom_models#new-transformers-architectures) guide to learn how.
173
+
174
+ ## Contribute
175
+
176
+ If you would like to contribute to PEFT, please check out our [contribution guide](https://huggingface.co/docs/peft/developer_guides/contributing).
177
+
178
+ ## Citing 🤗 PEFT
179
+
180
+ To use 🤗 PEFT in your publication, please cite it by using the following BibTeX entry.
181
+
182
+ ```bibtex
183
+ @Misc{peft,
184
+ title = {PEFT: State-of-the-art Parameter-Efficient Fine-Tuning methods},
185
+ author = {Sourab Mangrulkar and Sylvain Gugger and Lysandre Debut and Younes Belkada and Sayak Paul and Benjamin Bossan},
186
+ howpublished = {\url{https://github.com/huggingface/peft}},
187
+ year = {2022}
188
+ }
189
+ ```
peft/pyproject.toml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ # Only used by `hf-doc-builder´.
3
+ line-length = 119
4
+ target-version = ['py38']
5
+
6
+ [tool.ruff]
7
+ target-version = "py39"
8
+ line-length = 119
9
+ extend-exclude = ["*.ipynb"]
10
+
11
+ [tool.ruff.lint]
12
+ preview = true
13
+ explicit-preview-rules = true
14
+ extend-select = [
15
+ "C", # Complexity
16
+ "E", # PEP8 errors
17
+ "F", # PEP8 formatting
18
+ "I", # Import sorting
19
+ "UP", # Pyupgrade upgrades
20
+ "W", # PEP8 warnings
21
+ "PT009", # Pytest assertions
22
+ "RUF022", # Sorting of __all__
23
+ ]
24
+ ignore = [
25
+ "C901", # Function too complex
26
+ "E501", # Line length (handled by ruff-format)
27
+ "F841", # unused variable
28
+ "UP007", # X | Y style Unions
29
+ "C420", # dict.fromkeys
30
+ ]
31
+
32
+ [tool.ruff.lint.isort]
33
+ lines-after-imports = 2
34
+ known-first-party = ["peft"]
35
+
36
+ [tool.pytest]
37
+ doctest_optionflags = [
38
+ "NORMALIZE_WHITESPACE",
39
+ "ELLIPSIS",
40
+ "NUMBER",
41
+ ]
42
+
43
+ [tool.pytest.ini_options]
44
+ addopts = "--cov=src/peft --cov-report=term-missing --durations=10"
45
+ markers = [
46
+ "single_gpu_tests: tests that run on a single GPU",
47
+ "multi_gpu_tests: tests that run on multiple GPUs",
48
+ "regression: whether to run regression suite test",
49
+ "bitsandbytes: select bitsandbytes integration tests"
50
+ ]
peft/requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ torch
3
+ safetensors
4
+ bitsandbytes
5
+ scipy
6
+ peft
7
+ transformers
8
+ tqdm
9
+ packaging
10
+ pytest
11
+ numpy
12
+ pyyaml
13
+ datasets
14
+ psutil
15
+ setuptools
peft/setup.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from setuptools import find_packages, setup
16
+
17
+
18
+ VERSION = "0.16.1.dev0"
19
+
20
+ extras = {}
21
+ extras["quality"] = [
22
+ "black", # doc-builder has an implicit dependency on Black, see huggingface/doc-builder#434
23
+ "hf-doc-builder",
24
+ "ruff~=0.9.2",
25
+ ]
26
+ extras["docs_specific"] = [
27
+ "black", # doc-builder has an implicit dependency on Black, see huggingface/doc-builder#434
28
+ "hf-doc-builder",
29
+ ]
30
+ extras["dev"] = extras["quality"] + extras["docs_specific"]
31
+ extras["test"] = extras["dev"] + [
32
+ "pytest",
33
+ "pytest-cov",
34
+ "pytest-xdist",
35
+ "parameterized",
36
+ "datasets",
37
+ "diffusers",
38
+ "scipy",
39
+ "protobuf",
40
+ "sentencepiece",
41
+ ]
42
+
43
+ setup(
44
+ name="peft",
45
+ version=VERSION,
46
+ description="Parameter-Efficient Fine-Tuning (PEFT)",
47
+ license_files=["LICENSE"],
48
+ long_description=open("README.md", encoding="utf-8").read(),
49
+ long_description_content_type="text/markdown",
50
+ keywords="deep learning",
51
+ license="Apache",
52
+ author="The HuggingFace team",
53
+ author_email="benjamin@huggingface.co",
54
+ url="https://github.com/huggingface/peft",
55
+ package_dir={"": "src"},
56
+ packages=find_packages("src"),
57
+ package_data={"peft": ["py.typed", "tuners/boft/fbd/fbd_cuda.cpp", "tuners/boft/fbd/fbd_cuda_kernel.cu"]},
58
+ entry_points={},
59
+ python_requires=">=3.9.0",
60
+ install_requires=[
61
+ "numpy>=1.17",
62
+ "packaging>=20.0",
63
+ "psutil",
64
+ "pyyaml",
65
+ "torch>=1.13.0",
66
+ "transformers",
67
+ "tqdm",
68
+ "accelerate>=0.21.0",
69
+ "safetensors",
70
+ "huggingface_hub>=0.25.0",
71
+ ],
72
+ extras_require=extras,
73
+ classifiers=[
74
+ "Development Status :: 5 - Production/Stable",
75
+ "Intended Audience :: Developers",
76
+ "Intended Audience :: Education",
77
+ "Intended Audience :: Science/Research",
78
+ "License :: OSI Approved :: Apache Software License",
79
+ "Operating System :: OS Independent",
80
+ "Programming Language :: Python :: 3",
81
+ "Programming Language :: Python :: 3.9",
82
+ "Programming Language :: Python :: 3.10",
83
+ "Programming Language :: Python :: 3.11",
84
+ "Programming Language :: Python :: 3.12",
85
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
86
+ ],
87
+ )
88
+
89
+ # Release checklist
90
+ # 1. Change the version in __init__.py and setup.py to the release version, e.g. from "0.6.1.dev0" to "0.7.0"
91
+ # 2. Check if there are any deprecations that need to be addressed for this release by searching for "# TODO" in the code
92
+ # 3. Commit these changes with the message: "Release: VERSION", create a PR and merge it.
93
+ # 4. Add a tag in git to mark the release: "git tag -a v<VERSION> -m 'Adds tag <VERSION> for pypi' "
94
+ # Push the tag to git:
95
+ # git push --tags origin main
96
+ # It is necessary to work on the original repository, not on a fork.
97
+ # 5. Run the following commands in the top-level directory:
98
+ # python setup.py bdist_wheel
99
+ # python setup.py sdist
100
+ # Ensure that you are on the clean and up-to-date main branch (git status --untracked-files=no should not list any
101
+ # files and show the main branch)
102
+ # 6. Upload the package to the pypi test server first:
103
+ # twine upload dist/* -r pypitest
104
+ # 7. Check that you can install it in a virtualenv by running:
105
+ # pip install -i https://testpypi.python.org/pypi --extra-index-url https://pypi.org/simple peft
106
+ # 8. Upload the final version to actual pypi:
107
+ # twine upload dist/* -r pypi
108
+ # 9. Add release notes to the tag on https://github.com/huggingface/peft/releases once everything is looking hunky-dory.
109
+ # Check the notes here: https://docs.google.com/document/d/1k-sOIfykuKjWcOIALqjhFKz4amFEp-myeJUJEzNgjoU/edit?usp=sharing
110
+ # 10. Update the version in __init__.py, setup.py to the bumped patch version + ".dev0" (e.g. from "0.7.0" to "0.7.1.dev0")
sentence-transformers/.gitignore ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Distribution / packaging
2
+ .Python
3
+ build/
4
+ develop-eggs/
5
+ dist/
6
+ downloads/
7
+ eggs/
8
+ .eggs/
9
+ lib/
10
+ lib64/
11
+ parts/
12
+ sdist/
13
+ var/
14
+ wheels/
15
+ share/python-wheels/
16
+ *.egg-info/
17
+ .installed.cfg
18
+ *.egg
19
+ MANIFEST
20
+
21
+ # Docs
22
+ /docs/_build/
23
+ /docs/make.bat
24
+
25
+ # Editors
26
+ .idea
27
+ .vscode
28
+
29
+ # Coverage
30
+ htmlcov
31
+ .coverage*
32
+ coverage.xml
33
+
34
+ # Examples
35
+ /examples/**/output/*
36
+ /examples/datasets/
37
+ /examples/embeddings/
38
+ /examples/sentence_transformer/training/quora_duplicate_questions/quora-IR-dataset/
39
+ examples/datasets/*/
40
+
41
+
42
+ # Specific files and folders
43
+ /pretrained-models/
44
+ /cheatsheet.txt
45
+ /testsuite.txt
46
+ /TODO.txt
47
+
48
+ # Virtual environments
49
+ .env
50
+ .venv
51
+ env/
52
+ venv/
53
+
54
+ # Database
55
+ /qdrant_storage
56
+ /elastic-start-local
57
+
58
+ # Others
59
+ *.pyc
60
+ *.gz
61
+ *.tsv
62
+ tmp_*.py
63
+ nr_*/
64
+ wandb
65
+ checkpoints
66
+ tmp
67
+ .DS_Store
68
+ /runs
69
+ /tmp_trainer/