Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/.gitignore +1 -0
- .venv/CACHEDIR.TAG +1 -0
- .venv/pyvenv.cfg +6 -0
- arch/README.md +278 -0
- arch/__init__.py +59 -0
- arch/adapter.py +86 -0
- arch/data_loader.py +658 -0
- arch/example_train.py +377 -0
- arch/model_loader.py +325 -0
- arch/pipeline.py +348 -0
- arch/text_encoder.py +155 -0
- arch/training.py +307 -0
- diffusers/.github/PULL_REQUEST_TEMPLATE.md +61 -0
- diffusers/docs/README.md +268 -0
- diffusers/docs/TRANSLATING.md +69 -0
- diffusers/scripts/conversion_ldm_uncond.py +56 -0
- diffusers/scripts/convert_animatediff_motion_lora_to_diffusers.py +69 -0
- diffusers/scripts/convert_cogvideox_to_diffusers.py +346 -0
- diffusers/scripts/convert_consistency_decoder.py +1128 -0
- diffusers/scripts/convert_dance_diffusion_to_diffusers.py +346 -0
- diffusers/scripts/convert_dcae_to_diffusers.py +323 -0
- diffusers/scripts/convert_diffusers_sdxl_lora_to_webui.py +56 -0
- diffusers/scripts/convert_flux_xlabs_ipadapter_to_diffusers.py +97 -0
- diffusers/scripts/convert_hunyuandit_controlnet_to_diffusers.py +241 -0
- diffusers/scripts/convert_i2vgen_to_diffusers.py +510 -0
- diffusers/scripts/convert_if.py +1250 -0
- diffusers/scripts/convert_lora_safetensor_to_diffusers.py +128 -0
- diffusers/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py +185 -0
- diffusers/scripts/convert_omnigen_to_diffusers.py +203 -0
- diffusers/scripts/convert_original_audioldm2_to_diffusers.py +1135 -0
- diffusers/scripts/convert_original_musicldm_to_diffusers.py +1056 -0
- diffusers/scripts/convert_pixart_sigma_to_diffusers.py +223 -0
- diffusers/scripts/convert_sana_to_diffusers.py +456 -0
- diffusers/scripts/convert_stable_cascade.py +218 -0
- diffusers/scripts/convert_vae_pt_to_diffusers.py +177 -0
- diffusers/scripts/convert_wuerstchen.py +115 -0
- illustrious_generated/low_quality_images.json +0 -0
- illustrious_generated/natural_caption_generation_report.txt +14 -0
- illustrious_generated/optimization_final_results.json +0 -0
- illustrious_generated/optimization_summary_report.txt +20 -0
- illustrious_generated/regeneration_results.json +0 -0
- peft/.gitignore +145 -0
- peft/.pre-commit-config.yaml +13 -0
- peft/LICENSE +201 -0
- peft/Makefile +66 -0
- peft/README.md +189 -0
- peft/pyproject.toml +50 -0
- peft/requirements.txt +15 -0
- peft/setup.py +110 -0
- sentence-transformers/.gitignore +69 -0
.venv/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*
|
.venv/CACHEDIR.TAG
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
.venv/pyvenv.cfg
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
home = /home/ubuntu/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/bin
|
| 2 |
+
implementation = CPython
|
| 3 |
+
uv = 0.7.3
|
| 4 |
+
version_info = 3.12.10
|
| 5 |
+
include-system-site-packages = false
|
| 6 |
+
prompt = QwenIllustrious
|
arch/README.md
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Qwen-SDXL Architecture Components
|
| 2 |
+
|
| 3 |
+
本目录包含了 Qwen-SDXL 项目的解耦架构组件,便于训练和推理的模块化使用。
|
| 4 |
+
|
| 5 |
+
## 📁 组件结构
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
arch/
|
| 9 |
+
├── __init__.py # 组件导入和导出
|
| 10 |
+
├── adapter.py # Qwen 嵌入适配器
|
| 11 |
+
├── text_encoder.py # Qwen 文本编码器
|
| 12 |
+
├── model_loader.py # 模型加载工具
|
| 13 |
+
├── pipeline.py # 推理管道
|
| 14 |
+
├── training.py # 训练工具和损失函数
|
| 15 |
+
├── data_loader.py # 数据加载器
|
| 16 |
+
├── example_train.py # 示例训练脚本
|
| 17 |
+
└── README.md # 本文件
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
## 🧩 核心组件
|
| 21 |
+
|
| 22 |
+
### 1. QwenEmbeddingAdapter (`adapter.py`)
|
| 23 |
+
将 Qwen3 的 1024 维嵌入投影到 SDXL 兼容的维度:
|
| 24 |
+
- 文本嵌入: 1024 → 2048 (用于 encoder_hidden_states)
|
| 25 |
+
- 池化嵌入: 1024 → 1280 (用于 text_embeds)
|
| 26 |
+
|
| 27 |
+
```python
|
| 28 |
+
from arch import QwenEmbeddingAdapter
|
| 29 |
+
|
| 30 |
+
adapter = QwenEmbeddingAdapter()
|
| 31 |
+
text_embeddings = adapter.forward_text_embeddings(qwen_embeddings)
|
| 32 |
+
pooled_embeddings = adapter.forward_pooled_embeddings(qwen_pooled)
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### 2. QwenTextEncoder (`text_encoder.py`)
|
| 36 |
+
封装 Qwen3 模型的文本编码功能:
|
| 37 |
+
|
| 38 |
+
```python
|
| 39 |
+
from arch import QwenTextEncoder
|
| 40 |
+
|
| 41 |
+
text_encoder = QwenTextEncoder(model_path="path/to/qwen3")
|
| 42 |
+
text_emb, pooled_emb = text_encoder.encode_prompts(
|
| 43 |
+
["a beautiful landscape"],
|
| 44 |
+
["low quality"]
|
| 45 |
+
)
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
### 3. 模型加载器 (`model_loader.py`)
|
| 49 |
+
提供各种模型组件的加载功能:
|
| 50 |
+
|
| 51 |
+
```python
|
| 52 |
+
from arch import load_unet_from_safetensors, load_vae_from_safetensors
|
| 53 |
+
|
| 54 |
+
unet = load_unet_from_safetensors("unet.safetensors", "unet_config.json")
|
| 55 |
+
vae = load_vae_from_safetensors("vae.safetensors", "vae_config.json")
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### 4. 推理管道 (`pipeline.py`)
|
| 59 |
+
完整的 Qwen-SDXL 推理管道:
|
| 60 |
+
|
| 61 |
+
```python
|
| 62 |
+
from arch import QwenSDXLInference
|
| 63 |
+
|
| 64 |
+
pipeline = QwenSDXLInference(
|
| 65 |
+
qwen_model_path="path/to/qwen3",
|
| 66 |
+
adapter_path="path/to/trained/adapter.safetensors" # 可选
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
images = pipeline.generate(
|
| 70 |
+
prompt="a beautiful landscape",
|
| 71 |
+
height=1024,
|
| 72 |
+
width=1024
|
| 73 |
+
)
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## 🎓 训练组件
|
| 77 |
+
|
| 78 |
+
### 1. DiffusionLoss (`training.py`)
|
| 79 |
+
扩散训练损失函数,支持多种损失类型和 SNR 加权:
|
| 80 |
+
|
| 81 |
+
```python
|
| 82 |
+
from arch import DiffusionLoss
|
| 83 |
+
|
| 84 |
+
loss_fn = DiffusionLoss(
|
| 85 |
+
noise_scheduler=scheduler,
|
| 86 |
+
loss_type="mse",
|
| 87 |
+
snr_gamma=5.0 # Min-SNR weighting
|
| 88 |
+
)
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
### 2. AdapterTrainingStep (`training.py`)
|
| 92 |
+
适配器训练步骤,自动处理前向传播和损失计算:
|
| 93 |
+
|
| 94 |
+
```python
|
| 95 |
+
from arch import AdapterTrainingStep
|
| 96 |
+
|
| 97 |
+
training_step = AdapterTrainingStep(
|
| 98 |
+
unet=unet,
|
| 99 |
+
vae=vae,
|
| 100 |
+
text_encoder=text_encoder,
|
| 101 |
+
adapter=adapter,
|
| 102 |
+
noise_scheduler=scheduler,
|
| 103 |
+
loss_fn=loss_fn
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
result = training_step.training_step(images, prompts)
|
| 107 |
+
loss = result["loss"]
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
### 3. 数据加载器 (`data_loader.py`)
|
| 111 |
+
|
| 112 |
+
#### 基础数据集
|
| 113 |
+
```python
|
| 114 |
+
from arch import ImageCaptionDataset, create_dataloader
|
| 115 |
+
|
| 116 |
+
dataset = ImageCaptionDataset(
|
| 117 |
+
data_root="/path/to/images",
|
| 118 |
+
annotations_file="captions.jsonl"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
dataloader = create_dataloader(dataset, batch_size=4)
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
#### 多长宽比数据集
|
| 125 |
+
```python
|
| 126 |
+
from arch import MultiAspectDataset
|
| 127 |
+
|
| 128 |
+
dataset = MultiAspectDataset(
|
| 129 |
+
data_root="/path/to/images",
|
| 130 |
+
annotations_file="captions.jsonl",
|
| 131 |
+
aspect_ratios=[(1024, 1024), (1152, 896), (896, 1152)]
|
| 132 |
+
)
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
## 🚀 使用示例
|
| 136 |
+
|
| 137 |
+
### 1. 快速推理
|
| 138 |
+
|
| 139 |
+
```python
|
| 140 |
+
from arch import QwenSDXLInference
|
| 141 |
+
|
| 142 |
+
# 初始化管道
|
| 143 |
+
pipeline = QwenSDXLInference()
|
| 144 |
+
|
| 145 |
+
# 生成图像
|
| 146 |
+
images = pipeline.generate(
|
| 147 |
+
prompt="a serene mountain landscape at sunset",
|
| 148 |
+
negative_prompt="low quality, blurry",
|
| 149 |
+
height=1024,
|
| 150 |
+
width=1024,
|
| 151 |
+
num_inference_steps=50,
|
| 152 |
+
guidance_scale=7.5
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# 保存图像
|
| 156 |
+
images[0].save("generated_image.png")
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
### 2. 训练适配器
|
| 160 |
+
|
| 161 |
+
参考 `example_train.py` 中的完整训练脚本:
|
| 162 |
+
|
| 163 |
+
```bash
|
| 164 |
+
python arch/example_train.py \
|
| 165 |
+
--data_root /path/to/images \
|
| 166 |
+
--annotations_file /path/to/captions.jsonl \
|
| 167 |
+
--batch_size 4 \
|
| 168 |
+
--learning_rate 1e-4 \
|
| 169 |
+
--num_epochs 10 \
|
| 170 |
+
--output_dir ./checkpoints \
|
| 171 |
+
--use_wandb
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
### 3. 使用训练好的适配器
|
| 175 |
+
|
| 176 |
+
```python
|
| 177 |
+
from arch import QwenSDXLInference
|
| 178 |
+
|
| 179 |
+
# 加载带有训练好的适配器的管道
|
| 180 |
+
pipeline = QwenSDXLInference(
|
| 181 |
+
adapter_path="checkpoints/adapter_epoch_10_step_5000.safetensors"
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
images = pipeline.generate("your prompt here")
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
## 📊 训练配置
|
| 188 |
+
|
| 189 |
+
### 推荐的训练设置
|
| 190 |
+
|
| 191 |
+
- **学习率**: 1e-4 (AdamW)
|
| 192 |
+
- **批量大小**: 4-8 (根据 GPU 内存调整)
|
| 193 |
+
- **梯度累积**: 如果内存不足可使用
|
| 194 |
+
- **损失函数**: MSE 或 Huber
|
| 195 |
+
- **SNR 加权**: gamma=5.0 (Min-SNR)
|
| 196 |
+
- **EMA**: decay=0.9999
|
| 197 |
+
- **学习率调度**: 余弦退火 + 预热
|
| 198 |
+
|
| 199 |
+
### 数据格式
|
| 200 |
+
|
| 201 |
+
支持 JSON 和 JSONL 格式的标注文件:
|
| 202 |
+
|
| 203 |
+
```jsonl
|
| 204 |
+
{"image": "image1.jpg", "caption": "A beautiful landscape"}
|
| 205 |
+
{"image": "image2.jpg", "caption": "A cute cat"}
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
## 🔧 自定义和扩展
|
| 209 |
+
|
| 210 |
+
### 1. 自定义适配器架构
|
| 211 |
+
|
| 212 |
+
修改 `adapter.py` 中的 `QwenEmbeddingAdapter` 类:
|
| 213 |
+
|
| 214 |
+
```python
|
| 215 |
+
class CustomAdapter(QwenEmbeddingAdapter):
|
| 216 |
+
def __init__(self, qwen_dim=1024, sdxl_text_dim=2048, sdxl_pooled_dim=1280):
|
| 217 |
+
super().__init__(qwen_dim, sdxl_text_dim, sdxl_pooled_dim)
|
| 218 |
+
# 添加自定义层
|
| 219 |
+
self.custom_layer = nn.Linear(sdxl_text_dim, sdxl_text_dim)
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
### 2. 自定义损失函数
|
| 223 |
+
|
| 224 |
+
继承 `DiffusionLoss` 类:
|
| 225 |
+
|
| 226 |
+
```python
|
| 227 |
+
class CustomLoss(DiffusionLoss):
|
| 228 |
+
def forward(self, model_pred, target, timesteps, mask=None):
|
| 229 |
+
# 自定义损失计算
|
| 230 |
+
base_loss = super().forward(model_pred, target, timesteps, mask)
|
| 231 |
+
# 添加额外的正则化项
|
| 232 |
+
return base_loss + custom_regularization
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
### 3. 自定义数据加载
|
| 236 |
+
|
| 237 |
+
继承数据集类:
|
| 238 |
+
|
| 239 |
+
```python
|
| 240 |
+
class CustomDataset(ImageCaptionDataset):
|
| 241 |
+
def __getitem__(self, idx):
|
| 242 |
+
sample = super().__getitem__(idx)
|
| 243 |
+
# 添加自定义数据增强或预处理
|
| 244 |
+
return sample
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
## ⚠️ 注意事项
|
| 248 |
+
|
| 249 |
+
1. **内存管理**: 使用大模型时注意 GPU 内存使用
|
| 250 |
+
2. **数据类型**: 推荐使用 bfloat16 以平衡性能和精度
|
| 251 |
+
3. **检查点保存**: 定期保存检查点以防止训练中断
|
| 252 |
+
4. **验证集**: 使用验证集监控训练进度
|
| 253 |
+
5. **梯度裁剪**: 防止梯度爆炸
|
| 254 |
+
|
| 255 |
+
## 🐛 故障排除
|
| 256 |
+
|
| 257 |
+
### 常见问题
|
| 258 |
+
|
| 259 |
+
1. **CUDA 内存不足**: 减小批量大小或使用梯度累积
|
| 260 |
+
2. **模型加载失败**: 检查模型路径和配置文件
|
| 261 |
+
3. **生成图像质量差**: 调整学习率、损失函数或训练更多步数
|
| 262 |
+
4. **训练不稳定**: 使用梯度裁剪和 EMA
|
| 263 |
+
|
| 264 |
+
### 调试模式
|
| 265 |
+
|
| 266 |
+
可以在各个组件中启用调试输出:
|
| 267 |
+
|
| 268 |
+
```python
|
| 269 |
+
# 启用详细日志
|
| 270 |
+
import logging
|
| 271 |
+
logging.basicConfig(level=logging.DEBUG)
|
| 272 |
+
```
|
| 273 |
+
|
| 274 |
+
## 📚 参考资料
|
| 275 |
+
|
| 276 |
+
- [Stable Diffusion XL](https://arxiv.org/abs/2307.01952)
|
| 277 |
+
- [Qwen3 Embedding](https://github.com/QwenLM/Qwen)
|
| 278 |
+
- [Diffusers Library](https://github.com/huggingface/diffusers)
|
arch/__init__.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Architecture components for Qwen-SDXL
|
| 3 |
+
Qwen-SDXL 架构组件
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# Core components
|
| 7 |
+
from .adapter import QwenEmbeddingAdapter
|
| 8 |
+
from .text_encoder import QwenTextEncoder, encode_text_with_qwen
|
| 9 |
+
from .model_loader import (
|
| 10 |
+
load_qwen_model,
|
| 11 |
+
load_unet_from_safetensors,
|
| 12 |
+
load_vae_from_safetensors,
|
| 13 |
+
create_scheduler,
|
| 14 |
+
save_model_components,
|
| 15 |
+
load_checkpoint
|
| 16 |
+
)
|
| 17 |
+
from .pipeline import QwenIllustriousInference
|
| 18 |
+
|
| 19 |
+
# Training components
|
| 20 |
+
from .training import (
|
| 21 |
+
DiffusionLoss,
|
| 22 |
+
AdapterTrainingStep,
|
| 23 |
+
get_cosine_schedule_with_warmup,
|
| 24 |
+
EMAModel
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Data loading components
|
| 28 |
+
from .data_loader import (
|
| 29 |
+
ImageCaptionDataset,
|
| 30 |
+
MultiAspectDataset,
|
| 31 |
+
collate_fn,
|
| 32 |
+
create_dataloader
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
__all__ = [
|
| 36 |
+
# Core components
|
| 37 |
+
"QwenEmbeddingAdapter",
|
| 38 |
+
"QwenTextEncoder",
|
| 39 |
+
"encode_text_with_qwen",
|
| 40 |
+
"load_qwen_model",
|
| 41 |
+
"load_unet_from_safetensors",
|
| 42 |
+
"load_vae_from_safetensors",
|
| 43 |
+
"create_scheduler",
|
| 44 |
+
"save_model_components",
|
| 45 |
+
"load_checkpoint",
|
| 46 |
+
"QwenIllustriousInference",
|
| 47 |
+
|
| 48 |
+
# Training components
|
| 49 |
+
"DiffusionLoss",
|
| 50 |
+
"AdapterTrainingStep",
|
| 51 |
+
"get_cosine_schedule_with_warmup",
|
| 52 |
+
"EMAModel",
|
| 53 |
+
|
| 54 |
+
# Data loading components
|
| 55 |
+
"ImageCaptionDataset",
|
| 56 |
+
"MultiAspectDataset",
|
| 57 |
+
"collate_fn",
|
| 58 |
+
"create_dataloader"
|
| 59 |
+
]
|
arch/adapter.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Qwen Embedding Adapter
|
| 3 |
+
Qwen 嵌入适配器 - 将 Qwen3 嵌入维度投影到 SDXL 兼容维度
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class QwenEmbeddingAdapter(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Adapter layer to project Qwen3 embeddings to SDXL-compatible dimensions
|
| 14 |
+
将 Qwen3 嵌入维度投影到 SDXL 兼容维度
|
| 15 |
+
- Text embeddings: 1024 -> 2048 (for encoder_hidden_states)
|
| 16 |
+
- Pooled embeddings: 1024 -> 1280 (for text_embeds in added_cond_kwargs)
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, qwen_dim=1024, sdxl_text_dim=2048, sdxl_pooled_dim=1280):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.qwen_dim = qwen_dim
|
| 21 |
+
self.sdxl_text_dim = sdxl_text_dim
|
| 22 |
+
self.sdxl_pooled_dim = sdxl_pooled_dim
|
| 23 |
+
|
| 24 |
+
# Text embeddings projection (for encoder_hidden_states)
|
| 25 |
+
self.text_projection = nn.Linear(qwen_dim, sdxl_text_dim)
|
| 26 |
+
self.text_layer_norm = nn.LayerNorm(sdxl_text_dim)
|
| 27 |
+
self.text_activation = nn.GELU()
|
| 28 |
+
|
| 29 |
+
# Pooled embeddings MLP (for text_embeds in added_cond_kwargs)
|
| 30 |
+
self.pooled_mlp = nn.Sequential(
|
| 31 |
+
nn.Linear(qwen_dim, qwen_dim * 2),
|
| 32 |
+
nn.GELU(),
|
| 33 |
+
nn.Dropout(0.1),
|
| 34 |
+
nn.Linear(qwen_dim * 2, sdxl_pooled_dim),
|
| 35 |
+
nn.LayerNorm(sdxl_pooled_dim)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# 初始化权重
|
| 39 |
+
self._init_weights()
|
| 40 |
+
|
| 41 |
+
def _init_weights(self):
|
| 42 |
+
"""Initialize weights for better training stability"""
|
| 43 |
+
# Text projection initialization
|
| 44 |
+
nn.init.xavier_uniform_(self.text_projection.weight)
|
| 45 |
+
nn.init.zeros_(self.text_projection.bias)
|
| 46 |
+
|
| 47 |
+
# Pooled MLP initialization
|
| 48 |
+
for module in self.pooled_mlp:
|
| 49 |
+
if isinstance(module, nn.Linear):
|
| 50 |
+
nn.init.xavier_uniform_(module.weight)
|
| 51 |
+
nn.init.zeros_(module.bias)
|
| 52 |
+
|
| 53 |
+
def forward_text_embeddings(self, qwen_embeddings):
|
| 54 |
+
"""
|
| 55 |
+
Project text embeddings for encoder_hidden_states
|
| 56 |
+
Args:
|
| 57 |
+
qwen_embeddings: tensor of shape [batch_size, seq_len, 1024]
|
| 58 |
+
Returns:
|
| 59 |
+
text_embeddings: tensor of shape [batch_size, seq_len, 2048]
|
| 60 |
+
"""
|
| 61 |
+
projected = self.text_projection(qwen_embeddings)
|
| 62 |
+
projected = self.text_activation(projected)
|
| 63 |
+
return self.text_layer_norm(projected)
|
| 64 |
+
|
| 65 |
+
def forward_pooled_embeddings(self, qwen_embeddings):
|
| 66 |
+
"""
|
| 67 |
+
Project pooled embeddings for text_embeds (using MLP)
|
| 68 |
+
Args:
|
| 69 |
+
qwen_embeddings: tensor of shape [batch_size, 1024]
|
| 70 |
+
Returns:
|
| 71 |
+
pooled_embeddings: tensor of shape [batch_size, 1280]
|
| 72 |
+
"""
|
| 73 |
+
return self.pooled_mlp(qwen_embeddings)
|
| 74 |
+
|
| 75 |
+
def forward(self, text_embeddings, pooled_embeddings):
|
| 76 |
+
"""
|
| 77 |
+
Forward pass for both text and pooled embeddings
|
| 78 |
+
Args:
|
| 79 |
+
text_embeddings: tensor of shape [batch_size, seq_len, 1024]
|
| 80 |
+
pooled_embeddings: tensor of shape [batch_size, 1024]
|
| 81 |
+
Returns:
|
| 82 |
+
tuple: (projected_text_embeddings, projected_pooled_embeddings)
|
| 83 |
+
"""
|
| 84 |
+
projected_text = self.forward_text_embeddings(text_embeddings)
|
| 85 |
+
projected_pooled = self.forward_pooled_embeddings(pooled_embeddings)
|
| 86 |
+
return projected_text, projected_pooled
|
arch/data_loader.py
ADDED
|
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data Loading Utilities for QwenIllustrious
|
| 3 |
+
数据加载工具 - 处理训练数据的加载和预处理,支持预计算嵌入加速
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Dict, List, Optional, Tuple
|
| 14 |
+
import pickle
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class QwenIllustriousDataset(Dataset):
|
| 19 |
+
"""
|
| 20 |
+
Dataset for QwenIllustrious training
|
| 21 |
+
支持以下功能:
|
| 22 |
+
- 从 metadata.json 文件加载图像和标注
|
| 23 |
+
- 图像预处理和增强
|
| 24 |
+
- Qwen 文本编码缓存
|
| 25 |
+
- VAE 潜在空间编码缓存
|
| 26 |
+
- 训练时的预计算加速
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
dataset_path: str,
|
| 32 |
+
qwen_text_encoder=None,
|
| 33 |
+
vae=None,
|
| 34 |
+
image_size: int = 1024,
|
| 35 |
+
cache_dir: Optional[str] = None,
|
| 36 |
+
precompute_embeddings: bool = False
|
| 37 |
+
):
|
| 38 |
+
self.dataset_path = Path(dataset_path)
|
| 39 |
+
self.qwen_text_encoder = qwen_text_encoder
|
| 40 |
+
self.vae = vae
|
| 41 |
+
self.image_size = image_size
|
| 42 |
+
self.cache_dir = Path(cache_dir) if cache_dir else None
|
| 43 |
+
self.precompute_embeddings = precompute_embeddings
|
| 44 |
+
|
| 45 |
+
# Setup image transforms
|
| 46 |
+
self.image_transforms = transforms.Compose([
|
| 47 |
+
transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
|
| 48 |
+
transforms.ToTensor(),
|
| 49 |
+
transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1]
|
| 50 |
+
])
|
| 51 |
+
|
| 52 |
+
# Load metadata
|
| 53 |
+
self.metadata = self._load_metadata()
|
| 54 |
+
|
| 55 |
+
# Setup cache directories
|
| 56 |
+
if self.cache_dir:
|
| 57 |
+
self.cache_dir.mkdir(exist_ok=True)
|
| 58 |
+
self.text_cache_dir = self.cache_dir / "text_embeddings"
|
| 59 |
+
self.vae_cache_dir = self.cache_dir / "vae_latents"
|
| 60 |
+
self.text_cache_dir.mkdir(exist_ok=True)
|
| 61 |
+
self.vae_cache_dir.mkdir(exist_ok=True)
|
| 62 |
+
|
| 63 |
+
# Precomputed data storage
|
| 64 |
+
self.precomputed_data = {}
|
| 65 |
+
|
| 66 |
+
def _load_metadata(self) -> List[Dict]:
|
| 67 |
+
"""Load all metadata files"""
|
| 68 |
+
metadata_dir = self.dataset_path / "metadata"
|
| 69 |
+
if not metadata_dir.exists():
|
| 70 |
+
raise ValueError(f"Metadata directory not found: {metadata_dir}")
|
| 71 |
+
|
| 72 |
+
metadata_files = list(metadata_dir.glob("*.json"))
|
| 73 |
+
|
| 74 |
+
metadata_list = []
|
| 75 |
+
print(f"Loading metadata from {len(metadata_files)} files...")
|
| 76 |
+
|
| 77 |
+
for file_path in tqdm(metadata_files, desc="Loading metadata"):
|
| 78 |
+
try:
|
| 79 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 80 |
+
data = json.load(f)
|
| 81 |
+
# Add file path info
|
| 82 |
+
data['metadata_file'] = str(file_path)
|
| 83 |
+
data['image_file'] = str(self.dataset_path / f"{data['filename_hash']}.png")
|
| 84 |
+
metadata_list.append(data)
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Error loading {file_path}: {e}")
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
print(f"Successfully loaded {len(metadata_list)} metadata files")
|
| 90 |
+
return metadata_list
|
| 91 |
+
|
| 92 |
+
def _get_text_cache_path(self, filename_hash: str) -> Path:
|
| 93 |
+
"""Get path for cached text embeddings"""
|
| 94 |
+
return self.text_cache_dir / f"{filename_hash}_text.pt"
|
| 95 |
+
|
| 96 |
+
def _get_vae_cache_path(self, filename_hash: str) -> Path:
|
| 97 |
+
"""Get path for cached VAE latents"""
|
| 98 |
+
return self.vae_cache_dir / f"{filename_hash}_vae.pt"
|
| 99 |
+
|
| 100 |
+
def _compute_text_embeddings(self, prompt: str, device='cpu') -> Dict[str, torch.Tensor]:
|
| 101 |
+
"""Compute text embeddings using Qwen text encoder"""
|
| 102 |
+
if not self.qwen_text_encoder:
|
| 103 |
+
# Return dummy embeddings
|
| 104 |
+
return {
|
| 105 |
+
'text_embeddings': torch.zeros(1, 2048), # SDXL text embedding size
|
| 106 |
+
'pooled_embeddings': torch.zeros(1, 1280) # SDXL pooled embedding size
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
# Move to device temporarily for computation
|
| 111 |
+
original_device = next(self.qwen_text_encoder.parameters()).device
|
| 112 |
+
self.qwen_text_encoder.to(device)
|
| 113 |
+
|
| 114 |
+
embeddings = self.qwen_text_encoder.encode_prompts([prompt])
|
| 115 |
+
|
| 116 |
+
# Move back to original device
|
| 117 |
+
self.qwen_text_encoder.to(original_device)
|
| 118 |
+
|
| 119 |
+
return {
|
| 120 |
+
'text_embeddings': embeddings[0].cpu(),
|
| 121 |
+
'pooled_embeddings': embeddings[1].cpu() if len(embeddings) > 1 else embeddings[0].cpu()
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
def _compute_vae_latents(self, image: torch.Tensor, device='cpu') -> torch.Tensor:
|
| 125 |
+
"""Compute VAE latents for image"""
|
| 126 |
+
if not self.vae:
|
| 127 |
+
# Return dummy latents
|
| 128 |
+
return torch.zeros(1, 4, self.image_size // 8, self.image_size // 8)
|
| 129 |
+
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
# Move to device temporarily for computation
|
| 132 |
+
original_device = next(self.vae.parameters()).device
|
| 133 |
+
self.vae.to(device)
|
| 134 |
+
|
| 135 |
+
# Add batch dimension if needed
|
| 136 |
+
if image.dim() == 3:
|
| 137 |
+
image = image.unsqueeze(0)
|
| 138 |
+
|
| 139 |
+
image = image.to(device).to(self.vae.dtype)
|
| 140 |
+
latents = self.vae.encode(image).latent_dist.sample()
|
| 141 |
+
latents = latents * self.vae.config.scaling_factor
|
| 142 |
+
|
| 143 |
+
# Move back to original device
|
| 144 |
+
self.vae.to(original_device)
|
| 145 |
+
|
| 146 |
+
return latents.cpu()
|
| 147 |
+
|
| 148 |
+
def _load_or_compute_text_embeddings(self, prompt: str, filename_hash: str, device='cpu') -> Dict[str, torch.Tensor]:
|
| 149 |
+
"""Load cached text embeddings or compute new ones"""
|
| 150 |
+
if self.cache_dir:
|
| 151 |
+
cache_path = self._get_text_cache_path(filename_hash)
|
| 152 |
+
|
| 153 |
+
# Try to load from cache
|
| 154 |
+
if cache_path.exists():
|
| 155 |
+
try:
|
| 156 |
+
return torch.load(cache_path, map_location='cpu')
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Error loading cached text embeddings {cache_path}: {e}")
|
| 159 |
+
|
| 160 |
+
# Compute new embeddings
|
| 161 |
+
embeddings = self._compute_text_embeddings(prompt, device)
|
| 162 |
+
|
| 163 |
+
# Cache the embeddings
|
| 164 |
+
if self.cache_dir:
|
| 165 |
+
try:
|
| 166 |
+
torch.save(embeddings, cache_path)
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(f"Error saving text embeddings cache {cache_path}: {e}")
|
| 169 |
+
|
| 170 |
+
return embeddings
|
| 171 |
+
|
| 172 |
+
def _load_or_compute_vae_latents(self, image_path: str, filename_hash: str, device='cpu') -> torch.Tensor:
|
| 173 |
+
"""Load cached VAE latents or compute new ones"""
|
| 174 |
+
if self.cache_dir:
|
| 175 |
+
cache_path = self._get_vae_cache_path(filename_hash)
|
| 176 |
+
|
| 177 |
+
# Try to load from cache
|
| 178 |
+
if cache_path.exists():
|
| 179 |
+
try:
|
| 180 |
+
return torch.load(cache_path, map_location='cpu')
|
| 181 |
+
except Exception as e:
|
| 182 |
+
print(f"Error loading cached VAE latents {cache_path}: {e}")
|
| 183 |
+
|
| 184 |
+
# Load and process image
|
| 185 |
+
try:
|
| 186 |
+
image = Image.open(image_path).convert('RGB')
|
| 187 |
+
image = self.image_transforms(image)
|
| 188 |
+
except Exception as e:
|
| 189 |
+
print(f"Error loading image {image_path}: {e}")
|
| 190 |
+
image = torch.zeros(3, self.image_size, self.image_size)
|
| 191 |
+
|
| 192 |
+
# Compute latents
|
| 193 |
+
latents = self._compute_vae_latents(image, device)
|
| 194 |
+
|
| 195 |
+
# Cache the latents
|
| 196 |
+
if self.cache_dir:
|
| 197 |
+
try:
|
| 198 |
+
torch.save(latents, cache_path)
|
| 199 |
+
except Exception as e:
|
| 200 |
+
print(f"Error saving VAE latents cache {cache_path}: {e}")
|
| 201 |
+
|
| 202 |
+
return latents
|
| 203 |
+
|
| 204 |
+
def precompute_all(self, device='cuda'):
|
| 205 |
+
"""Precompute all embeddings and latents for faster training"""
|
| 206 |
+
print("Precomputing all embeddings and latents...")
|
| 207 |
+
|
| 208 |
+
for idx in tqdm(range(len(self.metadata)), desc="Precomputing"):
|
| 209 |
+
metadata = self.metadata[idx]
|
| 210 |
+
filename_hash = metadata['filename_hash']
|
| 211 |
+
|
| 212 |
+
# Get prompt
|
| 213 |
+
prompt = metadata.get('natural_caption_data', {}).get('natural_caption', '')
|
| 214 |
+
if not prompt:
|
| 215 |
+
prompt = metadata.get('original_prompt_data', {}).get('positive_prompt', '')
|
| 216 |
+
|
| 217 |
+
# Precompute text embeddings
|
| 218 |
+
text_embeddings = self._load_or_compute_text_embeddings(prompt, filename_hash, device)
|
| 219 |
+
|
| 220 |
+
# Precompute VAE latents
|
| 221 |
+
vae_latents = self._load_or_compute_vae_latents(metadata['image_file'], filename_hash, device)
|
| 222 |
+
|
| 223 |
+
# Store in memory for fast access
|
| 224 |
+
self.precomputed_data[filename_hash] = {
|
| 225 |
+
'text_embeddings': text_embeddings['text_embeddings'].squeeze(0),
|
| 226 |
+
'pooled_embeddings': text_embeddings['pooled_embeddings'].squeeze(0),
|
| 227 |
+
'latents': vae_latents.squeeze(0),
|
| 228 |
+
'prompt': prompt
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
print(f"Precomputation completed for {len(self.precomputed_data)} items")
|
| 232 |
+
|
| 233 |
+
def __len__(self):
|
| 234 |
+
return len(self.metadata)
|
| 235 |
+
|
| 236 |
+
def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
|
| 237 |
+
metadata = self.metadata[idx]
|
| 238 |
+
filename_hash = metadata['filename_hash']
|
| 239 |
+
|
| 240 |
+
if self.precompute_embeddings and filename_hash in self.precomputed_data:
|
| 241 |
+
# Use precomputed data
|
| 242 |
+
data = self.precomputed_data[filename_hash]
|
| 243 |
+
return {
|
| 244 |
+
'text_embeddings': data['text_embeddings'],
|
| 245 |
+
'pooled_embeddings': data['pooled_embeddings'],
|
| 246 |
+
'latents': data['latents'],
|
| 247 |
+
'prompts': data['prompt'],
|
| 248 |
+
'filename_hash': filename_hash,
|
| 249 |
+
'metadata': metadata
|
| 250 |
+
}
|
| 251 |
+
else:
|
| 252 |
+
# Load data on-the-fly
|
| 253 |
+
|
| 254 |
+
# Load image
|
| 255 |
+
image_path = metadata['image_file']
|
| 256 |
+
try:
|
| 257 |
+
image = Image.open(image_path).convert('RGB')
|
| 258 |
+
image = self.image_transforms(image)
|
| 259 |
+
except Exception as e:
|
| 260 |
+
print(f"Error loading image {image_path}: {e}")
|
| 261 |
+
image = torch.zeros(3, self.image_size, self.image_size)
|
| 262 |
+
|
| 263 |
+
# Get prompt
|
| 264 |
+
prompt = metadata.get('natural_caption_data', {}).get('natural_caption', '')
|
| 265 |
+
if not prompt:
|
| 266 |
+
prompt = metadata.get('original_prompt_data', {}).get('positive_prompt', '')
|
| 267 |
+
|
| 268 |
+
# Get text embeddings (will use cache if available)
|
| 269 |
+
text_embeddings = self._load_or_compute_text_embeddings(prompt, filename_hash)
|
| 270 |
+
|
| 271 |
+
return {
|
| 272 |
+
'images': image,
|
| 273 |
+
'prompts': prompt,
|
| 274 |
+
'text_embeddings': text_embeddings['text_embeddings'].squeeze(0),
|
| 275 |
+
'pooled_embeddings': text_embeddings['pooled_embeddings'].squeeze(0),
|
| 276 |
+
'filename_hash': filename_hash,
|
| 277 |
+
'metadata': metadata
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def collate_fn(examples: List[Dict]) -> Dict[str, torch.Tensor]:
|
| 282 |
+
"""Custom collate function for DataLoader"""
|
| 283 |
+
batch = {}
|
| 284 |
+
|
| 285 |
+
# Handle different data formats (precomputed vs on-the-fly)
|
| 286 |
+
if 'latents' in examples[0]:
|
| 287 |
+
# Precomputed format - embeddings and latents are already computed
|
| 288 |
+
batch['latents'] = torch.stack([example['latents'] for example in examples])
|
| 289 |
+
batch['text_embeddings'] = torch.stack([example['text_embeddings'] for example in examples])
|
| 290 |
+
batch['pooled_embeddings'] = torch.stack([example['pooled_embeddings'] for example in examples])
|
| 291 |
+
else:
|
| 292 |
+
# On-the-fly format - need to handle images
|
| 293 |
+
batch['images'] = torch.stack([example['images'] for example in examples])
|
| 294 |
+
batch['text_embeddings'] = torch.stack([example['text_embeddings'] for example in examples])
|
| 295 |
+
batch['pooled_embeddings'] = torch.stack([example['pooled_embeddings'] for example in examples])
|
| 296 |
+
|
| 297 |
+
# Handle string fields
|
| 298 |
+
batch['prompts'] = [example['prompts'] for example in examples]
|
| 299 |
+
batch['filename_hash'] = [example['filename_hash'] for example in examples]
|
| 300 |
+
batch['metadata'] = [example['metadata'] for example in examples]
|
| 301 |
+
|
| 302 |
+
return batch
|
| 303 |
+
|
| 304 |
+
import torch
|
| 305 |
+
from torch.utils.data import Dataset, DataLoader
|
| 306 |
+
from PIL import Image
|
| 307 |
+
import json
|
| 308 |
+
import os
|
| 309 |
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
| 310 |
+
import torchvision.transforms as transforms
|
| 311 |
+
import random
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class ImageCaptionDataset(Dataset):
|
| 315 |
+
"""
|
| 316 |
+
Dataset for image-caption pairs
|
| 317 |
+
图像-标题对数据集
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
def __init__(
|
| 321 |
+
self,
|
| 322 |
+
data_root: str,
|
| 323 |
+
annotations_file: str,
|
| 324 |
+
image_size: int = 1024,
|
| 325 |
+
center_crop: bool = True,
|
| 326 |
+
random_flip: bool = True,
|
| 327 |
+
caption_column: str = "caption",
|
| 328 |
+
image_column: str = "image",
|
| 329 |
+
max_caption_length: int = 512
|
| 330 |
+
):
|
| 331 |
+
self.data_root = data_root
|
| 332 |
+
self.image_size = image_size
|
| 333 |
+
self.caption_column = caption_column
|
| 334 |
+
self.image_column = image_column
|
| 335 |
+
self.max_caption_length = max_caption_length
|
| 336 |
+
|
| 337 |
+
# Load annotations
|
| 338 |
+
self.annotations = self._load_annotations(annotations_file)
|
| 339 |
+
|
| 340 |
+
# Setup image transforms
|
| 341 |
+
self.image_transforms = self._setup_transforms(image_size, center_crop, random_flip)
|
| 342 |
+
|
| 343 |
+
print(f"📚 数据集加载完成: {len(self.annotations)} 个样本")
|
| 344 |
+
|
| 345 |
+
def _load_annotations(self, annotations_file: str) -> List[Dict]:
|
| 346 |
+
"""Load annotations from file"""
|
| 347 |
+
if annotations_file.endswith('.json'):
|
| 348 |
+
with open(annotations_file, 'r', encoding='utf-8') as f:
|
| 349 |
+
data = json.load(f)
|
| 350 |
+
elif annotations_file.endswith('.jsonl'):
|
| 351 |
+
data = []
|
| 352 |
+
with open(annotations_file, 'r', encoding='utf-8') as f:
|
| 353 |
+
for line in f:
|
| 354 |
+
if line.strip():
|
| 355 |
+
data.append(json.loads(line))
|
| 356 |
+
else:
|
| 357 |
+
raise ValueError(f"Unsupported annotation file format: {annotations_file}")
|
| 358 |
+
|
| 359 |
+
# Filter valid samples
|
| 360 |
+
valid_data = []
|
| 361 |
+
for item in data:
|
| 362 |
+
if self.caption_column in item and self.image_column in item:
|
| 363 |
+
if isinstance(item[self.caption_column], str) and item[self.caption_column].strip():
|
| 364 |
+
valid_data.append(item)
|
| 365 |
+
|
| 366 |
+
print(f"📋 有效样本数: {len(valid_data)} / {len(data)}")
|
| 367 |
+
return valid_data
|
| 368 |
+
|
| 369 |
+
def _setup_transforms(self, size: int, center_crop: bool, random_flip: bool):
|
| 370 |
+
"""Setup image preprocessing transforms"""
|
| 371 |
+
transform_list = []
|
| 372 |
+
|
| 373 |
+
# Resize
|
| 374 |
+
if center_crop:
|
| 375 |
+
transform_list.extend([
|
| 376 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
| 377 |
+
transforms.CenterCrop(size)
|
| 378 |
+
])
|
| 379 |
+
else:
|
| 380 |
+
transform_list.append(
|
| 381 |
+
transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR)
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# Random horizontal flip
|
| 385 |
+
if random_flip:
|
| 386 |
+
transform_list.append(transforms.RandomHorizontalFlip(p=0.5))
|
| 387 |
+
|
| 388 |
+
# Convert to tensor and normalize
|
| 389 |
+
transform_list.extend([
|
| 390 |
+
transforms.ToTensor(),
|
| 391 |
+
transforms.Normalize([0.5], [0.5]) # Scale to [-1, 1]
|
| 392 |
+
])
|
| 393 |
+
|
| 394 |
+
return transforms.Compose(transform_list)
|
| 395 |
+
|
| 396 |
+
def __len__(self):
|
| 397 |
+
return len(self.annotations)
|
| 398 |
+
|
| 399 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 400 |
+
"""Get a single sample"""
|
| 401 |
+
annotation = self.annotations[idx]
|
| 402 |
+
|
| 403 |
+
# Load image
|
| 404 |
+
image_path = os.path.join(self.data_root, annotation[self.image_column])
|
| 405 |
+
try:
|
| 406 |
+
image = Image.open(image_path)
|
| 407 |
+
if image.mode != 'RGB':
|
| 408 |
+
image = image.convert('RGB')
|
| 409 |
+
except Exception as e:
|
| 410 |
+
print(f"⚠️ 加载图像失败 {image_path}: {e}")
|
| 411 |
+
# Return a black image as fallback
|
| 412 |
+
image = Image.new('RGB', (self.image_size, self.image_size), (0, 0, 0))
|
| 413 |
+
|
| 414 |
+
# Apply transforms
|
| 415 |
+
image = self.image_transforms(image)
|
| 416 |
+
|
| 417 |
+
# Get caption
|
| 418 |
+
caption = annotation[self.caption_column]
|
| 419 |
+
if len(caption) > self.max_caption_length:
|
| 420 |
+
caption = caption[:self.max_caption_length]
|
| 421 |
+
|
| 422 |
+
return {
|
| 423 |
+
"images": image,
|
| 424 |
+
"captions": caption,
|
| 425 |
+
"image_paths": image_path
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
class MultiAspectDataset(Dataset):
|
| 430 |
+
"""
|
| 431 |
+
Dataset that supports multiple aspect ratios
|
| 432 |
+
支持多种长宽比的数据集
|
| 433 |
+
"""
|
| 434 |
+
|
| 435 |
+
def __init__(
|
| 436 |
+
self,
|
| 437 |
+
data_root: str,
|
| 438 |
+
annotations_file: str,
|
| 439 |
+
base_size: int = 1024,
|
| 440 |
+
aspect_ratios: List[Tuple[int, int]] = None,
|
| 441 |
+
bucket_tolerance: float = 0.1,
|
| 442 |
+
caption_column: str = "caption",
|
| 443 |
+
image_column: str = "image",
|
| 444 |
+
max_caption_length: int = 512
|
| 445 |
+
):
|
| 446 |
+
self.data_root = data_root
|
| 447 |
+
self.base_size = base_size
|
| 448 |
+
self.caption_column = caption_column
|
| 449 |
+
self.image_column = image_column
|
| 450 |
+
self.max_caption_length = max_caption_length
|
| 451 |
+
|
| 452 |
+
# Default aspect ratios for SDXL
|
| 453 |
+
if aspect_ratios is None:
|
| 454 |
+
aspect_ratios = [
|
| 455 |
+
(1024, 1024), # 1:1
|
| 456 |
+
(1152, 896), # 9:7
|
| 457 |
+
(896, 1152), # 7:9
|
| 458 |
+
(1216, 832), # 3:2
|
| 459 |
+
(832, 1216), # 2:3
|
| 460 |
+
(1344, 768), # 7:4
|
| 461 |
+
(768, 1344), # 4:7
|
| 462 |
+
(1536, 640), # 12:5
|
| 463 |
+
(640, 1536), # 5:12
|
| 464 |
+
]
|
| 465 |
+
|
| 466 |
+
self.aspect_ratios = aspect_ratios
|
| 467 |
+
self.bucket_tolerance = bucket_tolerance
|
| 468 |
+
|
| 469 |
+
# Load and bucket annotations
|
| 470 |
+
self.annotations = self._load_and_bucket_annotations(annotations_file)
|
| 471 |
+
|
| 472 |
+
print(f"📚 多长宽比数据集加载完成: {len(self.annotations)} 个样本")
|
| 473 |
+
self._print_bucket_stats()
|
| 474 |
+
|
| 475 |
+
def _load_and_bucket_annotations(self, annotations_file: str) -> List[Dict]:
|
| 476 |
+
"""Load annotations and assign to aspect ratio buckets"""
|
| 477 |
+
# Load annotations
|
| 478 |
+
if annotations_file.endswith('.json'):
|
| 479 |
+
with open(annotations_file, 'r', encoding='utf-8') as f:
|
| 480 |
+
data = json.load(f)
|
| 481 |
+
elif annotations_file.endswith('.jsonl'):
|
| 482 |
+
data = []
|
| 483 |
+
with open(annotations_file, 'r', encoding='utf-8') as f:
|
| 484 |
+
for line in f:
|
| 485 |
+
if line.strip():
|
| 486 |
+
data.append(json.loads(line))
|
| 487 |
+
|
| 488 |
+
bucketed_data = []
|
| 489 |
+
|
| 490 |
+
for item in data:
|
| 491 |
+
if self.caption_column not in item or self.image_column not in item:
|
| 492 |
+
continue
|
| 493 |
+
|
| 494 |
+
caption = item[self.caption_column]
|
| 495 |
+
if not isinstance(caption, str) or not caption.strip():
|
| 496 |
+
continue
|
| 497 |
+
|
| 498 |
+
# Try to get image dimensions to assign bucket
|
| 499 |
+
image_path = os.path.join(self.data_root, item[self.image_column])
|
| 500 |
+
try:
|
| 501 |
+
with Image.open(image_path) as img:
|
| 502 |
+
width, height = img.size
|
| 503 |
+
aspect_ratio = width / height
|
| 504 |
+
|
| 505 |
+
# Find best matching bucket
|
| 506 |
+
best_bucket = self._find_best_bucket(aspect_ratio)
|
| 507 |
+
|
| 508 |
+
item_copy = item.copy()
|
| 509 |
+
item_copy['bucket_width'] = best_bucket[0]
|
| 510 |
+
item_copy['bucket_height'] = best_bucket[1]
|
| 511 |
+
item_copy['original_width'] = width
|
| 512 |
+
item_copy['original_height'] = height
|
| 513 |
+
|
| 514 |
+
bucketed_data.append(item_copy)
|
| 515 |
+
|
| 516 |
+
except Exception as e:
|
| 517 |
+
print(f"⚠️ 无法获取图像尺寸 {image_path}: {e}")
|
| 518 |
+
# Use default 1:1 bucket
|
| 519 |
+
item_copy = item.copy()
|
| 520 |
+
item_copy['bucket_width'] = 1024
|
| 521 |
+
item_copy['bucket_height'] = 1024
|
| 522 |
+
item_copy['original_width'] = 1024
|
| 523 |
+
item_copy['original_height'] = 1024
|
| 524 |
+
bucketed_data.append(item_copy)
|
| 525 |
+
|
| 526 |
+
return bucketed_data
|
| 527 |
+
|
| 528 |
+
def _find_best_bucket(self, aspect_ratio: float) -> Tuple[int, int]:
|
| 529 |
+
"""Find the best matching aspect ratio bucket"""
|
| 530 |
+
best_bucket = self.aspect_ratios[0]
|
| 531 |
+
best_diff = float('inf')
|
| 532 |
+
|
| 533 |
+
for bucket_w, bucket_h in self.aspect_ratios:
|
| 534 |
+
bucket_ratio = bucket_w / bucket_h
|
| 535 |
+
diff = abs(aspect_ratio - bucket_ratio)
|
| 536 |
+
|
| 537 |
+
if diff < best_diff:
|
| 538 |
+
best_diff = diff
|
| 539 |
+
best_bucket = (bucket_w, bucket_h)
|
| 540 |
+
|
| 541 |
+
return best_bucket
|
| 542 |
+
|
| 543 |
+
def _print_bucket_stats(self):
|
| 544 |
+
"""Print statistics about bucket distribution"""
|
| 545 |
+
bucket_counts = {}
|
| 546 |
+
for item in self.annotations:
|
| 547 |
+
bucket = (item['bucket_width'], item['bucket_height'])
|
| 548 |
+
bucket_counts[bucket] = bucket_counts.get(bucket, 0) + 1
|
| 549 |
+
|
| 550 |
+
print("📊 长宽比分布:")
|
| 551 |
+
for bucket, count in sorted(bucket_counts.items()):
|
| 552 |
+
ratio = bucket[0] / bucket[1]
|
| 553 |
+
print(f" {bucket[0]}×{bucket[1]} (比例 {ratio:.2f}): {count} 个样本")
|
| 554 |
+
|
| 555 |
+
def _get_transforms(self, target_width: int, target_height: int):
|
| 556 |
+
"""Get transforms for specific target size"""
|
| 557 |
+
return transforms.Compose([
|
| 558 |
+
transforms.Resize((target_height, target_width), interpolation=transforms.InterpolationMode.BILINEAR),
|
| 559 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 560 |
+
transforms.ToTensor(),
|
| 561 |
+
transforms.Normalize([0.5], [0.5])
|
| 562 |
+
])
|
| 563 |
+
|
| 564 |
+
def __len__(self):
|
| 565 |
+
return len(self.annotations)
|
| 566 |
+
|
| 567 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 568 |
+
"""Get a single sample"""
|
| 569 |
+
annotation = self.annotations[idx]
|
| 570 |
+
|
| 571 |
+
# Get target dimensions from bucket
|
| 572 |
+
target_width = annotation['bucket_width']
|
| 573 |
+
target_height = annotation['bucket_height']
|
| 574 |
+
|
| 575 |
+
# Load and transform image
|
| 576 |
+
image_path = os.path.join(self.data_root, annotation[self.image_column])
|
| 577 |
+
try:
|
| 578 |
+
image = Image.open(image_path)
|
| 579 |
+
if image.mode != 'RGB':
|
| 580 |
+
image = image.convert('RGB')
|
| 581 |
+
except Exception as e:
|
| 582 |
+
print(f"⚠️ 加载图像失败 {image_path}: {e}")
|
| 583 |
+
image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
|
| 584 |
+
|
| 585 |
+
# Apply transforms
|
| 586 |
+
transforms_fn = self._get_transforms(target_width, target_height)
|
| 587 |
+
image = transforms_fn(image)
|
| 588 |
+
|
| 589 |
+
# Get caption
|
| 590 |
+
caption = annotation[self.caption_column]
|
| 591 |
+
if len(caption) > self.max_caption_length:
|
| 592 |
+
caption = caption[:self.max_caption_length]
|
| 593 |
+
|
| 594 |
+
return {
|
| 595 |
+
"images": image,
|
| 596 |
+
"captions": caption,
|
| 597 |
+
"image_paths": image_path,
|
| 598 |
+
"width": target_width,
|
| 599 |
+
"height": target_height
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
# def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 604 |
+
# """
|
| 605 |
+
# Custom collate function for batching
|
| 606 |
+
# 自定义批处理整理函数
|
| 607 |
+
# """
|
| 608 |
+
# # Check if all images have the same size
|
| 609 |
+
# sizes = [(item["images"].shape[-2], item["images"].shape[-1]) for item in batch]
|
| 610 |
+
# if len(set(sizes)) == 1:
|
| 611 |
+
# # All same size, can batch normally
|
| 612 |
+
# images = torch.stack([item["images"] for item in batch])
|
| 613 |
+
# captions = [item["captions"] for item in batch]
|
| 614 |
+
|
| 615 |
+
# result = {
|
| 616 |
+
# "images": images,
|
| 617 |
+
# "captions": captions,
|
| 618 |
+
# "image_paths": [item["image_paths"] for item in batch]
|
| 619 |
+
# }
|
| 620 |
+
|
| 621 |
+
# # Add width/height if available
|
| 622 |
+
# if "width" in batch[0]:
|
| 623 |
+
# result["widths"] = [item["width"] for item in batch]
|
| 624 |
+
# result["heights"] = [item["height"] for item in batch]
|
| 625 |
+
|
| 626 |
+
# return result
|
| 627 |
+
# else:
|
| 628 |
+
# # Different sizes, return as list
|
| 629 |
+
# return {
|
| 630 |
+
# "images": [item["images"] for item in batch],
|
| 631 |
+
# "captions": [item["captions"] for item in batch],
|
| 632 |
+
# "image_paths": [item["image_paths"] for item in batch],
|
| 633 |
+
# "widths": [item.get("width", item["images"].shape[-1]) for item in batch],
|
| 634 |
+
# "heights": [item.get("height", item["images"].shape[-2]) for item in batch]
|
| 635 |
+
# }
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def create_dataloader(
|
| 639 |
+
dataset: Dataset,
|
| 640 |
+
batch_size: int = 4,
|
| 641 |
+
shuffle: bool = True,
|
| 642 |
+
num_workers: int = 4,
|
| 643 |
+
pin_memory: bool = True,
|
| 644 |
+
drop_last: bool = True
|
| 645 |
+
) -> DataLoader:
|
| 646 |
+
"""
|
| 647 |
+
Create dataloader with appropriate settings
|
| 648 |
+
创建具有适当设置的数据加载器
|
| 649 |
+
"""
|
| 650 |
+
return DataLoader(
|
| 651 |
+
dataset,
|
| 652 |
+
batch_size=batch_size,
|
| 653 |
+
shuffle=shuffle,
|
| 654 |
+
num_workers=num_workers,
|
| 655 |
+
pin_memory=pin_memory,
|
| 656 |
+
drop_last=drop_last,
|
| 657 |
+
collate_fn=collate_fn
|
| 658 |
+
)
|
arch/example_train.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example Training Script using Arch Components
|
| 3 |
+
使用架构组件的示例训练脚本
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
import os
|
| 10 |
+
import argparse
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import wandb
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
# Import arch components
|
| 16 |
+
from arch import (
|
| 17 |
+
QwenTextEncoder,
|
| 18 |
+
QwenEmbeddingAdapter,
|
| 19 |
+
load_unet_from_safetensors,
|
| 20 |
+
load_vae_from_safetensors,
|
| 21 |
+
create_scheduler,
|
| 22 |
+
DiffusionLoss,
|
| 23 |
+
AdapterTrainingStep,
|
| 24 |
+
get_cosine_schedule_with_warmup,
|
| 25 |
+
EMAModel,
|
| 26 |
+
ImageCaptionDataset,
|
| 27 |
+
MultiAspectDataset,
|
| 28 |
+
create_dataloader
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def parse_args():
|
| 33 |
+
parser = argparse.ArgumentParser(description="Train Qwen-SDXL Adapter")
|
| 34 |
+
|
| 35 |
+
# Model paths
|
| 36 |
+
parser.add_argument("--qwen_model_path", type=str, default="models/Qwen3-Embedding-0.6B")
|
| 37 |
+
parser.add_argument("--unet_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_unet.safetensors")
|
| 38 |
+
parser.add_argument("--unet_config_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_unet_config.json")
|
| 39 |
+
parser.add_argument("--vae_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_vae.safetensors")
|
| 40 |
+
parser.add_argument("--vae_config_path", type=str, default="models/extracted_components/waiNSFWIllustrious_v140_vae_config.json")
|
| 41 |
+
|
| 42 |
+
# Data
|
| 43 |
+
parser.add_argument("--data_root", type=str, required=True, help="Root directory of training images")
|
| 44 |
+
parser.add_argument("--annotations_file", type=str, required=True, help="Path to annotations file (JSON/JSONL)")
|
| 45 |
+
parser.add_argument("--caption_column", type=str, default="caption")
|
| 46 |
+
parser.add_argument("--image_column", type=str, default="image")
|
| 47 |
+
parser.add_argument("--use_multi_aspect", action="store_true", help="Use multi-aspect ratio dataset")
|
| 48 |
+
|
| 49 |
+
# Training
|
| 50 |
+
parser.add_argument("--batch_size", type=int, default=4)
|
| 51 |
+
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
| 52 |
+
parser.add_argument("--num_epochs", type=int, default=10)
|
| 53 |
+
parser.add_argument("--warmup_steps", type=int, default=500)
|
| 54 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
| 55 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
| 56 |
+
|
| 57 |
+
# Loss
|
| 58 |
+
parser.add_argument("--loss_type", type=str, default="mse", choices=["mse", "l1", "huber"])
|
| 59 |
+
parser.add_argument("--snr_gamma", type=float, default=None, help="SNR gamma for loss weighting")
|
| 60 |
+
parser.add_argument("--use_v_parameterization", action="store_true")
|
| 61 |
+
|
| 62 |
+
# Optimization
|
| 63 |
+
parser.add_argument("--optimizer", type=str, default="adamw", choices=["adamw", "adam"])
|
| 64 |
+
parser.add_argument("--weight_decay", type=float, default=0.01)
|
| 65 |
+
parser.add_argument("--use_ema", action="store_true", help="Use EMA for adapter")
|
| 66 |
+
parser.add_argument("--ema_decay", type=float, default=0.9999)
|
| 67 |
+
|
| 68 |
+
# Checkpointing
|
| 69 |
+
parser.add_argument("--output_dir", type=str, default="./checkpoints")
|
| 70 |
+
parser.add_argument("--save_steps", type=int, default=1000)
|
| 71 |
+
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
| 72 |
+
|
| 73 |
+
# Logging
|
| 74 |
+
parser.add_argument("--logging_steps", type=int, default=50)
|
| 75 |
+
parser.add_argument("--use_wandb", action="store_true")
|
| 76 |
+
parser.add_argument("--wandb_project", type=str, default="qwen-sdxl-training")
|
| 77 |
+
parser.add_argument("--wandb_run_name", type=str, default=None)
|
| 78 |
+
|
| 79 |
+
# Hardware
|
| 80 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 81 |
+
parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"])
|
| 82 |
+
parser.add_argument("--num_workers", type=int, default=4)
|
| 83 |
+
|
| 84 |
+
return parser.parse_args()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def setup_models(args):
|
| 88 |
+
"""Setup all model components"""
|
| 89 |
+
print("🚀 设置模型组件...")
|
| 90 |
+
|
| 91 |
+
# Convert dtype string to torch dtype
|
| 92 |
+
dtype_map = {
|
| 93 |
+
"float32": torch.float32,
|
| 94 |
+
"float16": torch.float16,
|
| 95 |
+
"bfloat16": torch.bfloat16
|
| 96 |
+
}
|
| 97 |
+
dtype = dtype_map[args.dtype]
|
| 98 |
+
|
| 99 |
+
# Load text encoder
|
| 100 |
+
print("📝 加载 Qwen 文本编码器...")
|
| 101 |
+
text_encoder = QwenTextEncoder(
|
| 102 |
+
model_path=args.qwen_model_path,
|
| 103 |
+
device=args.device,
|
| 104 |
+
freeze_encoder=True
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Initialize adapter
|
| 108 |
+
print("🔧 初始化适配器...")
|
| 109 |
+
adapter = QwenEmbeddingAdapter()
|
| 110 |
+
adapter.to(args.device, dtype)
|
| 111 |
+
|
| 112 |
+
# Load UNet
|
| 113 |
+
print("🏗️ 加载 UNet...")
|
| 114 |
+
unet = load_unet_from_safetensors(
|
| 115 |
+
args.unet_path,
|
| 116 |
+
args.unet_config_path,
|
| 117 |
+
args.device,
|
| 118 |
+
dtype
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Load VAE
|
| 122 |
+
print("🎨 加载 VAE...")
|
| 123 |
+
vae = load_vae_from_safetensors(
|
| 124 |
+
args.vae_path,
|
| 125 |
+
args.vae_config_path,
|
| 126 |
+
args.device,
|
| 127 |
+
dtype
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Create scheduler
|
| 131 |
+
print("⏰ 创建调度器...")
|
| 132 |
+
noise_scheduler = create_scheduler("DDPM")
|
| 133 |
+
|
| 134 |
+
return text_encoder, adapter, unet, vae, noise_scheduler, dtype
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def setup_data(args):
|
| 138 |
+
"""Setup data loaders"""
|
| 139 |
+
print("📚 设置数据加载器...")
|
| 140 |
+
|
| 141 |
+
if args.use_multi_aspect:
|
| 142 |
+
dataset = MultiAspectDataset(
|
| 143 |
+
data_root=args.data_root,
|
| 144 |
+
annotations_file=args.annotations_file,
|
| 145 |
+
caption_column=args.caption_column,
|
| 146 |
+
image_column=args.image_column
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
dataset = ImageCaptionDataset(
|
| 150 |
+
data_root=args.data_root,
|
| 151 |
+
annotations_file=args.annotations_file,
|
| 152 |
+
caption_column=args.caption_column,
|
| 153 |
+
image_column=args.image_column
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
dataloader = create_dataloader(
|
| 157 |
+
dataset,
|
| 158 |
+
batch_size=args.batch_size,
|
| 159 |
+
shuffle=True,
|
| 160 |
+
num_workers=args.num_workers,
|
| 161 |
+
pin_memory=True,
|
| 162 |
+
drop_last=True
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
return dataloader
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def setup_training(args, adapter, noise_scheduler):
|
| 169 |
+
"""Setup training components"""
|
| 170 |
+
print("🎯 设置训练组件...")
|
| 171 |
+
|
| 172 |
+
# Loss function
|
| 173 |
+
loss_fn = DiffusionLoss(
|
| 174 |
+
noise_scheduler=noise_scheduler,
|
| 175 |
+
loss_type=args.loss_type,
|
| 176 |
+
snr_gamma=args.snr_gamma,
|
| 177 |
+
use_v_parameterization=args.use_v_parameterization
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Optimizer
|
| 181 |
+
if args.optimizer == "adamw":
|
| 182 |
+
optimizer = optim.AdamW(
|
| 183 |
+
adapter.parameters(),
|
| 184 |
+
lr=args.learning_rate,
|
| 185 |
+
weight_decay=args.weight_decay,
|
| 186 |
+
betas=(0.9, 0.999)
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
optimizer = optim.Adam(
|
| 190 |
+
adapter.parameters(),
|
| 191 |
+
lr=args.learning_rate,
|
| 192 |
+
weight_decay=args.weight_decay
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# EMA
|
| 196 |
+
ema = None
|
| 197 |
+
if args.use_ema:
|
| 198 |
+
ema = EMAModel(adapter, decay=args.ema_decay)
|
| 199 |
+
|
| 200 |
+
return loss_fn, optimizer, ema
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def train_step(training_step_fn, batch, optimizer, args, ema=None):
|
| 204 |
+
"""Execute one training step"""
|
| 205 |
+
# Handle different batch formats
|
| 206 |
+
if isinstance(batch["images"], list):
|
| 207 |
+
# Multi-size batch, train one by one
|
| 208 |
+
total_loss = 0
|
| 209 |
+
num_samples = 0
|
| 210 |
+
|
| 211 |
+
for i in range(len(batch["images"])):
|
| 212 |
+
images = batch["images"][i].unsqueeze(0)
|
| 213 |
+
captions = [batch["captions"][i]]
|
| 214 |
+
|
| 215 |
+
step_output = training_step_fn.training_step(images, captions)
|
| 216 |
+
loss = step_output["loss"] / args.gradient_accumulation_steps
|
| 217 |
+
|
| 218 |
+
loss.backward()
|
| 219 |
+
total_loss += loss.item()
|
| 220 |
+
num_samples += 1
|
| 221 |
+
|
| 222 |
+
avg_loss = total_loss / num_samples if num_samples > 0 else 0
|
| 223 |
+
else:
|
| 224 |
+
# Regular batch
|
| 225 |
+
images = batch["images"]
|
| 226 |
+
captions = batch["captions"]
|
| 227 |
+
|
| 228 |
+
step_output = training_step_fn.training_step(images, captions)
|
| 229 |
+
loss = step_output["loss"] / args.gradient_accumulation_steps
|
| 230 |
+
|
| 231 |
+
loss.backward()
|
| 232 |
+
avg_loss = loss.item()
|
| 233 |
+
|
| 234 |
+
# Gradient clipping and optimization step
|
| 235 |
+
torch.nn.utils.clip_grad_norm_(training_step_fn.adapter.parameters(), args.max_grad_norm)
|
| 236 |
+
optimizer.step()
|
| 237 |
+
optimizer.zero_grad()
|
| 238 |
+
|
| 239 |
+
# Update EMA
|
| 240 |
+
if ema is not None:
|
| 241 |
+
ema.update()
|
| 242 |
+
|
| 243 |
+
return avg_loss
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def save_checkpoint(adapter, optimizer, ema, epoch, step, args):
|
| 247 |
+
"""Save training checkpoint"""
|
| 248 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 249 |
+
|
| 250 |
+
# Save adapter
|
| 251 |
+
adapter_path = os.path.join(args.output_dir, f"adapter_epoch_{epoch}_step_{step}.safetensors")
|
| 252 |
+
if hasattr(adapter, 'save_adapter'):
|
| 253 |
+
adapter.save_adapter(adapter_path)
|
| 254 |
+
else:
|
| 255 |
+
import safetensors.torch
|
| 256 |
+
safetensors.torch.save_file(adapter.state_dict(), adapter_path)
|
| 257 |
+
|
| 258 |
+
# Save EMA adapter if available
|
| 259 |
+
if ema is not None:
|
| 260 |
+
ema.apply_shadow()
|
| 261 |
+
ema_path = os.path.join(args.output_dir, f"adapter_ema_epoch_{epoch}_step_{step}.safetensors")
|
| 262 |
+
import safetensors.torch
|
| 263 |
+
safetensors.torch.save_file(adapter.state_dict(), ema_path)
|
| 264 |
+
ema.restore()
|
| 265 |
+
|
| 266 |
+
# Save training state
|
| 267 |
+
state_path = os.path.join(args.output_dir, f"training_state_epoch_{epoch}_step_{step}.pt")
|
| 268 |
+
torch.save({
|
| 269 |
+
"epoch": epoch,
|
| 270 |
+
"step": step,
|
| 271 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 272 |
+
"args": args
|
| 273 |
+
}, state_path)
|
| 274 |
+
|
| 275 |
+
print(f"💾 检查点已保存: epoch {epoch}, step {step}")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def main():
|
| 279 |
+
args = parse_args()
|
| 280 |
+
|
| 281 |
+
# Setup wandb
|
| 282 |
+
if args.use_wandb:
|
| 283 |
+
wandb.init(
|
| 284 |
+
project=args.wandb_project,
|
| 285 |
+
name=args.wandb_run_name,
|
| 286 |
+
config=vars(args)
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Setup models
|
| 290 |
+
text_encoder, adapter, unet, vae, noise_scheduler, dtype = setup_models(args)
|
| 291 |
+
|
| 292 |
+
# Setup data
|
| 293 |
+
dataloader = setup_data(args)
|
| 294 |
+
|
| 295 |
+
# Setup training
|
| 296 |
+
loss_fn, optimizer, ema = setup_training(args, adapter, noise_scheduler)
|
| 297 |
+
|
| 298 |
+
# Create training step function
|
| 299 |
+
training_step_fn = AdapterTrainingStep(
|
| 300 |
+
unet=unet,
|
| 301 |
+
vae=vae,
|
| 302 |
+
text_encoder=text_encoder,
|
| 303 |
+
adapter=adapter,
|
| 304 |
+
noise_scheduler=noise_scheduler,
|
| 305 |
+
loss_fn=loss_fn,
|
| 306 |
+
device=args.device,
|
| 307 |
+
dtype=dtype
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Setup learning rate scheduler
|
| 311 |
+
total_steps = len(dataloader) * args.num_epochs // args.gradient_accumulation_steps
|
| 312 |
+
lr_scheduler = get_cosine_schedule_with_warmup(
|
| 313 |
+
optimizer,
|
| 314 |
+
num_warmup_steps=args.warmup_steps,
|
| 315 |
+
num_training_steps=total_steps
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
print(f"🎓 开始训练: {args.num_epochs} epochs, {len(dataloader)} steps/epoch")
|
| 319 |
+
print(f"📊 总训练步数: {total_steps}")
|
| 320 |
+
|
| 321 |
+
# Training loop
|
| 322 |
+
global_step = 0
|
| 323 |
+
|
| 324 |
+
for epoch in range(args.num_epochs):
|
| 325 |
+
adapter.train()
|
| 326 |
+
epoch_loss = 0
|
| 327 |
+
|
| 328 |
+
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.num_epochs}")
|
| 329 |
+
|
| 330 |
+
for step, batch in enumerate(progress_bar):
|
| 331 |
+
step_loss = train_step(training_step_fn, batch, optimizer, args, ema)
|
| 332 |
+
epoch_loss += step_loss
|
| 333 |
+
|
| 334 |
+
# Update learning rate
|
| 335 |
+
lr_scheduler.step()
|
| 336 |
+
|
| 337 |
+
global_step += 1
|
| 338 |
+
|
| 339 |
+
# Logging
|
| 340 |
+
if global_step % args.logging_steps == 0:
|
| 341 |
+
avg_loss = epoch_loss / (step + 1)
|
| 342 |
+
current_lr = lr_scheduler.get_last_lr()[0]
|
| 343 |
+
|
| 344 |
+
progress_bar.set_postfix({
|
| 345 |
+
"loss": f"{step_loss:.4f}",
|
| 346 |
+
"avg_loss": f"{avg_loss:.4f}",
|
| 347 |
+
"lr": f"{current_lr:.2e}"
|
| 348 |
+
})
|
| 349 |
+
|
| 350 |
+
if args.use_wandb:
|
| 351 |
+
wandb.log({
|
| 352 |
+
"train/loss": step_loss,
|
| 353 |
+
"train/avg_loss": avg_loss,
|
| 354 |
+
"train/learning_rate": current_lr,
|
| 355 |
+
"train/epoch": epoch,
|
| 356 |
+
"train/step": global_step
|
| 357 |
+
})
|
| 358 |
+
|
| 359 |
+
# Save checkpoint
|
| 360 |
+
if global_step % args.save_steps == 0:
|
| 361 |
+
save_checkpoint(adapter, optimizer, ema, epoch, global_step, args)
|
| 362 |
+
|
| 363 |
+
# End of epoch
|
| 364 |
+
avg_epoch_loss = epoch_loss / len(dataloader)
|
| 365 |
+
print(f"📈 Epoch {epoch+1} 完成,平均损失: {avg_epoch_loss:.4f}")
|
| 366 |
+
|
| 367 |
+
# Save epoch checkpoint
|
| 368 |
+
save_checkpoint(adapter, optimizer, ema, epoch+1, global_step, args)
|
| 369 |
+
|
| 370 |
+
print("🎉 训练完成!")
|
| 371 |
+
|
| 372 |
+
if args.use_wandb:
|
| 373 |
+
wandb.finish()
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
if __name__ == "__main__":
|
| 377 |
+
main()
|
arch/model_loader.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Loader Utilities
|
| 3 |
+
模型加载工具 - 用于加载各种模型组件
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import json
|
| 8 |
+
import safetensors.torch
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_unet_from_safetensors(unet_path: str, config_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16):
|
| 13 |
+
"""
|
| 14 |
+
Load UNet from safetensors file
|
| 15 |
+
从 safetensors 文件加载 UNet
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
unet_path: Path to UNet safetensors file
|
| 19 |
+
config_path: Path to UNet config JSON file
|
| 20 |
+
device: Device to load model on
|
| 21 |
+
dtype: Data type for model weights
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
UNet2DConditionModel or None if loading fails
|
| 25 |
+
"""
|
| 26 |
+
try:
|
| 27 |
+
from diffusers import UNet2DConditionModel
|
| 28 |
+
|
| 29 |
+
# Load config
|
| 30 |
+
with open(config_path, 'r') as f:
|
| 31 |
+
unet_config = json.load(f)
|
| 32 |
+
|
| 33 |
+
# Create UNet
|
| 34 |
+
unet = UNet2DConditionModel.from_config(unet_config)
|
| 35 |
+
|
| 36 |
+
# Load weights
|
| 37 |
+
state_dict = safetensors.torch.load_file(unet_path)
|
| 38 |
+
unet.load_state_dict(state_dict)
|
| 39 |
+
unet.to(device, dtype)
|
| 40 |
+
|
| 41 |
+
return unet
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"Error loading UNet: {e}")
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_vae_from_safetensors(vae_path: str, config_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16):
|
| 48 |
+
"""
|
| 49 |
+
Load VAE from safetensors file
|
| 50 |
+
从 safetensors 文件加载 VAE
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
vae_path: Path to VAE safetensors file
|
| 54 |
+
config_path: Path to VAE config JSON file
|
| 55 |
+
device: Device to load model on
|
| 56 |
+
dtype: Data type for model weights
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
AutoencoderKL or None if loading fails
|
| 60 |
+
"""
|
| 61 |
+
try:
|
| 62 |
+
from diffusers import AutoencoderKL
|
| 63 |
+
|
| 64 |
+
# Load config
|
| 65 |
+
with open(config_path, 'r') as f:
|
| 66 |
+
vae_config = json.load(f)
|
| 67 |
+
|
| 68 |
+
# Create VAE
|
| 69 |
+
vae = AutoencoderKL.from_config(vae_config)
|
| 70 |
+
|
| 71 |
+
# Load weights
|
| 72 |
+
state_dict = safetensors.torch.load_file(vae_path)
|
| 73 |
+
vae.load_state_dict(state_dict)
|
| 74 |
+
vae.to(device, dtype)
|
| 75 |
+
|
| 76 |
+
return vae
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f"Error loading VAE: {e}")
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def create_scheduler(scheduler_type: str = "EulerAncestral", model_id: str = "stabilityai/stable-diffusion-xl-base-1.0"):
|
| 83 |
+
"""
|
| 84 |
+
Create scheduler for diffusion process
|
| 85 |
+
创建扩散过程调度器
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
scheduler_type: Type of scheduler to create
|
| 89 |
+
model_id: Model ID to load scheduler config from
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Scheduler object or None if creation fails
|
| 93 |
+
"""
|
| 94 |
+
try:
|
| 95 |
+
if scheduler_type == "DDPM":
|
| 96 |
+
from diffusers import DDPMScheduler
|
| 97 |
+
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
| 98 |
+
elif scheduler_type == "DDIM":
|
| 99 |
+
from diffusers import DDIMScheduler
|
| 100 |
+
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
| 101 |
+
elif scheduler_type == "DPMSolverMultistep":
|
| 102 |
+
from diffusers import DPMSolverMultistepScheduler
|
| 103 |
+
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
|
| 104 |
+
elif scheduler_type == "EulerAncestral":
|
| 105 |
+
from diffusers import EulerAncestralDiscreteScheduler
|
| 106 |
+
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
|
| 107 |
+
else:
|
| 108 |
+
print(f"Unsupported scheduler type: {scheduler_type}, using DDPM")
|
| 109 |
+
from diffusers import DDPMScheduler
|
| 110 |
+
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
| 111 |
+
|
| 112 |
+
return scheduler
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"Error creating scheduler: {e}")
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def load_qwen_model(model_path: str, device: str = "cuda"):
|
| 119 |
+
"""
|
| 120 |
+
Load Qwen3 embedding model
|
| 121 |
+
加载 Qwen3 嵌入模型
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
model_path: Path to Qwen model
|
| 125 |
+
device: Device to load model on
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
SentenceTransformer model or None if loading fails
|
| 129 |
+
"""
|
| 130 |
+
try:
|
| 131 |
+
from sentence_transformers import SentenceTransformer
|
| 132 |
+
model = SentenceTransformer(model_path)
|
| 133 |
+
model.to(device)
|
| 134 |
+
return model
|
| 135 |
+
except ImportError:
|
| 136 |
+
print("Warning: sentence-transformers not available. Using mock embeddings.")
|
| 137 |
+
return None
|
| 138 |
+
except Exception as e:
|
| 139 |
+
print(f"Error loading Qwen model: {e}")
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def save_model_components(
|
| 144 |
+
unet,
|
| 145 |
+
vae,
|
| 146 |
+
adapter,
|
| 147 |
+
text_encoder,
|
| 148 |
+
save_dir: str,
|
| 149 |
+
save_format: str = "safetensors"
|
| 150 |
+
):
|
| 151 |
+
"""
|
| 152 |
+
Save model components for training checkpoints
|
| 153 |
+
保存模型组件用于训练检查点
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
unet: UNet model
|
| 157 |
+
vae: VAE model
|
| 158 |
+
adapter: Qwen embedding adapter
|
| 159 |
+
text_encoder: Qwen text encoder
|
| 160 |
+
save_dir: Directory to save components
|
| 161 |
+
save_format: Format to save in (safetensors or pt)
|
| 162 |
+
"""
|
| 163 |
+
import os
|
| 164 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
if save_format == "safetensors":
|
| 168 |
+
# Save UNet
|
| 169 |
+
if unet is not None:
|
| 170 |
+
safetensors.torch.save_file(
|
| 171 |
+
unet.state_dict(),
|
| 172 |
+
os.path.join(save_dir, "unet.safetensors")
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Save VAE
|
| 176 |
+
if vae is not None:
|
| 177 |
+
safetensors.torch.save_file(
|
| 178 |
+
vae.state_dict(),
|
| 179 |
+
os.path.join(save_dir, "vae.safetensors")
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Save adapter
|
| 183 |
+
if adapter is not None:
|
| 184 |
+
safetensors.torch.save_file(
|
| 185 |
+
adapter.state_dict(),
|
| 186 |
+
os.path.join(save_dir, "adapter.safetensors")
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
else: # PyTorch format
|
| 190 |
+
if unet is not None:
|
| 191 |
+
torch.save(unet.state_dict(), os.path.join(save_dir, "unet.pt"))
|
| 192 |
+
if vae is not None:
|
| 193 |
+
torch.save(vae.state_dict(), os.path.join(save_dir, "vae.pt"))
|
| 194 |
+
if adapter is not None:
|
| 195 |
+
torch.save(adapter.state_dict(), os.path.join(save_dir, "adapter.pt"))
|
| 196 |
+
|
| 197 |
+
print(f"Model components saved to {save_dir}")
|
| 198 |
+
|
| 199 |
+
except Exception as e:
|
| 200 |
+
print(f"Error saving model components: {e}")
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def load_unet_with_lora(
|
| 204 |
+
unet_path: str,
|
| 205 |
+
unet_config_path: str,
|
| 206 |
+
lora_weights_path: Optional[str] = None,
|
| 207 |
+
lora_config_path: Optional[str] = None,
|
| 208 |
+
device: str = "cuda",
|
| 209 |
+
dtype: torch.dtype = torch.bfloat16
|
| 210 |
+
):
|
| 211 |
+
"""
|
| 212 |
+
Load UNet with optional LoRA weights
|
| 213 |
+
加载带有可选LoRA权重的UNet
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
base_unet_path: Path to base UNet (can be safetensors file or HF model path)
|
| 217 |
+
lora_weights_path: Optional path to LoRA weights (safetensors file)
|
| 218 |
+
lora_config_path: Optional path to LoRA config directory
|
| 219 |
+
device: Device to load model on
|
| 220 |
+
dtype: Data type for model weights
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
UNet model with LoRA applied if specified
|
| 224 |
+
"""
|
| 225 |
+
try:
|
| 226 |
+
from diffusers import UNet2DConditionModel
|
| 227 |
+
from peft import PeftModel, LoraConfig
|
| 228 |
+
|
| 229 |
+
# Load base UNet
|
| 230 |
+
# if unet_path.endswith(".safetensors"):
|
| 231 |
+
# # Load from safetensors file (need config too)
|
| 232 |
+
# print("Loading UNet from safetensors format requires separate config file")
|
| 233 |
+
# return None
|
| 234 |
+
# else:
|
| 235 |
+
# Load from HuggingFace model path
|
| 236 |
+
# unet = UNet2DConditionModel.from_pretrained(
|
| 237 |
+
# base_unet_path,
|
| 238 |
+
# subfolder="unet" if "/" in base_unet_path and not base_unet_path.endswith("unet") else None,
|
| 239 |
+
# torch_dtype=dtype
|
| 240 |
+
# )
|
| 241 |
+
unet = load_unet_from_safetensors(unet_path, unet_config_path, device, dtype)
|
| 242 |
+
|
| 243 |
+
# Apply LoRA if provided
|
| 244 |
+
if lora_weights_path and lora_config_path:
|
| 245 |
+
print(f"Loading LoRA weights from {lora_weights_path}")
|
| 246 |
+
|
| 247 |
+
# Load LoRA weights
|
| 248 |
+
if lora_weights_path.endswith(".safetensors"):
|
| 249 |
+
import safetensors.torch
|
| 250 |
+
lora_state_dict = safetensors.torch.load_file(lora_weights_path)
|
| 251 |
+
else:
|
| 252 |
+
lora_state_dict = torch.load(lora_weights_path, map_location=device)
|
| 253 |
+
|
| 254 |
+
# Load LoRA config
|
| 255 |
+
lora_config = LoraConfig.from_pretrained(lora_config_path)
|
| 256 |
+
|
| 257 |
+
# Apply LoRA to UNet
|
| 258 |
+
from peft import get_peft_model, set_peft_model_state_dict
|
| 259 |
+
unet = get_peft_model(unet, lora_config)
|
| 260 |
+
set_peft_model_state_dict(unet, lora_state_dict)
|
| 261 |
+
|
| 262 |
+
print("LoRA weights applied to UNet")
|
| 263 |
+
|
| 264 |
+
unet.to(device, dtype)
|
| 265 |
+
return unet
|
| 266 |
+
|
| 267 |
+
except Exception as e:
|
| 268 |
+
print(f"Error loading UNet with LoRA: {e}")
|
| 269 |
+
return None
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def load_fused_unet(
|
| 273 |
+
fused_unet_path: str,
|
| 274 |
+
device: str = "cuda",
|
| 275 |
+
dtype: torch.dtype = torch.bfloat16
|
| 276 |
+
):
|
| 277 |
+
"""
|
| 278 |
+
Load UNet with fused LoRA weights
|
| 279 |
+
加载融合了LoRA权重的UNet
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
fused_unet_path: Path to fused UNet model directory
|
| 283 |
+
device: Device to load model on
|
| 284 |
+
dtype: Data type for model weights
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
UNet model with fused LoRA weights
|
| 288 |
+
"""
|
| 289 |
+
try:
|
| 290 |
+
from diffusers import UNet2DConditionModel
|
| 291 |
+
|
| 292 |
+
unet = UNet2DConditionModel.from_pretrained(
|
| 293 |
+
fused_unet_path,
|
| 294 |
+
torch_dtype=dtype
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
unet.to(device, dtype)
|
| 298 |
+
print(f"Fused UNet loaded from {fused_unet_path}")
|
| 299 |
+
return unet
|
| 300 |
+
|
| 301 |
+
except Exception as e:
|
| 302 |
+
print(f"Error loading fused UNet: {e}")
|
| 303 |
+
return None
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def load_checkpoint(checkpoint_path: str, device: str = "cuda"):
|
| 307 |
+
"""
|
| 308 |
+
Load training checkpoint
|
| 309 |
+
加载训练检查点
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
checkpoint_path: Path to checkpoint file
|
| 313 |
+
device: Device to load on
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
Dictionary containing checkpoint data
|
| 317 |
+
"""
|
| 318 |
+
try:
|
| 319 |
+
if checkpoint_path.endswith(".safetensors"):
|
| 320 |
+
return safetensors.torch.load_file(checkpoint_path, device=device)
|
| 321 |
+
else:
|
| 322 |
+
return torch.load(checkpoint_path, map_location=device)
|
| 323 |
+
except Exception as e:
|
| 324 |
+
print(f"Error loading checkpoint: {e}")
|
| 325 |
+
return None
|
arch/pipeline.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Qwen-SDXL Inference Pipeline
|
| 3 |
+
Qwen-SDXL 推理管道 - 使用 Qwen3 嵌入模型替代 CLIP 文本编码器的 SDXL 推理管道
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from typing import List, Optional, Union, Tuple
|
| 11 |
+
|
| 12 |
+
from .adapter import QwenEmbeddingAdapter
|
| 13 |
+
from .text_encoder import QwenTextEncoder
|
| 14 |
+
from .model_loader import load_qwen_model, load_unet_from_safetensors, load_vae_from_safetensors, create_scheduler
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class QwenIllustriousInference:
|
| 18 |
+
"""
|
| 19 |
+
Qwen-SDXL 推理管道
|
| 20 |
+
使用 Qwen3 嵌入模型替代 CLIP 文本编码器的 SDXL 推理管道
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
qwen_model_path: str = "models/Qwen3-Embedding-0.6B",
|
| 26 |
+
unet_path: str = "models/extracted_components/waiNSFWIllustrious_v140_unet.safetensors",
|
| 27 |
+
unet_config_path: str = "models/extracted_components/waiNSFWIllustrious_v140_unet_config.json",
|
| 28 |
+
vae_path: str = "models/extracted_components/waiNSFWIllustrious_v140_vae.safetensors",
|
| 29 |
+
vae_config_path: str = "models/extracted_components/waiNSFWIllustrious_v140_vae_config.json",
|
| 30 |
+
adapter_path: Optional[str] = "/home/ubuntu/lyl/QwenIllustrious/qwen_illustrious_output/adapter/adapter.safetensors",
|
| 31 |
+
lora_weights_path: Optional[str] = "/home/ubuntu/lyl/QwenIllustrious/qwen_illustrious_output/lora_weights/lora_weights.safetensors",
|
| 32 |
+
lora_config_path: Optional[str] = "/home/ubuntu/lyl/QwenIllustrious/qwen_illustrious_output/lora_weights/adapter_config.json",
|
| 33 |
+
use_fused_unet: bool = False,
|
| 34 |
+
fused_unet_path: Optional[str] = None,
|
| 35 |
+
device: str = "cuda",
|
| 36 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 37 |
+
scheduler_type: str = "DDPM"
|
| 38 |
+
):
|
| 39 |
+
self.device = device
|
| 40 |
+
self.dtype = dtype
|
| 41 |
+
self.vae_scale_factor = 8 # SDXL default
|
| 42 |
+
|
| 43 |
+
print("🚀 初始化 Qwen-SDXL 推理管道...")
|
| 44 |
+
|
| 45 |
+
# Initialize text encoder
|
| 46 |
+
print("📝 初始化 Qwen 文本编码器...")
|
| 47 |
+
self.text_encoder = QwenTextEncoder(
|
| 48 |
+
model_path=qwen_model_path,
|
| 49 |
+
device=device,
|
| 50 |
+
freeze_encoder=True
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Initialize adapter layer
|
| 54 |
+
print("🔧 初始化适配器层...")
|
| 55 |
+
self.adapter = QwenEmbeddingAdapter()
|
| 56 |
+
self.adapter.to(device, dtype)
|
| 57 |
+
|
| 58 |
+
# Load adapter weights if provided
|
| 59 |
+
if adapter_path is not None:
|
| 60 |
+
print(f"📥 加载适配器权重: {adapter_path}")
|
| 61 |
+
try:
|
| 62 |
+
if adapter_path.endswith(".safetensors"):
|
| 63 |
+
import safetensors.torch
|
| 64 |
+
adapter_state = safetensors.torch.load_file(adapter_path)
|
| 65 |
+
else:
|
| 66 |
+
adapter_state = torch.load(adapter_path, map_location=device)
|
| 67 |
+
self.adapter.load_state_dict(adapter_state)
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"⚠️ 加载适配器权重失败: {e}")
|
| 70 |
+
|
| 71 |
+
# Load UNet (with LoRA support)
|
| 72 |
+
print("🏗️ 加载 UNet 模型...")
|
| 73 |
+
from .model_loader import load_unet_with_lora, load_fused_unet
|
| 74 |
+
|
| 75 |
+
if use_fused_unet and fused_unet_path:
|
| 76 |
+
# Load fused UNet with merged LoRA weights
|
| 77 |
+
print("📦 使用融合LoRA权重的UNet...")
|
| 78 |
+
self.unet = load_fused_unet(fused_unet_path, device, dtype)
|
| 79 |
+
elif lora_weights_path and lora_config_path:
|
| 80 |
+
# Load UNet with separate LoRA weights
|
| 81 |
+
print("🔧 加载UNet并应用LoRA权重...")
|
| 82 |
+
# For this case, use the base SDXL model path instead of safetensors
|
| 83 |
+
# base_model_path = unet_path.replace("/unet.safetensors", "").replace("/extracted_components/waiNSFWIllustrious_v140_unet.safetensors", "")
|
| 84 |
+
self.unet = load_unet_with_lora(
|
| 85 |
+
unet_path=unet_path,
|
| 86 |
+
unet_config_path=unet_config_path,
|
| 87 |
+
lora_weights_path=lora_weights_path,
|
| 88 |
+
lora_config_path=lora_config_path,
|
| 89 |
+
device=device,
|
| 90 |
+
dtype=dtype
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
# Load standard UNet from safetensors
|
| 94 |
+
self.unet = load_unet_from_safetensors(unet_path, unet_config_path, device, dtype)
|
| 95 |
+
|
| 96 |
+
# Load VAE
|
| 97 |
+
print("🎨 加载 VAE 模型...")
|
| 98 |
+
self.vae = load_vae_from_safetensors(vae_path, vae_config_path, device, dtype)
|
| 99 |
+
|
| 100 |
+
# Initialize scheduler
|
| 101 |
+
print(f"⏰ 创建调度器 ({scheduler_type})...")
|
| 102 |
+
self.scheduler = create_scheduler(scheduler_type)
|
| 103 |
+
|
| 104 |
+
# Check if all components loaded successfully
|
| 105 |
+
self.is_ready = all([
|
| 106 |
+
self.text_encoder is not None,
|
| 107 |
+
self.adapter is not None,
|
| 108 |
+
self.unet is not None,
|
| 109 |
+
self.vae is not None,
|
| 110 |
+
self.scheduler is not None
|
| 111 |
+
])
|
| 112 |
+
|
| 113 |
+
if self.is_ready:
|
| 114 |
+
print("✅ 管道初始化完成!")
|
| 115 |
+
else:
|
| 116 |
+
print("❌ 管道初始化失败,某些组件加��失败")
|
| 117 |
+
|
| 118 |
+
def encode_prompts(
|
| 119 |
+
self,
|
| 120 |
+
prompts: List[str],
|
| 121 |
+
negative_prompts: Optional[List[str]] = None,
|
| 122 |
+
do_classifier_free_guidance: bool = True
|
| 123 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 124 |
+
"""
|
| 125 |
+
Encode prompts using Qwen3 + adapter
|
| 126 |
+
使用 Qwen3 + 适配器编码提示词
|
| 127 |
+
"""
|
| 128 |
+
# Get raw embeddings from Qwen
|
| 129 |
+
text_embeddings, pooled_embeddings = self.text_encoder.encode_prompts(
|
| 130 |
+
prompts, negative_prompts, do_classifier_free_guidance
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
batch_size = len(prompts)
|
| 134 |
+
if do_classifier_free_guidance:
|
| 135 |
+
batch_size *= 2
|
| 136 |
+
|
| 137 |
+
# Add sequence dimension for text embeddings (We uses 512 tokens for SDXL)
|
| 138 |
+
seq_len = 512
|
| 139 |
+
text_embeddings_seq = text_embeddings.unsqueeze(1).expand(-1, seq_len, -1) # [B, 512, 1024]
|
| 140 |
+
|
| 141 |
+
# Project to SDXL dimensions using adapter
|
| 142 |
+
prompt_embeds = self.adapter.forward_text_embeddings(text_embeddings_seq.to(self.dtype)) # [B, 512, 2048]
|
| 143 |
+
pooled_prompt_embeds = self.adapter.forward_pooled_embeddings(pooled_embeddings.to(self.dtype)) # [B, 1280]
|
| 144 |
+
|
| 145 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 146 |
+
|
| 147 |
+
def prepare_latents(
|
| 148 |
+
self,
|
| 149 |
+
batch_size: int,
|
| 150 |
+
height: int,
|
| 151 |
+
width: int,
|
| 152 |
+
generator: Optional[torch.Generator] = None
|
| 153 |
+
) -> torch.Tensor:
|
| 154 |
+
"""
|
| 155 |
+
Prepare initial latents
|
| 156 |
+
准备初始潜在变量
|
| 157 |
+
"""
|
| 158 |
+
if self.unet is None:
|
| 159 |
+
# Mock latents for testing
|
| 160 |
+
shape = (batch_size, 4, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
| 161 |
+
return torch.randn(shape, device=self.device, dtype=self.dtype)
|
| 162 |
+
|
| 163 |
+
shape = (
|
| 164 |
+
batch_size,
|
| 165 |
+
self.unet.config.in_channels,
|
| 166 |
+
height // self.vae_scale_factor,
|
| 167 |
+
width // self.vae_scale_factor,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
from diffusers.utils import randn_tensor
|
| 172 |
+
latents = randn_tensor(shape, generator=generator, device=self.device, dtype=self.dtype)
|
| 173 |
+
except ImportError:
|
| 174 |
+
latents = torch.randn(shape, device=self.device, dtype=self.dtype, generator=generator)
|
| 175 |
+
|
| 176 |
+
# Scale initial noise
|
| 177 |
+
if self.scheduler is not None:
|
| 178 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 179 |
+
|
| 180 |
+
return latents
|
| 181 |
+
|
| 182 |
+
def get_time_ids(
|
| 183 |
+
self,
|
| 184 |
+
height: int,
|
| 185 |
+
width: int,
|
| 186 |
+
original_size: Tuple[int, int],
|
| 187 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 188 |
+
target_size: Optional[Tuple[int, int]] = None
|
| 189 |
+
) -> torch.Tensor:
|
| 190 |
+
"""
|
| 191 |
+
Get SDXL time IDs for micro-conditioning
|
| 192 |
+
获取 SDXL 时间 ID 用于微调节
|
| 193 |
+
"""
|
| 194 |
+
if target_size is None:
|
| 195 |
+
target_size = (height, width)
|
| 196 |
+
|
| 197 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
| 198 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=self.dtype, device=self.device)
|
| 199 |
+
|
| 200 |
+
return add_time_ids
|
| 201 |
+
|
| 202 |
+
@torch.no_grad()
|
| 203 |
+
def generate(
|
| 204 |
+
self,
|
| 205 |
+
prompt: Union[str, List[str]],
|
| 206 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 207 |
+
height: int = 1024,
|
| 208 |
+
width: int = 1024,
|
| 209 |
+
num_inference_steps: int = 50,
|
| 210 |
+
guidance_scale: float = 7.5,
|
| 211 |
+
generator: Optional[torch.Generator] = None,
|
| 212 |
+
return_type: str = "pil"
|
| 213 |
+
) -> List[Image.Image]:
|
| 214 |
+
"""
|
| 215 |
+
Generate images using Qwen-SDXL pipeline
|
| 216 |
+
使用 Qwen-SDXL 管道生成图像
|
| 217 |
+
"""
|
| 218 |
+
if not self.is_ready:
|
| 219 |
+
print("❌ 管道未准备就绪,无法生成图像")
|
| 220 |
+
return []
|
| 221 |
+
|
| 222 |
+
# Prepare prompts
|
| 223 |
+
if isinstance(prompt, str):
|
| 224 |
+
prompt = [prompt]
|
| 225 |
+
if isinstance(negative_prompt, str):
|
| 226 |
+
negative_prompt = [negative_prompt]
|
| 227 |
+
|
| 228 |
+
batch_size = len(prompt)
|
| 229 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 230 |
+
|
| 231 |
+
print(f"🎯 开始生成 {batch_size} 张图像...")
|
| 232 |
+
print(f"📏 尺寸: {width}x{height}")
|
| 233 |
+
print(f"🔄 推理步数: {num_inference_steps}")
|
| 234 |
+
print(f"🎚️ 引导强度: {guidance_scale}")
|
| 235 |
+
|
| 236 |
+
# 1. Encode prompts
|
| 237 |
+
print("📝 编码提示词...")
|
| 238 |
+
prompt_embeds, pooled_prompt_embeds = self.encode_prompts(
|
| 239 |
+
prompt, negative_prompt, do_classifier_free_guidance
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# 2. Prepare timesteps
|
| 243 |
+
print("⏰ 准备时间步...")
|
| 244 |
+
if self.scheduler is not None:
|
| 245 |
+
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
|
| 246 |
+
timesteps = self.scheduler.timesteps
|
| 247 |
+
else:
|
| 248 |
+
timesteps = torch.linspace(1000, 0, num_inference_steps, device=self.device)
|
| 249 |
+
|
| 250 |
+
# 3. Prepare latents
|
| 251 |
+
print("🌀 准备潜在变量...")
|
| 252 |
+
latents = self.prepare_latents(batch_size, height, width, generator)
|
| 253 |
+
|
| 254 |
+
# 4. Prepare time IDs
|
| 255 |
+
original_size = (height, width)
|
| 256 |
+
target_size = (height, width)
|
| 257 |
+
add_time_ids = self.get_time_ids(height, width, original_size, target_size=target_size)
|
| 258 |
+
|
| 259 |
+
if do_classifier_free_guidance:
|
| 260 |
+
add_time_ids = add_time_ids.repeat(2, 1)
|
| 261 |
+
add_time_ids = add_time_ids.repeat(batch_size, 1)
|
| 262 |
+
|
| 263 |
+
# 5. Denoising loop
|
| 264 |
+
print("🔄 开始去噪过程...")
|
| 265 |
+
for i, t in enumerate(timesteps):
|
| 266 |
+
# Expand latents for classifier-free guidance
|
| 267 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 268 |
+
|
| 269 |
+
if self.scheduler is not None:
|
| 270 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 271 |
+
|
| 272 |
+
# Predict noise
|
| 273 |
+
if self.unet is not None:
|
| 274 |
+
added_cond_kwargs = {
|
| 275 |
+
"text_embeds": pooled_prompt_embeds,
|
| 276 |
+
"time_ids": add_time_ids
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
noise_pred = self.unet(
|
| 280 |
+
latent_model_input,
|
| 281 |
+
t,
|
| 282 |
+
encoder_hidden_states=prompt_embeds,
|
| 283 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 284 |
+
return_dict=False,
|
| 285 |
+
)[0]
|
| 286 |
+
|
| 287 |
+
# Classifier-free guidance
|
| 288 |
+
if do_classifier_free_guidance:
|
| 289 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 290 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 291 |
+
|
| 292 |
+
# Scheduler step
|
| 293 |
+
if self.scheduler is not None:
|
| 294 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 295 |
+
|
| 296 |
+
if (i + 1) % 5 == 0:
|
| 297 |
+
print(f" 步骤 {i+1}/{len(timesteps)} 完成")
|
| 298 |
+
|
| 299 |
+
# 6. Decode latents
|
| 300 |
+
print("🎨 解码生成图像...")
|
| 301 |
+
if self.vae is not None:
|
| 302 |
+
latents = latents / self.vae.config.scaling_factor
|
| 303 |
+
images = self.vae.decode(latents, return_dict=False)[0]
|
| 304 |
+
else:
|
| 305 |
+
# Mock image generation for testing
|
| 306 |
+
images = torch.randn(batch_size, 3, height, width, device=self.device)
|
| 307 |
+
|
| 308 |
+
# 7. Convert to PIL images
|
| 309 |
+
images = (images / 2 + 0.5).clamp(0, 1)
|
| 310 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 311 |
+
|
| 312 |
+
if return_type == "pil":
|
| 313 |
+
images = [Image.fromarray((img * 255).astype(np.uint8)) for img in images]
|
| 314 |
+
|
| 315 |
+
print("✅ 图像生成完成!")
|
| 316 |
+
return images
|
| 317 |
+
|
| 318 |
+
def save_adapter(self, save_path: str):
|
| 319 |
+
"""
|
| 320 |
+
Save adapter weights
|
| 321 |
+
保存适配器权重
|
| 322 |
+
"""
|
| 323 |
+
try:
|
| 324 |
+
if save_path.endswith(".safetensors"):
|
| 325 |
+
import safetensors.torch
|
| 326 |
+
safetensors.torch.save_file(self.adapter.state_dict(), save_path)
|
| 327 |
+
else:
|
| 328 |
+
torch.save(self.adapter.state_dict(), save_path)
|
| 329 |
+
print(f"✅ 适配器权重已保存到: {save_path}")
|
| 330 |
+
except Exception as e:
|
| 331 |
+
print(f"❌ 保存适配器权重失败: {e}")
|
| 332 |
+
|
| 333 |
+
def load_adapter(self, load_path: str):
|
| 334 |
+
"""
|
| 335 |
+
Load adapter weights
|
| 336 |
+
加载适配器权重
|
| 337 |
+
"""
|
| 338 |
+
try:
|
| 339 |
+
if load_path.endswith(".safetensors"):
|
| 340 |
+
import safetensors.torch
|
| 341 |
+
state_dict = safetensors.torch.load_file(load_path)
|
| 342 |
+
else:
|
| 343 |
+
state_dict = torch.load(load_path, map_location=self.device)
|
| 344 |
+
|
| 345 |
+
self.adapter.load_state_dict(state_dict)
|
| 346 |
+
print(f"✅ 适配器权重已从 {load_path} 加载")
|
| 347 |
+
except Exception as e:
|
| 348 |
+
print(f"❌ 加载适配器权重失败: {e}")
|
arch/text_encoder.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Qwen Text Encoder
|
| 3 |
+
Qwen 文本编码器 - 使用 Qwen3 模型进行文本编码
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from typing import List, Optional, Union, Tuple
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_qwen_model(model_path: str, device: str = "cuda"):
|
| 12 |
+
"""
|
| 13 |
+
Load Qwen3 embedding model
|
| 14 |
+
加载 Qwen3 嵌入模型
|
| 15 |
+
"""
|
| 16 |
+
try:
|
| 17 |
+
from sentence_transformers import SentenceTransformer
|
| 18 |
+
model = SentenceTransformer(model_path)
|
| 19 |
+
model.to(device)
|
| 20 |
+
return model
|
| 21 |
+
except ImportError:
|
| 22 |
+
print("Warning: sentence-transformers not available. Using mock embeddings.")
|
| 23 |
+
return None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def encode_text_with_qwen(
|
| 27 |
+
qwen_model,
|
| 28 |
+
texts: List[str],
|
| 29 |
+
device: str = "cuda",
|
| 30 |
+
max_length: int = 512,
|
| 31 |
+
use_query_mode: bool = False
|
| 32 |
+
) -> torch.Tensor:
|
| 33 |
+
"""
|
| 34 |
+
Encode text using Qwen3 model
|
| 35 |
+
使用 Qwen3 模型编码文本
|
| 36 |
+
Args:
|
| 37 |
+
qwen_model: Qwen3 embedding model
|
| 38 |
+
texts: List of text strings to encode
|
| 39 |
+
device: Device to run on
|
| 40 |
+
max_length: Maximum sequence length
|
| 41 |
+
use_query_mode: Whether to use query prompt for better understanding
|
| 42 |
+
"""
|
| 43 |
+
if qwen_model is None:
|
| 44 |
+
# Mock embeddings for testing when sentence-transformers is not available
|
| 45 |
+
batch_size = len(texts)
|
| 46 |
+
return torch.randn(batch_size, 1024, device=device, dtype=torch.float32)
|
| 47 |
+
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
# Use query prompt for better text understanding when specified
|
| 50 |
+
embeddings = qwen_model.encode(
|
| 51 |
+
texts,
|
| 52 |
+
prompt_name="query" if use_query_mode else None,
|
| 53 |
+
convert_to_tensor=True,
|
| 54 |
+
device=device,
|
| 55 |
+
max_seq_length=max_length,
|
| 56 |
+
output_value="token_embeddings" if not use_query_mode else "sentence_embedding"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return embeddings if use_query_mode else torch.stack(embeddings, dim=0)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class QwenTextEncoder(nn.Module):
|
| 63 |
+
"""
|
| 64 |
+
Qwen Text Encoder wrapper for training and inference
|
| 65 |
+
用于训练和推理的 Qwen 文本编码器包装器
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
model_path: str = "models/Qwen3-Embedding-0.6B",
|
| 71 |
+
device: str = "cuda",
|
| 72 |
+
max_length: int = 512,
|
| 73 |
+
freeze_encoder: bool = True
|
| 74 |
+
):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.device = device
|
| 77 |
+
self.max_length = max_length
|
| 78 |
+
self.freeze_encoder = freeze_encoder
|
| 79 |
+
|
| 80 |
+
# Load Qwen model
|
| 81 |
+
self.qwen_model = load_qwen_model(model_path, device)
|
| 82 |
+
|
| 83 |
+
# Freeze parameters if specified
|
| 84 |
+
if self.freeze_encoder and self.qwen_model is not None:
|
| 85 |
+
for param in self.qwen_model.parameters():
|
| 86 |
+
param.requires_grad = False
|
| 87 |
+
|
| 88 |
+
def encode_prompts(
|
| 89 |
+
self,
|
| 90 |
+
prompts: List[str],
|
| 91 |
+
negative_prompts: Optional[List[str]] = None,
|
| 92 |
+
do_classifier_free_guidance: bool = False
|
| 93 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 94 |
+
"""
|
| 95 |
+
Encode prompts using Qwen3 model
|
| 96 |
+
使用 Qwen3 模型编码提示词
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
tuple: (text_embeddings, pooled_embeddings)
|
| 100 |
+
- text_embeddings: [batch_size, 1024] for sequence embeddings
|
| 101 |
+
- pooled_embeddings: [batch_size, 1024] for pooled embeddings
|
| 102 |
+
"""
|
| 103 |
+
batch_size = len(prompts)
|
| 104 |
+
|
| 105 |
+
# Encode positive prompts for text embeddings (normal mode)
|
| 106 |
+
text_embeddings = encode_text_with_qwen(
|
| 107 |
+
self.qwen_model, prompts, self.device,
|
| 108 |
+
max_length=self.max_length, use_query_mode=False
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Encode positive prompts for pooled embeddings (query mode)
|
| 112 |
+
pooled_embeddings = encode_text_with_qwen(
|
| 113 |
+
self.qwen_model, prompts, self.device,
|
| 114 |
+
max_length=self.max_length, use_query_mode=True
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Handle negative prompts
|
| 118 |
+
if do_classifier_free_guidance:
|
| 119 |
+
if negative_prompts is None:
|
| 120 |
+
negative_prompts = [""] * batch_size
|
| 121 |
+
|
| 122 |
+
# Encode negative prompts
|
| 123 |
+
negative_text_embeddings = encode_text_with_qwen(
|
| 124 |
+
self.qwen_model, negative_prompts, self.device,
|
| 125 |
+
max_length=self.max_length, use_query_mode=False
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
negative_pooled_embeddings = encode_text_with_qwen(
|
| 129 |
+
self.qwen_model, negative_prompts, self.device,
|
| 130 |
+
max_length=self.max_length, use_query_mode=True
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Concatenate for classifier-free guidance
|
| 134 |
+
text_embeddings = torch.cat([negative_text_embeddings, text_embeddings], dim=0)
|
| 135 |
+
pooled_embeddings = torch.cat([negative_pooled_embeddings, pooled_embeddings], dim=0)
|
| 136 |
+
|
| 137 |
+
return text_embeddings, pooled_embeddings
|
| 138 |
+
|
| 139 |
+
def forward(self, prompts: List[str], negative_prompts: Optional[List[str]] = None):
|
| 140 |
+
"""
|
| 141 |
+
Forward pass for text encoding
|
| 142 |
+
Args:
|
| 143 |
+
prompts: List of text prompts
|
| 144 |
+
negative_prompts: Optional list of negative prompts
|
| 145 |
+
Returns:
|
| 146 |
+
tuple: (text_embeddings, pooled_embeddings)
|
| 147 |
+
"""
|
| 148 |
+
return self.encode_prompts(prompts, negative_prompts, do_classifier_free_guidance=(negative_prompts is not None))
|
| 149 |
+
|
| 150 |
+
def train(self, mode: bool = True):
|
| 151 |
+
"""Override train mode to handle frozen encoder"""
|
| 152 |
+
super().train(mode)
|
| 153 |
+
if self.freeze_encoder and self.qwen_model is not None:
|
| 154 |
+
self.qwen_model.eval() # Keep encoder in eval mode
|
| 155 |
+
return self
|
arch/training.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training Utilities for Qwen-SDXL
|
| 3 |
+
Qwen-SDXL 训练工具 - 包含损失函数、训练步骤等
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from typing import Dict, Any, Optional, Tuple
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DiffusionLoss(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Diffusion training loss for SDXL with Qwen embeddings
|
| 16 |
+
使用 Qwen 嵌入的 SDXL 扩散训练损失
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
noise_scheduler,
|
| 22 |
+
loss_type: str = "mse",
|
| 23 |
+
snr_gamma: Optional[float] = None,
|
| 24 |
+
use_v_parameterization: bool = False
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.noise_scheduler = noise_scheduler
|
| 28 |
+
self.loss_type = loss_type
|
| 29 |
+
self.snr_gamma = snr_gamma
|
| 30 |
+
self.use_v_parameterization = use_v_parameterization
|
| 31 |
+
|
| 32 |
+
if loss_type == "mse":
|
| 33 |
+
self.loss_fn = nn.MSELoss(reduction="none")
|
| 34 |
+
elif loss_type == "l1":
|
| 35 |
+
self.loss_fn = nn.L1Loss(reduction="none")
|
| 36 |
+
elif loss_type == "huber":
|
| 37 |
+
self.loss_fn = nn.HuberLoss(reduction="none", delta=0.1)
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(f"Unsupported loss type: {loss_type}")
|
| 40 |
+
|
| 41 |
+
def compute_snr(self, timesteps):
|
| 42 |
+
"""
|
| 43 |
+
Compute signal-to-noise ratio for loss weighting
|
| 44 |
+
计算信噪比用于损失加权
|
| 45 |
+
"""
|
| 46 |
+
alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps]
|
| 47 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
| 48 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
| 49 |
+
|
| 50 |
+
# SNR = signal^2 / noise^2
|
| 51 |
+
snr = (sqrt_alphas_cumprod / sqrt_one_minus_alphas_cumprod) ** 2
|
| 52 |
+
return snr
|
| 53 |
+
|
| 54 |
+
def forward(
|
| 55 |
+
self,
|
| 56 |
+
model_pred: torch.Tensor,
|
| 57 |
+
target: torch.Tensor,
|
| 58 |
+
timesteps: torch.Tensor,
|
| 59 |
+
mask: Optional[torch.Tensor] = None
|
| 60 |
+
) -> torch.Tensor:
|
| 61 |
+
"""
|
| 62 |
+
Compute diffusion loss
|
| 63 |
+
计算扩散损失
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
model_pred: Model prediction (noise or v-parameterization)
|
| 67 |
+
target: Target (noise or v-parameterization)
|
| 68 |
+
timesteps: Diffusion timesteps
|
| 69 |
+
mask: Optional mask for selective loss computation
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
Loss tensor
|
| 73 |
+
"""
|
| 74 |
+
# Basic loss computation
|
| 75 |
+
loss = self.loss_fn(model_pred, target)
|
| 76 |
+
|
| 77 |
+
# Apply mask if provided
|
| 78 |
+
if mask is not None:
|
| 79 |
+
loss = loss * mask
|
| 80 |
+
|
| 81 |
+
# Reduce over spatial dimensions
|
| 82 |
+
loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
| 83 |
+
|
| 84 |
+
# Apply SNR weighting if specified
|
| 85 |
+
if self.snr_gamma is not None:
|
| 86 |
+
snr = self.compute_snr(timesteps)
|
| 87 |
+
|
| 88 |
+
if self.snr_gamma >= 1.0:
|
| 89 |
+
# Min-SNR weighting
|
| 90 |
+
snr_weight = torch.stack([snr, torch.full_like(snr, self.snr_gamma)], dim=1).min(dim=1)[0]
|
| 91 |
+
else:
|
| 92 |
+
# Standard SNR weighting
|
| 93 |
+
snr_weight = snr ** self.snr_gamma
|
| 94 |
+
|
| 95 |
+
loss = loss * snr_weight
|
| 96 |
+
|
| 97 |
+
return loss.mean()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class AdapterTrainingStep:
|
| 101 |
+
"""
|
| 102 |
+
Training step for adapter-only training
|
| 103 |
+
仅训练适配器的训练步骤
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
unet,
|
| 109 |
+
vae,
|
| 110 |
+
text_encoder,
|
| 111 |
+
adapter,
|
| 112 |
+
noise_scheduler,
|
| 113 |
+
loss_fn: DiffusionLoss,
|
| 114 |
+
device: str = "cuda",
|
| 115 |
+
dtype: torch.dtype = torch.bfloat16
|
| 116 |
+
):
|
| 117 |
+
self.unet = unet
|
| 118 |
+
self.vae = vae
|
| 119 |
+
self.text_encoder = text_encoder
|
| 120 |
+
self.adapter = adapter
|
| 121 |
+
self.noise_scheduler = noise_scheduler
|
| 122 |
+
self.loss_fn = loss_fn
|
| 123 |
+
self.device = device
|
| 124 |
+
self.dtype = dtype
|
| 125 |
+
|
| 126 |
+
# Freeze components except adapter
|
| 127 |
+
self._freeze_components()
|
| 128 |
+
|
| 129 |
+
def _freeze_components(self):
|
| 130 |
+
"""Freeze all components except adapter"""
|
| 131 |
+
if self.unet is not None:
|
| 132 |
+
for param in self.unet.parameters():
|
| 133 |
+
param.requires_grad = False
|
| 134 |
+
|
| 135 |
+
if self.vae is not None:
|
| 136 |
+
for param in self.vae.parameters():
|
| 137 |
+
param.requires_grad = False
|
| 138 |
+
|
| 139 |
+
# Text encoder is already frozen in QwenTextEncoder
|
| 140 |
+
# Only adapter parameters should be trainable
|
| 141 |
+
for param in self.adapter.parameters():
|
| 142 |
+
param.requires_grad = True
|
| 143 |
+
|
| 144 |
+
def prepare_inputs(
|
| 145 |
+
self,
|
| 146 |
+
images: torch.Tensor,
|
| 147 |
+
prompts: list,
|
| 148 |
+
negative_prompts: Optional[list] = None
|
| 149 |
+
) -> Dict[str, torch.Tensor]:
|
| 150 |
+
"""
|
| 151 |
+
Prepare inputs for training step
|
| 152 |
+
准备训练步骤的输入
|
| 153 |
+
"""
|
| 154 |
+
batch_size = images.shape[0]
|
| 155 |
+
|
| 156 |
+
# Encode images to latents
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
latents = self.vae.encode(images.to(self.dtype)).latent_dist.sample()
|
| 159 |
+
latents = latents * self.vae.config.scaling_factor
|
| 160 |
+
|
| 161 |
+
# Add noise to latents
|
| 162 |
+
noise = torch.randn_like(latents)
|
| 163 |
+
timesteps = torch.randint(
|
| 164 |
+
0, self.noise_scheduler.config.num_train_timesteps,
|
| 165 |
+
(batch_size,), device=self.device
|
| 166 |
+
).long()
|
| 167 |
+
|
| 168 |
+
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
|
| 169 |
+
|
| 170 |
+
# Encode text
|
| 171 |
+
text_embeddings, pooled_embeddings = self.text_encoder.encode_prompts(
|
| 172 |
+
prompts, negative_prompts, do_classifier_free_guidance=False
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Add sequence dimension and project through adapter
|
| 176 |
+
seq_len = 77
|
| 177 |
+
text_embeddings_seq = text_embeddings.unsqueeze(1).expand(-1, seq_len, -1)
|
| 178 |
+
|
| 179 |
+
encoder_hidden_states = self.adapter.forward_text_embeddings(text_embeddings_seq.to(self.dtype))
|
| 180 |
+
pooled_prompt_embeds = self.adapter.forward_pooled_embeddings(pooled_embeddings.to(self.dtype))
|
| 181 |
+
|
| 182 |
+
# Prepare time IDs (simplified for training)
|
| 183 |
+
height, width = images.shape[-2:]
|
| 184 |
+
original_size = (height, width)
|
| 185 |
+
target_size = (height, width)
|
| 186 |
+
crops_coords_top_left = (0, 0)
|
| 187 |
+
|
| 188 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
| 189 |
+
add_time_ids = torch.tensor([add_time_ids] * batch_size, dtype=self.dtype, device=self.device)
|
| 190 |
+
|
| 191 |
+
return {
|
| 192 |
+
"noisy_latents": noisy_latents,
|
| 193 |
+
"timesteps": timesteps,
|
| 194 |
+
"encoder_hidden_states": encoder_hidden_states,
|
| 195 |
+
"pooled_prompt_embeds": pooled_prompt_embeds,
|
| 196 |
+
"add_time_ids": add_time_ids,
|
| 197 |
+
"noise": noise
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
def training_step(
|
| 201 |
+
self,
|
| 202 |
+
images: torch.Tensor,
|
| 203 |
+
prompts: list,
|
| 204 |
+
negative_prompts: Optional[list] = None
|
| 205 |
+
) -> Dict[str, Any]:
|
| 206 |
+
"""
|
| 207 |
+
Execute one training step
|
| 208 |
+
执行一个训练步骤
|
| 209 |
+
"""
|
| 210 |
+
# Prepare inputs
|
| 211 |
+
inputs = self.prepare_inputs(images, prompts, negative_prompts)
|
| 212 |
+
|
| 213 |
+
# Forward pass through UNet
|
| 214 |
+
added_cond_kwargs = {
|
| 215 |
+
"text_embeds": inputs["pooled_prompt_embeds"],
|
| 216 |
+
"time_ids": inputs["add_time_ids"]
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
model_pred = self.unet(
|
| 220 |
+
inputs["noisy_latents"],
|
| 221 |
+
inputs["timesteps"],
|
| 222 |
+
encoder_hidden_states=inputs["encoder_hidden_states"],
|
| 223 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 224 |
+
return_dict=False,
|
| 225 |
+
)[0]
|
| 226 |
+
|
| 227 |
+
# Compute loss
|
| 228 |
+
if self.loss_fn.use_v_parameterization:
|
| 229 |
+
# v-parameterization target
|
| 230 |
+
alphas_cumprod = self.noise_scheduler.alphas_cumprod[inputs["timesteps"]]
|
| 231 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
| 232 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
| 233 |
+
|
| 234 |
+
target = sqrt_alphas_cumprod.view(-1, 1, 1, 1) * inputs["noise"] - \
|
| 235 |
+
sqrt_one_minus_alphas_cumprod.view(-1, 1, 1, 1) * inputs["noisy_latents"]
|
| 236 |
+
else:
|
| 237 |
+
# Standard noise prediction
|
| 238 |
+
target = inputs["noise"]
|
| 239 |
+
|
| 240 |
+
loss = self.loss_fn(model_pred, target, inputs["timesteps"])
|
| 241 |
+
|
| 242 |
+
return {
|
| 243 |
+
"loss": loss,
|
| 244 |
+
"model_pred": model_pred.detach(),
|
| 245 |
+
"target": target.detach(),
|
| 246 |
+
"timesteps": inputs["timesteps"]
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def get_cosine_schedule_with_warmup(
|
| 251 |
+
optimizer,
|
| 252 |
+
num_warmup_steps: int,
|
| 253 |
+
num_training_steps: int,
|
| 254 |
+
num_cycles: float = 0.5,
|
| 255 |
+
last_epoch: int = -1,
|
| 256 |
+
):
|
| 257 |
+
"""
|
| 258 |
+
Cosine learning rate schedule with warmup
|
| 259 |
+
带预热的余弦学习率调度
|
| 260 |
+
"""
|
| 261 |
+
def lr_lambda(current_step):
|
| 262 |
+
if current_step < num_warmup_steps:
|
| 263 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 264 |
+
|
| 265 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 266 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
| 267 |
+
|
| 268 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 269 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class EMAModel:
|
| 273 |
+
"""
|
| 274 |
+
Exponential Moving Average for model parameters
|
| 275 |
+
模型参数的指数移动平均
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
def __init__(self, model, decay: float = 0.9999):
|
| 279 |
+
self.model = model
|
| 280 |
+
self.decay = decay
|
| 281 |
+
self.shadow = {}
|
| 282 |
+
self.backup = {}
|
| 283 |
+
|
| 284 |
+
# Initialize shadow parameters
|
| 285 |
+
for name, param in model.named_parameters():
|
| 286 |
+
if param.requires_grad:
|
| 287 |
+
self.shadow[name] = param.data.clone()
|
| 288 |
+
|
| 289 |
+
def update(self):
|
| 290 |
+
"""Update EMA parameters"""
|
| 291 |
+
for name, param in self.model.named_parameters():
|
| 292 |
+
if param.requires_grad and name in self.shadow:
|
| 293 |
+
self.shadow[name] = self.decay * self.shadow[name] + (1 - self.decay) * param.data
|
| 294 |
+
|
| 295 |
+
def apply_shadow(self):
|
| 296 |
+
"""Apply EMA parameters to model"""
|
| 297 |
+
for name, param in self.model.named_parameters():
|
| 298 |
+
if param.requires_grad and name in self.shadow:
|
| 299 |
+
self.backup[name] = param.data.clone()
|
| 300 |
+
param.data = self.shadow[name]
|
| 301 |
+
|
| 302 |
+
def restore(self):
|
| 303 |
+
"""Restore original parameters"""
|
| 304 |
+
for name, param in self.model.named_parameters():
|
| 305 |
+
if param.requires_grad and name in self.backup:
|
| 306 |
+
param.data = self.backup[name]
|
| 307 |
+
self.backup = {}
|
diffusers/.github/PULL_REQUEST_TEMPLATE.md
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# What does this PR do?
|
| 2 |
+
|
| 3 |
+
<!--
|
| 4 |
+
Congratulations! You've made it this far! You're not quite done yet though.
|
| 5 |
+
|
| 6 |
+
Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution.
|
| 7 |
+
|
| 8 |
+
Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.
|
| 9 |
+
|
| 10 |
+
Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost.
|
| 11 |
+
-->
|
| 12 |
+
|
| 13 |
+
<!-- Remove if not applicable -->
|
| 14 |
+
|
| 15 |
+
Fixes # (issue)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
## Before submitting
|
| 19 |
+
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
|
| 20 |
+
- [ ] Did you read the [contributor guideline](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md)?
|
| 21 |
+
- [ ] Did you read our [philosophy doc](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md) (important for complex PRs)?
|
| 22 |
+
- [ ] Was this discussed/approved via a GitHub issue or the [forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63)? Please add a link to it if that's the case.
|
| 23 |
+
- [ ] Did you make sure to update the documentation with your changes? Here are the
|
| 24 |
+
[documentation guidelines](https://github.com/huggingface/diffusers/tree/main/docs), and
|
| 25 |
+
[here are tips on formatting docstrings](https://github.com/huggingface/diffusers/tree/main/docs#writing-source-documentation).
|
| 26 |
+
- [ ] Did you write any new necessary tests?
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
## Who can review?
|
| 30 |
+
|
| 31 |
+
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
|
| 32 |
+
members/contributors who may be interested in your PR.
|
| 33 |
+
|
| 34 |
+
<!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @.
|
| 35 |
+
|
| 36 |
+
If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.
|
| 37 |
+
Please tag fewer than 3 people.
|
| 38 |
+
|
| 39 |
+
Core library:
|
| 40 |
+
|
| 41 |
+
- Schedulers: @yiyixuxu
|
| 42 |
+
- Pipelines and pipeline callbacks: @yiyixuxu and @asomoza
|
| 43 |
+
- Training examples: @sayakpaul
|
| 44 |
+
- Docs: @stevhliu and @sayakpaul
|
| 45 |
+
- JAX and MPS: @pcuenca
|
| 46 |
+
- Audio: @sanchit-gandhi
|
| 47 |
+
- General functionalities: @sayakpaul @yiyixuxu @DN6
|
| 48 |
+
|
| 49 |
+
Integrations:
|
| 50 |
+
|
| 51 |
+
- deepspeed: HF Trainer/Accelerate: @SunMarc
|
| 52 |
+
- PEFT: @sayakpaul @BenjaminBossan
|
| 53 |
+
|
| 54 |
+
HF projects:
|
| 55 |
+
|
| 56 |
+
- accelerate: [different repo](https://github.com/huggingface/accelerate)
|
| 57 |
+
- datasets: [different repo](https://github.com/huggingface/datasets)
|
| 58 |
+
- transformers: [different repo](https://github.com/huggingface/transformers)
|
| 59 |
+
- safetensors: [different repo](https://github.com/huggingface/safetensors)
|
| 60 |
+
|
| 61 |
+
-->
|
diffusers/docs/README.md
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!---
|
| 2 |
+
Copyright 2024- The HuggingFace Team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
you may not use this file except in compliance with the License.
|
| 6 |
+
You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
See the License for the specific language governing permissions and
|
| 14 |
+
limitations under the License.
|
| 15 |
+
-->
|
| 16 |
+
|
| 17 |
+
# Generating the documentation
|
| 18 |
+
|
| 19 |
+
To generate the documentation, you first have to build it. Several packages are necessary to build the doc,
|
| 20 |
+
you can install them with the following command, at the root of the code repository:
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
pip install -e ".[docs]"
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
Then you need to install our open source documentation builder tool:
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
pip install git+https://github.com/huggingface/doc-builder
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
---
|
| 33 |
+
**NOTE**
|
| 34 |
+
|
| 35 |
+
You only need to generate the documentation to inspect it locally (if you're planning changes and want to
|
| 36 |
+
check how they look before committing for instance). You don't have to commit the built documentation.
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## Previewing the documentation
|
| 41 |
+
|
| 42 |
+
To preview the docs, first install the `watchdog` module with:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
pip install watchdog
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
Then run the following command:
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
doc-builder preview {package_name} {path_to_docs}
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
For example:
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
doc-builder preview diffusers docs/source/en
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
The docs will be viewable at [http://localhost:3000](http://localhost:3000). You can also preview the docs once you have opened a PR. You will see a bot add a comment to a link where the documentation with your changes lives.
|
| 61 |
+
|
| 62 |
+
---
|
| 63 |
+
**NOTE**
|
| 64 |
+
|
| 65 |
+
The `preview` command only works with existing doc files. When you add a completely new file, you need to update `_toctree.yml` & restart `preview` command (`ctrl-c` to stop it & call `doc-builder preview ...` again).
|
| 66 |
+
|
| 67 |
+
---
|
| 68 |
+
|
| 69 |
+
## Adding a new element to the navigation bar
|
| 70 |
+
|
| 71 |
+
Accepted files are Markdown (.md).
|
| 72 |
+
|
| 73 |
+
Create a file with its extension and put it in the source directory. You can then link it to the toc-tree by putting
|
| 74 |
+
the filename without the extension in the [`_toctree.yml`](https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml) file.
|
| 75 |
+
|
| 76 |
+
## Renaming section headers and moving sections
|
| 77 |
+
|
| 78 |
+
It helps to keep the old links working when renaming the section header and/or moving sections from one document to another. This is because the old links are likely to be used in Issues, Forums, and Social media and it'd make for a much more superior user experience if users reading those months later could still easily navigate to the originally intended information.
|
| 79 |
+
|
| 80 |
+
Therefore, we simply keep a little map of moved sections at the end of the document where the original section was. The key is to preserve the original anchor.
|
| 81 |
+
|
| 82 |
+
So if you renamed a section from: "Section A" to "Section B", then you can add at the end of the file:
|
| 83 |
+
|
| 84 |
+
```md
|
| 85 |
+
Sections that were moved:
|
| 86 |
+
|
| 87 |
+
[ <a href="#section-b">Section A</a><a id="section-a"></a> ]
|
| 88 |
+
```
|
| 89 |
+
and of course, if you moved it to another file, then:
|
| 90 |
+
|
| 91 |
+
```md
|
| 92 |
+
Sections that were moved:
|
| 93 |
+
|
| 94 |
+
[ <a href="../new-file#section-b">Section A</a><a id="section-a"></a> ]
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
Use the relative style to link to the new file so that the versioned docs continue to work.
|
| 98 |
+
|
| 99 |
+
For an example of a rich moved section set please see the very end of [the transformers Trainer doc](https://github.com/huggingface/transformers/blob/main/docs/source/en/main_classes/trainer.md).
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
## Writing Documentation - Specification
|
| 103 |
+
|
| 104 |
+
The `huggingface/diffusers` documentation follows the
|
| 105 |
+
[Google documentation](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) style for docstrings,
|
| 106 |
+
although we can write them directly in Markdown.
|
| 107 |
+
|
| 108 |
+
### Adding a new tutorial
|
| 109 |
+
|
| 110 |
+
Adding a new tutorial or section is done in two steps:
|
| 111 |
+
|
| 112 |
+
- Add a new Markdown (.md) file under `docs/source/<languageCode>`.
|
| 113 |
+
- Link that file in `docs/source/<languageCode>/_toctree.yml` on the correct toc-tree.
|
| 114 |
+
|
| 115 |
+
Make sure to put your new file under the proper section. It's unlikely to go in the first section (*Get Started*), so
|
| 116 |
+
depending on the intended targets (beginners, more advanced users, or researchers) it should go in sections two, three, or four.
|
| 117 |
+
|
| 118 |
+
### Adding a new pipeline/scheduler
|
| 119 |
+
|
| 120 |
+
When adding a new pipeline:
|
| 121 |
+
|
| 122 |
+
- Create a file `xxx.md` under `docs/source/<languageCode>/api/pipelines` (don't hesitate to copy an existing file as template).
|
| 123 |
+
- Link that file in (*Diffusers Summary*) section in `docs/source/api/pipelines/overview.md`, along with the link to the paper, and a colab notebook (if available).
|
| 124 |
+
- Write a short overview of the diffusion model:
|
| 125 |
+
- Overview with paper & authors
|
| 126 |
+
- Paper abstract
|
| 127 |
+
- Tips and tricks and how to use it best
|
| 128 |
+
- Possible an end-to-end example of how to use it
|
| 129 |
+
- Add all the pipeline classes that should be linked in the diffusion model. These classes should be added using our Markdown syntax. By default as follows:
|
| 130 |
+
|
| 131 |
+
```
|
| 132 |
+
[[autodoc]] XXXPipeline
|
| 133 |
+
- all
|
| 134 |
+
- __call__
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
This will include every public method of the pipeline that is documented, as well as the `__call__` method that is not documented by default. If you just want to add additional methods that are not documented, you can put the list of all methods to add in a list that contains `all`.
|
| 138 |
+
|
| 139 |
+
```
|
| 140 |
+
[[autodoc]] XXXPipeline
|
| 141 |
+
- all
|
| 142 |
+
- __call__
|
| 143 |
+
- enable_attention_slicing
|
| 144 |
+
- disable_attention_slicing
|
| 145 |
+
- enable_xformers_memory_efficient_attention
|
| 146 |
+
- disable_xformers_memory_efficient_attention
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
You can follow the same process to create a new scheduler under the `docs/source/<languageCode>/api/schedulers` folder.
|
| 150 |
+
|
| 151 |
+
### Writing source documentation
|
| 152 |
+
|
| 153 |
+
Values that should be put in `code` should either be surrounded by backticks: \`like so\`. Note that argument names
|
| 154 |
+
and objects like True, None, or any strings should usually be put in `code`.
|
| 155 |
+
|
| 156 |
+
When mentioning a class, function, or method, it is recommended to use our syntax for internal links so that our tool
|
| 157 |
+
adds a link to its documentation with this syntax: \[\`XXXClass\`\] or \[\`function\`\]. This requires the class or
|
| 158 |
+
function to be in the main package.
|
| 159 |
+
|
| 160 |
+
If you want to create a link to some internal class or function, you need to
|
| 161 |
+
provide its path. For instance: \[\`pipelines.ImagePipelineOutput\`\]. This will be converted into a link with
|
| 162 |
+
`pipelines.ImagePipelineOutput` in the description. To get rid of the path and only keep the name of the object you are
|
| 163 |
+
linking to in the description, add a ~: \[\`~pipelines.ImagePipelineOutput\`\] will generate a link with `ImagePipelineOutput` in the description.
|
| 164 |
+
|
| 165 |
+
The same works for methods so you can either use \[\`XXXClass.method\`\] or \[\`~XXXClass.method\`\].
|
| 166 |
+
|
| 167 |
+
#### Defining arguments in a method
|
| 168 |
+
|
| 169 |
+
Arguments should be defined with the `Args:` (or `Arguments:` or `Parameters:`) prefix, followed by a line return and
|
| 170 |
+
an indentation. The argument should be followed by its type, with its shape if it is a tensor, a colon, and its
|
| 171 |
+
description:
|
| 172 |
+
|
| 173 |
+
```
|
| 174 |
+
Args:
|
| 175 |
+
n_layers (`int`): The number of layers of the model.
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
If the description is too long to fit in one line, another indentation is necessary before writing the description
|
| 179 |
+
after the argument.
|
| 180 |
+
|
| 181 |
+
Here's an example showcasing everything so far:
|
| 182 |
+
|
| 183 |
+
```
|
| 184 |
+
Args:
|
| 185 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 186 |
+
Indices of input sequence tokens in the vocabulary.
|
| 187 |
+
|
| 188 |
+
Indices can be obtained using [`AlbertTokenizer`]. See [`~PreTrainedTokenizer.encode`] and
|
| 189 |
+
[`~PreTrainedTokenizer.__call__`] for details.
|
| 190 |
+
|
| 191 |
+
[What are input IDs?](../glossary#input-ids)
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
For optional arguments or arguments with defaults we follow the following syntax: imagine we have a function with the
|
| 195 |
+
following signature:
|
| 196 |
+
|
| 197 |
+
```py
|
| 198 |
+
def my_function(x: str=None, a: float=3.14):
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
then its documentation should look like this:
|
| 202 |
+
|
| 203 |
+
```
|
| 204 |
+
Args:
|
| 205 |
+
x (`str`, *optional*):
|
| 206 |
+
This argument controls ...
|
| 207 |
+
a (`float`, *optional*, defaults to `3.14`):
|
| 208 |
+
This argument is used to ...
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
Note that we always omit the "defaults to \`None\`" when None is the default for any argument. Also note that even
|
| 212 |
+
if the first line describing your argument type and its default gets long, you can't break it on several lines. You can
|
| 213 |
+
however write as many lines as you want in the indented description (see the example above with `input_ids`).
|
| 214 |
+
|
| 215 |
+
#### Writing a multi-line code block
|
| 216 |
+
|
| 217 |
+
Multi-line code blocks can be useful for displaying examples. They are done between two lines of three backticks as usual in Markdown:
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
````
|
| 221 |
+
```
|
| 222 |
+
# first line of code
|
| 223 |
+
# second line
|
| 224 |
+
# etc
|
| 225 |
+
```
|
| 226 |
+
````
|
| 227 |
+
|
| 228 |
+
#### Writing a return block
|
| 229 |
+
|
| 230 |
+
The return block should be introduced with the `Returns:` prefix, followed by a line return and an indentation.
|
| 231 |
+
The first line should be the type of the return, followed by a line return. No need to indent further for the elements
|
| 232 |
+
building the return.
|
| 233 |
+
|
| 234 |
+
Here's an example of a single value return:
|
| 235 |
+
|
| 236 |
+
```
|
| 237 |
+
Returns:
|
| 238 |
+
`List[int]`: A list of integers in the range [0, 1] --- 1 for a special token, 0 for a sequence token.
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
Here's an example of a tuple return, comprising several objects:
|
| 242 |
+
|
| 243 |
+
```
|
| 244 |
+
Returns:
|
| 245 |
+
`tuple(torch.Tensor)` comprising various elements depending on the configuration ([`BertConfig`]) and inputs:
|
| 246 |
+
- ** loss** (*optional*, returned when `masked_lm_labels` is provided) `torch.Tensor` of shape `(1,)` --
|
| 247 |
+
Total loss is the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
|
| 248 |
+
- **prediction_scores** (`torch.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`) --
|
| 249 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 250 |
+
```
|
| 251 |
+
|
| 252 |
+
#### Adding an image
|
| 253 |
+
|
| 254 |
+
Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like
|
| 255 |
+
the ones hosted on [`hf-internal-testing`](https://huggingface.co/hf-internal-testing) in which to place these files and reference
|
| 256 |
+
them by URL. We recommend putting them in the following dataset: [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images).
|
| 257 |
+
If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images
|
| 258 |
+
to this dataset.
|
| 259 |
+
|
| 260 |
+
## Styling the docstring
|
| 261 |
+
|
| 262 |
+
We have an automatic script running with the `make style` command that will make sure that:
|
| 263 |
+
- the docstrings fully take advantage of the line width
|
| 264 |
+
- all code examples are formatted using black, like the code of the Transformers library
|
| 265 |
+
|
| 266 |
+
This script may have some weird failures if you made a syntax mistake or if you uncover a bug. Therefore, it's
|
| 267 |
+
recommended to commit your changes before running `make style`, so you can revert the changes done by that script
|
| 268 |
+
easily.
|
diffusers/docs/TRANSLATING.md
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
|
| 3 |
+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
| 4 |
+
the License. You may obtain a copy of the License at
|
| 5 |
+
|
| 6 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
|
| 8 |
+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
| 9 |
+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
| 10 |
+
specific language governing permissions and limitations under the License.
|
| 11 |
+
-->
|
| 12 |
+
|
| 13 |
+
### Translating the Diffusers documentation into your language
|
| 14 |
+
|
| 15 |
+
As part of our mission to democratize machine learning, we'd love to make the Diffusers library available in many more languages! Follow the steps below if you want to help translate the documentation into your language 🙏.
|
| 16 |
+
|
| 17 |
+
**🗞️ Open an issue**
|
| 18 |
+
|
| 19 |
+
To get started, navigate to the [Issues](https://github.com/huggingface/diffusers/issues) page of this repo and check if anyone else has opened an issue for your language. If not, open a new issue by selecting the "🌐 Translating a New Language?" from the "New issue" button.
|
| 20 |
+
|
| 21 |
+
Once an issue exists, post a comment to indicate which chapters you'd like to work on, and we'll add your name to the list.
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
**🍴 Fork the repository**
|
| 25 |
+
|
| 26 |
+
First, you'll need to [fork the Diffusers repo](https://docs.github.com/en/get-started/quickstart/fork-a-repo). You can do this by clicking on the **Fork** button on the top-right corner of this repo's page.
|
| 27 |
+
|
| 28 |
+
Once you've forked the repo, you'll want to get the files on your local machine for editing. You can do that by cloning the fork with Git as follows:
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
git clone https://github.com/<YOUR-USERNAME>/diffusers.git
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
**📋 Copy-paste the English version with a new language code**
|
| 35 |
+
|
| 36 |
+
The documentation files are in one leading directory:
|
| 37 |
+
|
| 38 |
+
- [`docs/source`](https://github.com/huggingface/diffusers/tree/main/docs/source): All the documentation materials are organized here by language.
|
| 39 |
+
|
| 40 |
+
You'll only need to copy the files in the [`docs/source/en`](https://github.com/huggingface/diffusers/tree/main/docs/source/en) directory, so first navigate to your fork of the repo and run the following:
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
cd ~/path/to/diffusers/docs
|
| 44 |
+
cp -r source/en source/<LANG-ID>
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
Here, `<LANG-ID>` should be one of the ISO 639-1 or ISO 639-2 language codes -- see [here](https://www.loc.gov/standards/iso639-2/php/code_list.php) for a handy table.
|
| 48 |
+
|
| 49 |
+
**✍️ Start translating**
|
| 50 |
+
|
| 51 |
+
The fun part comes - translating the text!
|
| 52 |
+
|
| 53 |
+
The first thing we recommend is translating the part of the `_toctree.yml` file that corresponds to your doc chapter. This file is used to render the table of contents on the website.
|
| 54 |
+
|
| 55 |
+
> 🙋 If the `_toctree.yml` file doesn't yet exist for your language, you can create one by copy-pasting from the English version and deleting the sections unrelated to your chapter. Just make sure it exists in the `docs/source/<LANG-ID>/` directory!
|
| 56 |
+
|
| 57 |
+
The fields you should add are `local` (with the name of the file containing the translation; e.g. `autoclass_tutorial`), and `title` (with the title of the doc in your language; e.g. `Load pretrained instances with an AutoClass`) -- as a reference, here is the `_toctree.yml` for [English](https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml):
|
| 58 |
+
|
| 59 |
+
```yaml
|
| 60 |
+
- sections:
|
| 61 |
+
- local: pipeline_tutorial # Do not change this! Use the same name for your .md file
|
| 62 |
+
title: Pipelines for inference # Translate this!
|
| 63 |
+
...
|
| 64 |
+
title: Tutorials # Translate this!
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
Once you have translated the `_toctree.yml` file, you can start translating the [MDX](https://mdxjs.com/) files associated with your docs chapter.
|
| 68 |
+
|
| 69 |
+
> 🙋 If you'd like others to help you with the translation, you should [open an issue](https://github.com/huggingface/diffusers/issues) and tag @patrickvonplaten.
|
diffusers/scripts/conversion_ldm_uncond.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import yaml
|
| 5 |
+
|
| 6 |
+
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def convert_ldm_original(checkpoint_path, config_path, output_path):
|
| 10 |
+
config = yaml.safe_load(config_path)
|
| 11 |
+
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
|
| 12 |
+
keys = list(state_dict.keys())
|
| 13 |
+
|
| 14 |
+
# extract state_dict for VQVAE
|
| 15 |
+
first_stage_dict = {}
|
| 16 |
+
first_stage_key = "first_stage_model."
|
| 17 |
+
for key in keys:
|
| 18 |
+
if key.startswith(first_stage_key):
|
| 19 |
+
first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
|
| 20 |
+
|
| 21 |
+
# extract state_dict for UNetLDM
|
| 22 |
+
unet_state_dict = {}
|
| 23 |
+
unet_key = "model.diffusion_model."
|
| 24 |
+
for key in keys:
|
| 25 |
+
if key.startswith(unet_key):
|
| 26 |
+
unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
|
| 27 |
+
|
| 28 |
+
vqvae_init_args = config["model"]["params"]["first_stage_config"]["params"]
|
| 29 |
+
unet_init_args = config["model"]["params"]["unet_config"]["params"]
|
| 30 |
+
|
| 31 |
+
vqvae = VQModel(**vqvae_init_args).eval()
|
| 32 |
+
vqvae.load_state_dict(first_stage_dict)
|
| 33 |
+
|
| 34 |
+
unet = UNetLDMModel(**unet_init_args).eval()
|
| 35 |
+
unet.load_state_dict(unet_state_dict)
|
| 36 |
+
|
| 37 |
+
noise_scheduler = DDIMScheduler(
|
| 38 |
+
timesteps=config["model"]["params"]["timesteps"],
|
| 39 |
+
beta_schedule="scaled_linear",
|
| 40 |
+
beta_start=config["model"]["params"]["linear_start"],
|
| 41 |
+
beta_end=config["model"]["params"]["linear_end"],
|
| 42 |
+
clip_sample=False,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
pipeline = LDMPipeline(vqvae, unet, noise_scheduler)
|
| 46 |
+
pipeline.save_pretrained(output_path)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
parser = argparse.ArgumentParser()
|
| 51 |
+
parser.add_argument("--checkpoint_path", type=str, required=True)
|
| 52 |
+
parser.add_argument("--config_path", type=str, required=True)
|
| 53 |
+
parser.add_argument("--output_path", type=str, required=True)
|
| 54 |
+
args = parser.parse_args()
|
| 55 |
+
|
| 56 |
+
convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
|
diffusers/scripts/convert_animatediff_motion_lora_to_diffusers.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from huggingface_hub import create_repo, upload_folder
|
| 6 |
+
from safetensors.torch import load_file, save_file
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def convert_motion_module(original_state_dict):
|
| 10 |
+
converted_state_dict = {}
|
| 11 |
+
for k, v in original_state_dict.items():
|
| 12 |
+
if "pos_encoder" in k:
|
| 13 |
+
continue
|
| 14 |
+
|
| 15 |
+
else:
|
| 16 |
+
converted_state_dict[
|
| 17 |
+
k.replace(".norms.0", ".norm1")
|
| 18 |
+
.replace(".norms.1", ".norm2")
|
| 19 |
+
.replace(".ff_norm", ".norm3")
|
| 20 |
+
.replace(".attention_blocks.0", ".attn1")
|
| 21 |
+
.replace(".attention_blocks.1", ".attn2")
|
| 22 |
+
.replace(".temporal_transformer", "")
|
| 23 |
+
] = v
|
| 24 |
+
|
| 25 |
+
return converted_state_dict
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_args():
|
| 29 |
+
parser = argparse.ArgumentParser()
|
| 30 |
+
parser.add_argument("--ckpt_path", type=str, required=True, help="Path to checkpoint")
|
| 31 |
+
parser.add_argument("--output_path", type=str, required=True, help="Path to output directory")
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--push_to_hub",
|
| 34 |
+
action="store_true",
|
| 35 |
+
default=False,
|
| 36 |
+
help="Whether to push the converted model to the HF or not",
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
return parser.parse_args()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
args = get_args()
|
| 44 |
+
|
| 45 |
+
if args.ckpt_path.endswith(".safetensors"):
|
| 46 |
+
state_dict = load_file(args.ckpt_path)
|
| 47 |
+
else:
|
| 48 |
+
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
| 49 |
+
|
| 50 |
+
if "state_dict" in state_dict.keys():
|
| 51 |
+
state_dict = state_dict["state_dict"]
|
| 52 |
+
|
| 53 |
+
conv_state_dict = convert_motion_module(state_dict)
|
| 54 |
+
|
| 55 |
+
# convert to new format
|
| 56 |
+
output_dict = {}
|
| 57 |
+
for module_name, params in conv_state_dict.items():
|
| 58 |
+
if type(params) is not torch.Tensor:
|
| 59 |
+
continue
|
| 60 |
+
output_dict.update({f"unet.{module_name}": params})
|
| 61 |
+
|
| 62 |
+
os.makedirs(args.output_path, exist_ok=True)
|
| 63 |
+
|
| 64 |
+
filepath = os.path.join(args.output_path, "diffusion_pytorch_model.safetensors")
|
| 65 |
+
save_file(output_dict, filepath)
|
| 66 |
+
|
| 67 |
+
if args.push_to_hub:
|
| 68 |
+
repo_id = create_repo(args.output_path, exist_ok=True).repo_id
|
| 69 |
+
upload_folder(repo_id=repo_id, folder_path=args.output_path, repo_type="model")
|
diffusers/scripts/convert_cogvideox_to_diffusers.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from typing import Any, Dict
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 6 |
+
|
| 7 |
+
from diffusers import (
|
| 8 |
+
AutoencoderKLCogVideoX,
|
| 9 |
+
CogVideoXDDIMScheduler,
|
| 10 |
+
CogVideoXImageToVideoPipeline,
|
| 11 |
+
CogVideoXPipeline,
|
| 12 |
+
CogVideoXTransformer3DModel,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
|
| 17 |
+
to_q_key = key.replace("query_key_value", "to_q")
|
| 18 |
+
to_k_key = key.replace("query_key_value", "to_k")
|
| 19 |
+
to_v_key = key.replace("query_key_value", "to_v")
|
| 20 |
+
to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
|
| 21 |
+
state_dict[to_q_key] = to_q
|
| 22 |
+
state_dict[to_k_key] = to_k
|
| 23 |
+
state_dict[to_v_key] = to_v
|
| 24 |
+
state_dict.pop(key)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
|
| 28 |
+
layer_id, weight_or_bias = key.split(".")[-2:]
|
| 29 |
+
|
| 30 |
+
if "query" in key:
|
| 31 |
+
new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
|
| 32 |
+
elif "key" in key:
|
| 33 |
+
new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"
|
| 34 |
+
|
| 35 |
+
state_dict[new_key] = state_dict.pop(key)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
|
| 39 |
+
layer_id, _, weight_or_bias = key.split(".")[-3:]
|
| 40 |
+
|
| 41 |
+
weights_or_biases = state_dict[key].chunk(12, dim=0)
|
| 42 |
+
norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
|
| 43 |
+
norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
|
| 44 |
+
|
| 45 |
+
norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
|
| 46 |
+
state_dict[norm1_key] = norm1_weights_or_biases
|
| 47 |
+
|
| 48 |
+
norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
|
| 49 |
+
state_dict[norm2_key] = norm2_weights_or_biases
|
| 50 |
+
|
| 51 |
+
state_dict.pop(key)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
|
| 55 |
+
state_dict.pop(key)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
|
| 59 |
+
key_split = key.split(".")
|
| 60 |
+
layer_index = int(key_split[2])
|
| 61 |
+
replace_layer_index = 4 - 1 - layer_index
|
| 62 |
+
|
| 63 |
+
key_split[1] = "up_blocks"
|
| 64 |
+
key_split[2] = str(replace_layer_index)
|
| 65 |
+
new_key = ".".join(key_split)
|
| 66 |
+
|
| 67 |
+
state_dict[new_key] = state_dict.pop(key)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
TRANSFORMER_KEYS_RENAME_DICT = {
|
| 71 |
+
"transformer.final_layernorm": "norm_final",
|
| 72 |
+
"transformer": "transformer_blocks",
|
| 73 |
+
"attention": "attn1",
|
| 74 |
+
"mlp": "ff.net",
|
| 75 |
+
"dense_h_to_4h": "0.proj",
|
| 76 |
+
"dense_4h_to_h": "2",
|
| 77 |
+
".layers": "",
|
| 78 |
+
"dense": "to_out.0",
|
| 79 |
+
"input_layernorm": "norm1.norm",
|
| 80 |
+
"post_attn1_layernorm": "norm2.norm",
|
| 81 |
+
"time_embed.0": "time_embedding.linear_1",
|
| 82 |
+
"time_embed.2": "time_embedding.linear_2",
|
| 83 |
+
"ofs_embed.0": "ofs_embedding.linear_1",
|
| 84 |
+
"ofs_embed.2": "ofs_embedding.linear_2",
|
| 85 |
+
"mixins.patch_embed": "patch_embed",
|
| 86 |
+
"mixins.final_layer.norm_final": "norm_out.norm",
|
| 87 |
+
"mixins.final_layer.linear": "proj_out",
|
| 88 |
+
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
|
| 89 |
+
"mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
| 93 |
+
"query_key_value": reassign_query_key_value_inplace,
|
| 94 |
+
"query_layernorm_list": reassign_query_key_layernorm_inplace,
|
| 95 |
+
"key_layernorm_list": reassign_query_key_layernorm_inplace,
|
| 96 |
+
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
|
| 97 |
+
"embed_tokens": remove_keys_inplace,
|
| 98 |
+
"freqs_sin": remove_keys_inplace,
|
| 99 |
+
"freqs_cos": remove_keys_inplace,
|
| 100 |
+
"position_embedding": remove_keys_inplace,
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
VAE_KEYS_RENAME_DICT = {
|
| 104 |
+
"block.": "resnets.",
|
| 105 |
+
"down.": "down_blocks.",
|
| 106 |
+
"downsample": "downsamplers.0",
|
| 107 |
+
"upsample": "upsamplers.0",
|
| 108 |
+
"nin_shortcut": "conv_shortcut",
|
| 109 |
+
"encoder.mid.block_1": "encoder.mid_block.resnets.0",
|
| 110 |
+
"encoder.mid.block_2": "encoder.mid_block.resnets.1",
|
| 111 |
+
"decoder.mid.block_1": "decoder.mid_block.resnets.0",
|
| 112 |
+
"decoder.mid.block_2": "decoder.mid_block.resnets.1",
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
VAE_SPECIAL_KEYS_REMAP = {
|
| 116 |
+
"loss": remove_keys_inplace,
|
| 117 |
+
"up.": replace_up_keys_inplace,
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
TOKENIZER_MAX_LENGTH = 226
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
| 124 |
+
state_dict = saved_dict
|
| 125 |
+
if "model" in saved_dict.keys():
|
| 126 |
+
state_dict = state_dict["model"]
|
| 127 |
+
if "module" in saved_dict.keys():
|
| 128 |
+
state_dict = state_dict["module"]
|
| 129 |
+
if "state_dict" in saved_dict.keys():
|
| 130 |
+
state_dict = state_dict["state_dict"]
|
| 131 |
+
return state_dict
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
| 135 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def convert_transformer(
|
| 139 |
+
ckpt_path: str,
|
| 140 |
+
num_layers: int,
|
| 141 |
+
num_attention_heads: int,
|
| 142 |
+
use_rotary_positional_embeddings: bool,
|
| 143 |
+
i2v: bool,
|
| 144 |
+
dtype: torch.dtype,
|
| 145 |
+
init_kwargs: Dict[str, Any],
|
| 146 |
+
):
|
| 147 |
+
PREFIX_KEY = "model.diffusion_model."
|
| 148 |
+
|
| 149 |
+
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
| 150 |
+
transformer = CogVideoXTransformer3DModel(
|
| 151 |
+
in_channels=32 if i2v else 16,
|
| 152 |
+
num_layers=num_layers,
|
| 153 |
+
num_attention_heads=num_attention_heads,
|
| 154 |
+
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
|
| 155 |
+
ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V
|
| 156 |
+
use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
|
| 157 |
+
**init_kwargs,
|
| 158 |
+
).to(dtype=dtype)
|
| 159 |
+
|
| 160 |
+
for key in list(original_state_dict.keys()):
|
| 161 |
+
new_key = key[len(PREFIX_KEY) :]
|
| 162 |
+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
| 163 |
+
new_key = new_key.replace(replace_key, rename_key)
|
| 164 |
+
update_state_dict_inplace(original_state_dict, key, new_key)
|
| 165 |
+
|
| 166 |
+
for key in list(original_state_dict.keys()):
|
| 167 |
+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
| 168 |
+
if special_key not in key:
|
| 169 |
+
continue
|
| 170 |
+
handler_fn_inplace(key, original_state_dict)
|
| 171 |
+
|
| 172 |
+
transformer.load_state_dict(original_state_dict, strict=True)
|
| 173 |
+
return transformer
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype):
|
| 177 |
+
init_kwargs = {"scaling_factor": scaling_factor}
|
| 178 |
+
if version == "1.5":
|
| 179 |
+
init_kwargs.update({"invert_scale_latents": True})
|
| 180 |
+
|
| 181 |
+
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
| 182 |
+
vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype)
|
| 183 |
+
|
| 184 |
+
for key in list(original_state_dict.keys()):
|
| 185 |
+
new_key = key[:]
|
| 186 |
+
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
| 187 |
+
new_key = new_key.replace(replace_key, rename_key)
|
| 188 |
+
update_state_dict_inplace(original_state_dict, key, new_key)
|
| 189 |
+
|
| 190 |
+
for key in list(original_state_dict.keys()):
|
| 191 |
+
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
|
| 192 |
+
if special_key not in key:
|
| 193 |
+
continue
|
| 194 |
+
handler_fn_inplace(key, original_state_dict)
|
| 195 |
+
|
| 196 |
+
vae.load_state_dict(original_state_dict, strict=True)
|
| 197 |
+
return vae
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_transformer_init_kwargs(version: str):
|
| 201 |
+
if version == "1.0":
|
| 202 |
+
vae_scale_factor_spatial = 8
|
| 203 |
+
init_kwargs = {
|
| 204 |
+
"patch_size": 2,
|
| 205 |
+
"patch_size_t": None,
|
| 206 |
+
"patch_bias": True,
|
| 207 |
+
"sample_height": 480 // vae_scale_factor_spatial,
|
| 208 |
+
"sample_width": 720 // vae_scale_factor_spatial,
|
| 209 |
+
"sample_frames": 49,
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
elif version == "1.5":
|
| 213 |
+
vae_scale_factor_spatial = 8
|
| 214 |
+
init_kwargs = {
|
| 215 |
+
"patch_size": 2,
|
| 216 |
+
"patch_size_t": 2,
|
| 217 |
+
"patch_bias": False,
|
| 218 |
+
"sample_height": 300,
|
| 219 |
+
"sample_width": 300,
|
| 220 |
+
"sample_frames": 81,
|
| 221 |
+
}
|
| 222 |
+
else:
|
| 223 |
+
raise ValueError("Unsupported version of CogVideoX.")
|
| 224 |
+
|
| 225 |
+
return init_kwargs
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def get_args():
|
| 229 |
+
parser = argparse.ArgumentParser()
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
| 232 |
+
)
|
| 233 |
+
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
| 234 |
+
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
| 235 |
+
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
| 236 |
+
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
|
| 239 |
+
)
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
| 242 |
+
)
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
"--typecast_text_encoder",
|
| 245 |
+
action="store_true",
|
| 246 |
+
default=False,
|
| 247 |
+
help="Whether or not to apply fp16/bf16 precision to text_encoder",
|
| 248 |
+
)
|
| 249 |
+
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
|
| 250 |
+
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
|
| 251 |
+
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
|
| 252 |
+
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
|
| 253 |
+
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
|
| 254 |
+
parser.add_argument(
|
| 255 |
+
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
|
| 256 |
+
)
|
| 257 |
+
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
|
| 258 |
+
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
|
| 259 |
+
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
|
| 260 |
+
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
|
| 261 |
+
parser.add_argument(
|
| 262 |
+
"--i2v",
|
| 263 |
+
action="store_true",
|
| 264 |
+
default=False,
|
| 265 |
+
help="Whether the model to be converted is the Image-to-Video version of CogVideoX.",
|
| 266 |
+
)
|
| 267 |
+
parser.add_argument(
|
| 268 |
+
"--version",
|
| 269 |
+
choices=["1.0", "1.5"],
|
| 270 |
+
default="1.0",
|
| 271 |
+
help="Which version of CogVideoX to use for initializing default modeling parameters.",
|
| 272 |
+
)
|
| 273 |
+
return parser.parse_args()
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
if __name__ == "__main__":
|
| 277 |
+
args = get_args()
|
| 278 |
+
|
| 279 |
+
transformer = None
|
| 280 |
+
vae = None
|
| 281 |
+
|
| 282 |
+
if args.fp16 and args.bf16:
|
| 283 |
+
raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
|
| 284 |
+
|
| 285 |
+
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
|
| 286 |
+
|
| 287 |
+
if args.transformer_ckpt_path is not None:
|
| 288 |
+
init_kwargs = get_transformer_init_kwargs(args.version)
|
| 289 |
+
transformer = convert_transformer(
|
| 290 |
+
args.transformer_ckpt_path,
|
| 291 |
+
args.num_layers,
|
| 292 |
+
args.num_attention_heads,
|
| 293 |
+
args.use_rotary_positional_embeddings,
|
| 294 |
+
args.i2v,
|
| 295 |
+
dtype,
|
| 296 |
+
init_kwargs,
|
| 297 |
+
)
|
| 298 |
+
if args.vae_ckpt_path is not None:
|
| 299 |
+
# Keep VAE in float32 for better quality
|
| 300 |
+
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32)
|
| 301 |
+
|
| 302 |
+
text_encoder_id = "google/t5-v1_1-xxl"
|
| 303 |
+
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
| 304 |
+
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
| 305 |
+
|
| 306 |
+
if args.typecast_text_encoder:
|
| 307 |
+
text_encoder = text_encoder.to(dtype=dtype)
|
| 308 |
+
|
| 309 |
+
# Apparently, the conversion does not work anymore without this :shrug:
|
| 310 |
+
for param in text_encoder.parameters():
|
| 311 |
+
param.data = param.data.contiguous()
|
| 312 |
+
|
| 313 |
+
scheduler = CogVideoXDDIMScheduler.from_config(
|
| 314 |
+
{
|
| 315 |
+
"snr_shift_scale": args.snr_shift_scale,
|
| 316 |
+
"beta_end": 0.012,
|
| 317 |
+
"beta_schedule": "scaled_linear",
|
| 318 |
+
"beta_start": 0.00085,
|
| 319 |
+
"clip_sample": False,
|
| 320 |
+
"num_train_timesteps": 1000,
|
| 321 |
+
"prediction_type": "v_prediction",
|
| 322 |
+
"rescale_betas_zero_snr": True,
|
| 323 |
+
"set_alpha_to_one": True,
|
| 324 |
+
"timestep_spacing": "trailing",
|
| 325 |
+
}
|
| 326 |
+
)
|
| 327 |
+
if args.i2v:
|
| 328 |
+
pipeline_cls = CogVideoXImageToVideoPipeline
|
| 329 |
+
else:
|
| 330 |
+
pipeline_cls = CogVideoXPipeline
|
| 331 |
+
|
| 332 |
+
pipe = pipeline_cls(
|
| 333 |
+
tokenizer=tokenizer,
|
| 334 |
+
text_encoder=text_encoder,
|
| 335 |
+
vae=vae,
|
| 336 |
+
transformer=transformer,
|
| 337 |
+
scheduler=scheduler,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
|
| 341 |
+
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
|
| 342 |
+
# is either fp16/bf16 here).
|
| 343 |
+
|
| 344 |
+
# This is necessary This is necessary for users with insufficient memory,
|
| 345 |
+
# such as those using Colab and notebooks, as it can save some memory used for model loading.
|
| 346 |
+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
|
diffusers/scripts/convert_consistency_decoder.py
ADDED
|
@@ -0,0 +1,1128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import urllib
|
| 4 |
+
import warnings
|
| 5 |
+
from argparse import ArgumentParser
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from huggingface_hub.utils import insecure_hashlib
|
| 11 |
+
from safetensors.torch import load_file as stl
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
|
| 15 |
+
from diffusers.models.autoencoders.vae import Encoder
|
| 16 |
+
from diffusers.models.embeddings import TimestepEmbedding
|
| 17 |
+
from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
args = ArgumentParser()
|
| 21 |
+
args.add_argument("--save_pretrained", required=False, default=None, type=str)
|
| 22 |
+
args.add_argument("--test_image", required=True, type=str)
|
| 23 |
+
args = args.parse_args()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
| 27 |
+
# from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895 """
|
| 28 |
+
res = arr[timesteps].float()
|
| 29 |
+
dims_to_append = len(broadcast_shape) - len(res.shape)
|
| 30 |
+
return res[(...,) + (None,) * dims_to_append]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
| 34 |
+
# from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L45
|
| 35 |
+
betas = []
|
| 36 |
+
for i in range(num_diffusion_timesteps):
|
| 37 |
+
t1 = i / num_diffusion_timesteps
|
| 38 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
| 39 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
| 40 |
+
return torch.tensor(betas)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _download(url: str, root: str):
|
| 44 |
+
os.makedirs(root, exist_ok=True)
|
| 45 |
+
filename = os.path.basename(url)
|
| 46 |
+
|
| 47 |
+
expected_sha256 = url.split("/")[-2]
|
| 48 |
+
download_target = os.path.join(root, filename)
|
| 49 |
+
|
| 50 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 51 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
| 52 |
+
|
| 53 |
+
if os.path.isfile(download_target):
|
| 54 |
+
if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
| 55 |
+
return download_target
|
| 56 |
+
else:
|
| 57 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
| 58 |
+
|
| 59 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
| 60 |
+
with tqdm(
|
| 61 |
+
total=int(source.info().get("Content-Length")),
|
| 62 |
+
ncols=80,
|
| 63 |
+
unit="iB",
|
| 64 |
+
unit_scale=True,
|
| 65 |
+
unit_divisor=1024,
|
| 66 |
+
) as loop:
|
| 67 |
+
while True:
|
| 68 |
+
buffer = source.read(8192)
|
| 69 |
+
if not buffer:
|
| 70 |
+
break
|
| 71 |
+
|
| 72 |
+
output.write(buffer)
|
| 73 |
+
loop.update(len(buffer))
|
| 74 |
+
|
| 75 |
+
if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
| 76 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not match")
|
| 77 |
+
|
| 78 |
+
return download_target
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class ConsistencyDecoder:
|
| 82 |
+
def __init__(self, device="cuda:0", download_root=os.path.expanduser("~/.cache/clip")):
|
| 83 |
+
self.n_distilled_steps = 64
|
| 84 |
+
download_target = _download(
|
| 85 |
+
"https://openaipublic.azureedge.net/diff-vae/c9cebd3132dd9c42936d803e33424145a748843c8f716c0814838bdc8a2fe7cb/decoder.pt",
|
| 86 |
+
download_root,
|
| 87 |
+
)
|
| 88 |
+
self.ckpt = torch.jit.load(download_target).to(device)
|
| 89 |
+
self.device = device
|
| 90 |
+
sigma_data = 0.5
|
| 91 |
+
betas = betas_for_alpha_bar(1024, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2).to(device)
|
| 92 |
+
alphas = 1.0 - betas
|
| 93 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 94 |
+
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
| 95 |
+
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
| 96 |
+
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
|
| 97 |
+
sigmas = torch.sqrt(1.0 / alphas_cumprod - 1)
|
| 98 |
+
self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2)
|
| 99 |
+
self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5
|
| 100 |
+
self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def round_timesteps(timesteps, total_timesteps, n_distilled_steps, truncate_start=True):
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor")
|
| 106 |
+
rounded_timesteps = (torch.div(timesteps, space, rounding_mode="floor") + 1) * space
|
| 107 |
+
if truncate_start:
|
| 108 |
+
rounded_timesteps[rounded_timesteps == total_timesteps] -= space
|
| 109 |
+
else:
|
| 110 |
+
rounded_timesteps[rounded_timesteps == total_timesteps] -= space
|
| 111 |
+
rounded_timesteps[rounded_timesteps == 0] += space
|
| 112 |
+
return rounded_timesteps
|
| 113 |
+
|
| 114 |
+
@staticmethod
|
| 115 |
+
def ldm_transform_latent(z, extra_scale_factor=1):
|
| 116 |
+
channel_means = [0.38862467, 0.02253063, 0.07381133, -0.0171294]
|
| 117 |
+
channel_stds = [0.9654121, 1.0440036, 0.76147926, 0.77022034]
|
| 118 |
+
|
| 119 |
+
if len(z.shape) != 4:
|
| 120 |
+
raise ValueError()
|
| 121 |
+
|
| 122 |
+
z = z * 0.18215
|
| 123 |
+
channels = [z[:, i] for i in range(z.shape[1])]
|
| 124 |
+
|
| 125 |
+
channels = [extra_scale_factor * (c - channel_means[i]) / channel_stds[i] for i, c in enumerate(channels)]
|
| 126 |
+
return torch.stack(channels, dim=1)
|
| 127 |
+
|
| 128 |
+
@torch.no_grad()
|
| 129 |
+
def __call__(
|
| 130 |
+
self,
|
| 131 |
+
features: torch.Tensor,
|
| 132 |
+
schedule=[1.0, 0.5],
|
| 133 |
+
generator=None,
|
| 134 |
+
):
|
| 135 |
+
features = self.ldm_transform_latent(features)
|
| 136 |
+
ts = self.round_timesteps(
|
| 137 |
+
torch.arange(0, 1024),
|
| 138 |
+
1024,
|
| 139 |
+
self.n_distilled_steps,
|
| 140 |
+
truncate_start=False,
|
| 141 |
+
)
|
| 142 |
+
shape = (
|
| 143 |
+
features.size(0),
|
| 144 |
+
3,
|
| 145 |
+
8 * features.size(2),
|
| 146 |
+
8 * features.size(3),
|
| 147 |
+
)
|
| 148 |
+
x_start = torch.zeros(shape, device=features.device, dtype=features.dtype)
|
| 149 |
+
schedule_timesteps = [int((1024 - 1) * s) for s in schedule]
|
| 150 |
+
for i in schedule_timesteps:
|
| 151 |
+
t = ts[i].item()
|
| 152 |
+
t_ = torch.tensor([t] * features.shape[0]).to(self.device)
|
| 153 |
+
# noise = torch.randn_like(x_start)
|
| 154 |
+
noise = torch.randn(x_start.shape, dtype=x_start.dtype, generator=generator).to(device=x_start.device)
|
| 155 |
+
x_start = (
|
| 156 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t_, x_start.shape) * x_start
|
| 157 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t_, x_start.shape) * noise
|
| 158 |
+
)
|
| 159 |
+
c_in = _extract_into_tensor(self.c_in, t_, x_start.shape)
|
| 160 |
+
|
| 161 |
+
import torch.nn.functional as F
|
| 162 |
+
|
| 163 |
+
from diffusers import UNet2DModel
|
| 164 |
+
|
| 165 |
+
if isinstance(self.ckpt, UNet2DModel):
|
| 166 |
+
input = torch.concat([c_in * x_start, F.upsample_nearest(features, scale_factor=8)], dim=1)
|
| 167 |
+
model_output = self.ckpt(input, t_).sample
|
| 168 |
+
else:
|
| 169 |
+
model_output = self.ckpt(c_in * x_start, t_, features=features)
|
| 170 |
+
|
| 171 |
+
B, C = x_start.shape[:2]
|
| 172 |
+
model_output, _ = torch.split(model_output, C, dim=1)
|
| 173 |
+
pred_xstart = (
|
| 174 |
+
_extract_into_tensor(self.c_out, t_, x_start.shape) * model_output
|
| 175 |
+
+ _extract_into_tensor(self.c_skip, t_, x_start.shape) * x_start
|
| 176 |
+
).clamp(-1, 1)
|
| 177 |
+
x_start = pred_xstart
|
| 178 |
+
return x_start
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def save_image(image, name):
|
| 182 |
+
import numpy as np
|
| 183 |
+
from PIL import Image
|
| 184 |
+
|
| 185 |
+
image = image[0].cpu().numpy()
|
| 186 |
+
image = (image + 1.0) * 127.5
|
| 187 |
+
image = image.clip(0, 255).astype(np.uint8)
|
| 188 |
+
image = Image.fromarray(image.transpose(1, 2, 0))
|
| 189 |
+
image.save(name)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def load_image(uri, size=None, center_crop=False):
|
| 193 |
+
import numpy as np
|
| 194 |
+
from PIL import Image
|
| 195 |
+
|
| 196 |
+
image = Image.open(uri)
|
| 197 |
+
if center_crop:
|
| 198 |
+
image = image.crop(
|
| 199 |
+
(
|
| 200 |
+
(image.width - min(image.width, image.height)) // 2,
|
| 201 |
+
(image.height - min(image.width, image.height)) // 2,
|
| 202 |
+
(image.width + min(image.width, image.height)) // 2,
|
| 203 |
+
(image.height + min(image.width, image.height)) // 2,
|
| 204 |
+
)
|
| 205 |
+
)
|
| 206 |
+
if size is not None:
|
| 207 |
+
image = image.resize(size)
|
| 208 |
+
image = torch.tensor(np.array(image).transpose(2, 0, 1)).unsqueeze(0).float()
|
| 209 |
+
image = image / 127.5 - 1.0
|
| 210 |
+
return image
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class TimestepEmbedding_(nn.Module):
|
| 214 |
+
def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None:
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.emb = nn.Embedding(n_time, n_emb)
|
| 217 |
+
self.f_1 = nn.Linear(n_emb, n_out)
|
| 218 |
+
self.f_2 = nn.Linear(n_out, n_out)
|
| 219 |
+
|
| 220 |
+
def forward(self, x) -> torch.Tensor:
|
| 221 |
+
x = self.emb(x)
|
| 222 |
+
x = self.f_1(x)
|
| 223 |
+
x = F.silu(x)
|
| 224 |
+
return self.f_2(x)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class ImageEmbedding(nn.Module):
|
| 228 |
+
def __init__(self, in_channels=7, out_channels=320) -> None:
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
| 231 |
+
|
| 232 |
+
def forward(self, x) -> torch.Tensor:
|
| 233 |
+
return self.f(x)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class ImageUnembedding(nn.Module):
|
| 237 |
+
def __init__(self, in_channels=320, out_channels=6) -> None:
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.gn = nn.GroupNorm(32, in_channels)
|
| 240 |
+
self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
| 241 |
+
|
| 242 |
+
def forward(self, x) -> torch.Tensor:
|
| 243 |
+
return self.f(F.silu(self.gn(x)))
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class ConvResblock(nn.Module):
|
| 247 |
+
def __init__(self, in_features=320, out_features=320) -> None:
|
| 248 |
+
super().__init__()
|
| 249 |
+
self.f_t = nn.Linear(1280, out_features * 2)
|
| 250 |
+
|
| 251 |
+
self.gn_1 = nn.GroupNorm(32, in_features)
|
| 252 |
+
self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1)
|
| 253 |
+
|
| 254 |
+
self.gn_2 = nn.GroupNorm(32, out_features)
|
| 255 |
+
self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1)
|
| 256 |
+
|
| 257 |
+
skip_conv = in_features != out_features
|
| 258 |
+
self.f_s = nn.Conv2d(in_features, out_features, kernel_size=1, padding=0) if skip_conv else nn.Identity()
|
| 259 |
+
|
| 260 |
+
def forward(self, x, t):
|
| 261 |
+
x_skip = x
|
| 262 |
+
t = self.f_t(F.silu(t))
|
| 263 |
+
t = t.chunk(2, dim=1)
|
| 264 |
+
t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1
|
| 265 |
+
t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3)
|
| 266 |
+
|
| 267 |
+
gn_1 = F.silu(self.gn_1(x))
|
| 268 |
+
f_1 = self.f_1(gn_1)
|
| 269 |
+
|
| 270 |
+
gn_2 = self.gn_2(f_1)
|
| 271 |
+
|
| 272 |
+
return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2))
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# Also ConvResblock
|
| 276 |
+
class Downsample(nn.Module):
|
| 277 |
+
def __init__(self, in_channels=320) -> None:
|
| 278 |
+
super().__init__()
|
| 279 |
+
self.f_t = nn.Linear(1280, in_channels * 2)
|
| 280 |
+
|
| 281 |
+
self.gn_1 = nn.GroupNorm(32, in_channels)
|
| 282 |
+
self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
|
| 283 |
+
self.gn_2 = nn.GroupNorm(32, in_channels)
|
| 284 |
+
|
| 285 |
+
self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
|
| 286 |
+
|
| 287 |
+
def forward(self, x, t) -> torch.Tensor:
|
| 288 |
+
x_skip = x
|
| 289 |
+
|
| 290 |
+
t = self.f_t(F.silu(t))
|
| 291 |
+
t_1, t_2 = t.chunk(2, dim=1)
|
| 292 |
+
t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
|
| 293 |
+
t_2 = t_2.unsqueeze(2).unsqueeze(3)
|
| 294 |
+
|
| 295 |
+
gn_1 = F.silu(self.gn_1(x))
|
| 296 |
+
avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None)
|
| 297 |
+
|
| 298 |
+
f_1 = self.f_1(avg_pool2d)
|
| 299 |
+
gn_2 = self.gn_2(f_1)
|
| 300 |
+
|
| 301 |
+
f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
|
| 302 |
+
|
| 303 |
+
return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
# Also ConvResblock
|
| 307 |
+
class Upsample(nn.Module):
|
| 308 |
+
def __init__(self, in_channels=1024) -> None:
|
| 309 |
+
super().__init__()
|
| 310 |
+
self.f_t = nn.Linear(1280, in_channels * 2)
|
| 311 |
+
|
| 312 |
+
self.gn_1 = nn.GroupNorm(32, in_channels)
|
| 313 |
+
self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
|
| 314 |
+
self.gn_2 = nn.GroupNorm(32, in_channels)
|
| 315 |
+
|
| 316 |
+
self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
|
| 317 |
+
|
| 318 |
+
def forward(self, x, t) -> torch.Tensor:
|
| 319 |
+
x_skip = x
|
| 320 |
+
|
| 321 |
+
t = self.f_t(F.silu(t))
|
| 322 |
+
t_1, t_2 = t.chunk(2, dim=1)
|
| 323 |
+
t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
|
| 324 |
+
t_2 = t_2.unsqueeze(2).unsqueeze(3)
|
| 325 |
+
|
| 326 |
+
gn_1 = F.silu(self.gn_1(x))
|
| 327 |
+
upsample = F.upsample_nearest(gn_1, scale_factor=2)
|
| 328 |
+
f_1 = self.f_1(upsample)
|
| 329 |
+
gn_2 = self.gn_2(f_1)
|
| 330 |
+
|
| 331 |
+
f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
|
| 332 |
+
|
| 333 |
+
return f_2 + F.upsample_nearest(x_skip, scale_factor=2)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class ConvUNetVAE(nn.Module):
|
| 337 |
+
def __init__(self) -> None:
|
| 338 |
+
super().__init__()
|
| 339 |
+
self.embed_image = ImageEmbedding()
|
| 340 |
+
self.embed_time = TimestepEmbedding_()
|
| 341 |
+
|
| 342 |
+
down_0 = nn.ModuleList(
|
| 343 |
+
[
|
| 344 |
+
ConvResblock(320, 320),
|
| 345 |
+
ConvResblock(320, 320),
|
| 346 |
+
ConvResblock(320, 320),
|
| 347 |
+
Downsample(320),
|
| 348 |
+
]
|
| 349 |
+
)
|
| 350 |
+
down_1 = nn.ModuleList(
|
| 351 |
+
[
|
| 352 |
+
ConvResblock(320, 640),
|
| 353 |
+
ConvResblock(640, 640),
|
| 354 |
+
ConvResblock(640, 640),
|
| 355 |
+
Downsample(640),
|
| 356 |
+
]
|
| 357 |
+
)
|
| 358 |
+
down_2 = nn.ModuleList(
|
| 359 |
+
[
|
| 360 |
+
ConvResblock(640, 1024),
|
| 361 |
+
ConvResblock(1024, 1024),
|
| 362 |
+
ConvResblock(1024, 1024),
|
| 363 |
+
Downsample(1024),
|
| 364 |
+
]
|
| 365 |
+
)
|
| 366 |
+
down_3 = nn.ModuleList(
|
| 367 |
+
[
|
| 368 |
+
ConvResblock(1024, 1024),
|
| 369 |
+
ConvResblock(1024, 1024),
|
| 370 |
+
ConvResblock(1024, 1024),
|
| 371 |
+
]
|
| 372 |
+
)
|
| 373 |
+
self.down = nn.ModuleList(
|
| 374 |
+
[
|
| 375 |
+
down_0,
|
| 376 |
+
down_1,
|
| 377 |
+
down_2,
|
| 378 |
+
down_3,
|
| 379 |
+
]
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
self.mid = nn.ModuleList(
|
| 383 |
+
[
|
| 384 |
+
ConvResblock(1024, 1024),
|
| 385 |
+
ConvResblock(1024, 1024),
|
| 386 |
+
]
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
up_3 = nn.ModuleList(
|
| 390 |
+
[
|
| 391 |
+
ConvResblock(1024 * 2, 1024),
|
| 392 |
+
ConvResblock(1024 * 2, 1024),
|
| 393 |
+
ConvResblock(1024 * 2, 1024),
|
| 394 |
+
ConvResblock(1024 * 2, 1024),
|
| 395 |
+
Upsample(1024),
|
| 396 |
+
]
|
| 397 |
+
)
|
| 398 |
+
up_2 = nn.ModuleList(
|
| 399 |
+
[
|
| 400 |
+
ConvResblock(1024 * 2, 1024),
|
| 401 |
+
ConvResblock(1024 * 2, 1024),
|
| 402 |
+
ConvResblock(1024 * 2, 1024),
|
| 403 |
+
ConvResblock(1024 + 640, 1024),
|
| 404 |
+
Upsample(1024),
|
| 405 |
+
]
|
| 406 |
+
)
|
| 407 |
+
up_1 = nn.ModuleList(
|
| 408 |
+
[
|
| 409 |
+
ConvResblock(1024 + 640, 640),
|
| 410 |
+
ConvResblock(640 * 2, 640),
|
| 411 |
+
ConvResblock(640 * 2, 640),
|
| 412 |
+
ConvResblock(320 + 640, 640),
|
| 413 |
+
Upsample(640),
|
| 414 |
+
]
|
| 415 |
+
)
|
| 416 |
+
up_0 = nn.ModuleList(
|
| 417 |
+
[
|
| 418 |
+
ConvResblock(320 + 640, 320),
|
| 419 |
+
ConvResblock(320 * 2, 320),
|
| 420 |
+
ConvResblock(320 * 2, 320),
|
| 421 |
+
ConvResblock(320 * 2, 320),
|
| 422 |
+
]
|
| 423 |
+
)
|
| 424 |
+
self.up = nn.ModuleList(
|
| 425 |
+
[
|
| 426 |
+
up_0,
|
| 427 |
+
up_1,
|
| 428 |
+
up_2,
|
| 429 |
+
up_3,
|
| 430 |
+
]
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
self.output = ImageUnembedding()
|
| 434 |
+
|
| 435 |
+
def forward(self, x, t, features) -> torch.Tensor:
|
| 436 |
+
converted = hasattr(self, "converted") and self.converted
|
| 437 |
+
|
| 438 |
+
x = torch.cat([x, F.upsample_nearest(features, scale_factor=8)], dim=1)
|
| 439 |
+
|
| 440 |
+
if converted:
|
| 441 |
+
t = self.time_embedding(self.time_proj(t))
|
| 442 |
+
else:
|
| 443 |
+
t = self.embed_time(t)
|
| 444 |
+
|
| 445 |
+
x = self.embed_image(x)
|
| 446 |
+
|
| 447 |
+
skips = [x]
|
| 448 |
+
for i, down in enumerate(self.down):
|
| 449 |
+
if converted and i in [0, 1, 2, 3]:
|
| 450 |
+
x, skips_ = down(x, t)
|
| 451 |
+
for skip in skips_:
|
| 452 |
+
skips.append(skip)
|
| 453 |
+
else:
|
| 454 |
+
for block in down:
|
| 455 |
+
x = block(x, t)
|
| 456 |
+
skips.append(x)
|
| 457 |
+
print(x.float().abs().sum())
|
| 458 |
+
|
| 459 |
+
if converted:
|
| 460 |
+
x = self.mid(x, t)
|
| 461 |
+
else:
|
| 462 |
+
for i in range(2):
|
| 463 |
+
x = self.mid[i](x, t)
|
| 464 |
+
print(x.float().abs().sum())
|
| 465 |
+
|
| 466 |
+
for i, up in enumerate(self.up[::-1]):
|
| 467 |
+
if converted and i in [0, 1, 2, 3]:
|
| 468 |
+
skip_4 = skips.pop()
|
| 469 |
+
skip_3 = skips.pop()
|
| 470 |
+
skip_2 = skips.pop()
|
| 471 |
+
skip_1 = skips.pop()
|
| 472 |
+
skips_ = (skip_1, skip_2, skip_3, skip_4)
|
| 473 |
+
x = up(x, skips_, t)
|
| 474 |
+
else:
|
| 475 |
+
for block in up:
|
| 476 |
+
if isinstance(block, ConvResblock):
|
| 477 |
+
x = torch.concat([x, skips.pop()], dim=1)
|
| 478 |
+
x = block(x, t)
|
| 479 |
+
|
| 480 |
+
return self.output(x)
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def rename_state_dict_key(k):
|
| 484 |
+
k = k.replace("blocks.", "")
|
| 485 |
+
for i in range(5):
|
| 486 |
+
k = k.replace(f"down_{i}_", f"down.{i}.")
|
| 487 |
+
k = k.replace(f"conv_{i}.", f"{i}.")
|
| 488 |
+
k = k.replace(f"up_{i}_", f"up.{i}.")
|
| 489 |
+
k = k.replace(f"mid_{i}", f"mid.{i}")
|
| 490 |
+
k = k.replace("upsamp.", "4.")
|
| 491 |
+
k = k.replace("downsamp.", "3.")
|
| 492 |
+
k = k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias")
|
| 493 |
+
k = k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias")
|
| 494 |
+
k = k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias")
|
| 495 |
+
k = k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias")
|
| 496 |
+
k = k.replace("f.w", "f.weight").replace("f.b", "f.bias")
|
| 497 |
+
k = k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias")
|
| 498 |
+
k = k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias")
|
| 499 |
+
k = k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias")
|
| 500 |
+
return k
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def rename_state_dict(sd, embedding):
|
| 504 |
+
sd = {rename_state_dict_key(k): v for k, v in sd.items()}
|
| 505 |
+
sd["embed_time.emb.weight"] = embedding["weight"]
|
| 506 |
+
return sd
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
# encode with stable diffusion vae
|
| 510 |
+
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
| 511 |
+
pipe.vae.cuda()
|
| 512 |
+
|
| 513 |
+
# construct original decoder with jitted model
|
| 514 |
+
decoder_consistency = ConsistencyDecoder(device="cuda:0")
|
| 515 |
+
|
| 516 |
+
# construct UNet code, overwrite the decoder with conv_unet_vae
|
| 517 |
+
model = ConvUNetVAE()
|
| 518 |
+
model.load_state_dict(
|
| 519 |
+
rename_state_dict(
|
| 520 |
+
stl("consistency_decoder.safetensors"),
|
| 521 |
+
stl("embedding.safetensors"),
|
| 522 |
+
)
|
| 523 |
+
)
|
| 524 |
+
model = model.cuda()
|
| 525 |
+
|
| 526 |
+
decoder_consistency.ckpt = model
|
| 527 |
+
|
| 528 |
+
image = load_image(args.test_image, size=(256, 256), center_crop=True)
|
| 529 |
+
latent = pipe.vae.encode(image.half().cuda()).latent_dist.sample()
|
| 530 |
+
|
| 531 |
+
# decode with gan
|
| 532 |
+
sample_gan = pipe.vae.decode(latent).sample.detach()
|
| 533 |
+
save_image(sample_gan, "gan.png")
|
| 534 |
+
|
| 535 |
+
# decode with conv_unet_vae
|
| 536 |
+
sample_consistency_orig = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
|
| 537 |
+
save_image(sample_consistency_orig, "con_orig.png")
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
########### conversion
|
| 541 |
+
|
| 542 |
+
print("CONVERSION")
|
| 543 |
+
|
| 544 |
+
print("DOWN BLOCK ONE")
|
| 545 |
+
|
| 546 |
+
block_one_sd_orig = model.down[0].state_dict()
|
| 547 |
+
block_one_sd_new = {}
|
| 548 |
+
|
| 549 |
+
for i in range(3):
|
| 550 |
+
block_one_sd_new[f"resnets.{i}.norm1.weight"] = block_one_sd_orig.pop(f"{i}.gn_1.weight")
|
| 551 |
+
block_one_sd_new[f"resnets.{i}.norm1.bias"] = block_one_sd_orig.pop(f"{i}.gn_1.bias")
|
| 552 |
+
block_one_sd_new[f"resnets.{i}.conv1.weight"] = block_one_sd_orig.pop(f"{i}.f_1.weight")
|
| 553 |
+
block_one_sd_new[f"resnets.{i}.conv1.bias"] = block_one_sd_orig.pop(f"{i}.f_1.bias")
|
| 554 |
+
block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_one_sd_orig.pop(f"{i}.f_t.weight")
|
| 555 |
+
block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_one_sd_orig.pop(f"{i}.f_t.bias")
|
| 556 |
+
block_one_sd_new[f"resnets.{i}.norm2.weight"] = block_one_sd_orig.pop(f"{i}.gn_2.weight")
|
| 557 |
+
block_one_sd_new[f"resnets.{i}.norm2.bias"] = block_one_sd_orig.pop(f"{i}.gn_2.bias")
|
| 558 |
+
block_one_sd_new[f"resnets.{i}.conv2.weight"] = block_one_sd_orig.pop(f"{i}.f_2.weight")
|
| 559 |
+
block_one_sd_new[f"resnets.{i}.conv2.bias"] = block_one_sd_orig.pop(f"{i}.f_2.bias")
|
| 560 |
+
|
| 561 |
+
block_one_sd_new["downsamplers.0.norm1.weight"] = block_one_sd_orig.pop("3.gn_1.weight")
|
| 562 |
+
block_one_sd_new["downsamplers.0.norm1.bias"] = block_one_sd_orig.pop("3.gn_1.bias")
|
| 563 |
+
block_one_sd_new["downsamplers.0.conv1.weight"] = block_one_sd_orig.pop("3.f_1.weight")
|
| 564 |
+
block_one_sd_new["downsamplers.0.conv1.bias"] = block_one_sd_orig.pop("3.f_1.bias")
|
| 565 |
+
block_one_sd_new["downsamplers.0.time_emb_proj.weight"] = block_one_sd_orig.pop("3.f_t.weight")
|
| 566 |
+
block_one_sd_new["downsamplers.0.time_emb_proj.bias"] = block_one_sd_orig.pop("3.f_t.bias")
|
| 567 |
+
block_one_sd_new["downsamplers.0.norm2.weight"] = block_one_sd_orig.pop("3.gn_2.weight")
|
| 568 |
+
block_one_sd_new["downsamplers.0.norm2.bias"] = block_one_sd_orig.pop("3.gn_2.bias")
|
| 569 |
+
block_one_sd_new["downsamplers.0.conv2.weight"] = block_one_sd_orig.pop("3.f_2.weight")
|
| 570 |
+
block_one_sd_new["downsamplers.0.conv2.bias"] = block_one_sd_orig.pop("3.f_2.bias")
|
| 571 |
+
|
| 572 |
+
assert len(block_one_sd_orig) == 0
|
| 573 |
+
|
| 574 |
+
block_one = ResnetDownsampleBlock2D(
|
| 575 |
+
in_channels=320,
|
| 576 |
+
out_channels=320,
|
| 577 |
+
temb_channels=1280,
|
| 578 |
+
num_layers=3,
|
| 579 |
+
add_downsample=True,
|
| 580 |
+
resnet_time_scale_shift="scale_shift",
|
| 581 |
+
resnet_eps=1e-5,
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
block_one.load_state_dict(block_one_sd_new)
|
| 585 |
+
|
| 586 |
+
print("DOWN BLOCK TWO")
|
| 587 |
+
|
| 588 |
+
block_two_sd_orig = model.down[1].state_dict()
|
| 589 |
+
block_two_sd_new = {}
|
| 590 |
+
|
| 591 |
+
for i in range(3):
|
| 592 |
+
block_two_sd_new[f"resnets.{i}.norm1.weight"] = block_two_sd_orig.pop(f"{i}.gn_1.weight")
|
| 593 |
+
block_two_sd_new[f"resnets.{i}.norm1.bias"] = block_two_sd_orig.pop(f"{i}.gn_1.bias")
|
| 594 |
+
block_two_sd_new[f"resnets.{i}.conv1.weight"] = block_two_sd_orig.pop(f"{i}.f_1.weight")
|
| 595 |
+
block_two_sd_new[f"resnets.{i}.conv1.bias"] = block_two_sd_orig.pop(f"{i}.f_1.bias")
|
| 596 |
+
block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_two_sd_orig.pop(f"{i}.f_t.weight")
|
| 597 |
+
block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_two_sd_orig.pop(f"{i}.f_t.bias")
|
| 598 |
+
block_two_sd_new[f"resnets.{i}.norm2.weight"] = block_two_sd_orig.pop(f"{i}.gn_2.weight")
|
| 599 |
+
block_two_sd_new[f"resnets.{i}.norm2.bias"] = block_two_sd_orig.pop(f"{i}.gn_2.bias")
|
| 600 |
+
block_two_sd_new[f"resnets.{i}.conv2.weight"] = block_two_sd_orig.pop(f"{i}.f_2.weight")
|
| 601 |
+
block_two_sd_new[f"resnets.{i}.conv2.bias"] = block_two_sd_orig.pop(f"{i}.f_2.bias")
|
| 602 |
+
|
| 603 |
+
if i == 0:
|
| 604 |
+
block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_two_sd_orig.pop(f"{i}.f_s.weight")
|
| 605 |
+
block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_two_sd_orig.pop(f"{i}.f_s.bias")
|
| 606 |
+
|
| 607 |
+
block_two_sd_new["downsamplers.0.norm1.weight"] = block_two_sd_orig.pop("3.gn_1.weight")
|
| 608 |
+
block_two_sd_new["downsamplers.0.norm1.bias"] = block_two_sd_orig.pop("3.gn_1.bias")
|
| 609 |
+
block_two_sd_new["downsamplers.0.conv1.weight"] = block_two_sd_orig.pop("3.f_1.weight")
|
| 610 |
+
block_two_sd_new["downsamplers.0.conv1.bias"] = block_two_sd_orig.pop("3.f_1.bias")
|
| 611 |
+
block_two_sd_new["downsamplers.0.time_emb_proj.weight"] = block_two_sd_orig.pop("3.f_t.weight")
|
| 612 |
+
block_two_sd_new["downsamplers.0.time_emb_proj.bias"] = block_two_sd_orig.pop("3.f_t.bias")
|
| 613 |
+
block_two_sd_new["downsamplers.0.norm2.weight"] = block_two_sd_orig.pop("3.gn_2.weight")
|
| 614 |
+
block_two_sd_new["downsamplers.0.norm2.bias"] = block_two_sd_orig.pop("3.gn_2.bias")
|
| 615 |
+
block_two_sd_new["downsamplers.0.conv2.weight"] = block_two_sd_orig.pop("3.f_2.weight")
|
| 616 |
+
block_two_sd_new["downsamplers.0.conv2.bias"] = block_two_sd_orig.pop("3.f_2.bias")
|
| 617 |
+
|
| 618 |
+
assert len(block_two_sd_orig) == 0
|
| 619 |
+
|
| 620 |
+
block_two = ResnetDownsampleBlock2D(
|
| 621 |
+
in_channels=320,
|
| 622 |
+
out_channels=640,
|
| 623 |
+
temb_channels=1280,
|
| 624 |
+
num_layers=3,
|
| 625 |
+
add_downsample=True,
|
| 626 |
+
resnet_time_scale_shift="scale_shift",
|
| 627 |
+
resnet_eps=1e-5,
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
block_two.load_state_dict(block_two_sd_new)
|
| 631 |
+
|
| 632 |
+
print("DOWN BLOCK THREE")
|
| 633 |
+
|
| 634 |
+
block_three_sd_orig = model.down[2].state_dict()
|
| 635 |
+
block_three_sd_new = {}
|
| 636 |
+
|
| 637 |
+
for i in range(3):
|
| 638 |
+
block_three_sd_new[f"resnets.{i}.norm1.weight"] = block_three_sd_orig.pop(f"{i}.gn_1.weight")
|
| 639 |
+
block_three_sd_new[f"resnets.{i}.norm1.bias"] = block_three_sd_orig.pop(f"{i}.gn_1.bias")
|
| 640 |
+
block_three_sd_new[f"resnets.{i}.conv1.weight"] = block_three_sd_orig.pop(f"{i}.f_1.weight")
|
| 641 |
+
block_three_sd_new[f"resnets.{i}.conv1.bias"] = block_three_sd_orig.pop(f"{i}.f_1.bias")
|
| 642 |
+
block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_three_sd_orig.pop(f"{i}.f_t.weight")
|
| 643 |
+
block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_three_sd_orig.pop(f"{i}.f_t.bias")
|
| 644 |
+
block_three_sd_new[f"resnets.{i}.norm2.weight"] = block_three_sd_orig.pop(f"{i}.gn_2.weight")
|
| 645 |
+
block_three_sd_new[f"resnets.{i}.norm2.bias"] = block_three_sd_orig.pop(f"{i}.gn_2.bias")
|
| 646 |
+
block_three_sd_new[f"resnets.{i}.conv2.weight"] = block_three_sd_orig.pop(f"{i}.f_2.weight")
|
| 647 |
+
block_three_sd_new[f"resnets.{i}.conv2.bias"] = block_three_sd_orig.pop(f"{i}.f_2.bias")
|
| 648 |
+
|
| 649 |
+
if i == 0:
|
| 650 |
+
block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_three_sd_orig.pop(f"{i}.f_s.weight")
|
| 651 |
+
block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_three_sd_orig.pop(f"{i}.f_s.bias")
|
| 652 |
+
|
| 653 |
+
block_three_sd_new["downsamplers.0.norm1.weight"] = block_three_sd_orig.pop("3.gn_1.weight")
|
| 654 |
+
block_three_sd_new["downsamplers.0.norm1.bias"] = block_three_sd_orig.pop("3.gn_1.bias")
|
| 655 |
+
block_three_sd_new["downsamplers.0.conv1.weight"] = block_three_sd_orig.pop("3.f_1.weight")
|
| 656 |
+
block_three_sd_new["downsamplers.0.conv1.bias"] = block_three_sd_orig.pop("3.f_1.bias")
|
| 657 |
+
block_three_sd_new["downsamplers.0.time_emb_proj.weight"] = block_three_sd_orig.pop("3.f_t.weight")
|
| 658 |
+
block_three_sd_new["downsamplers.0.time_emb_proj.bias"] = block_three_sd_orig.pop("3.f_t.bias")
|
| 659 |
+
block_three_sd_new["downsamplers.0.norm2.weight"] = block_three_sd_orig.pop("3.gn_2.weight")
|
| 660 |
+
block_three_sd_new["downsamplers.0.norm2.bias"] = block_three_sd_orig.pop("3.gn_2.bias")
|
| 661 |
+
block_three_sd_new["downsamplers.0.conv2.weight"] = block_three_sd_orig.pop("3.f_2.weight")
|
| 662 |
+
block_three_sd_new["downsamplers.0.conv2.bias"] = block_three_sd_orig.pop("3.f_2.bias")
|
| 663 |
+
|
| 664 |
+
assert len(block_three_sd_orig) == 0
|
| 665 |
+
|
| 666 |
+
block_three = ResnetDownsampleBlock2D(
|
| 667 |
+
in_channels=640,
|
| 668 |
+
out_channels=1024,
|
| 669 |
+
temb_channels=1280,
|
| 670 |
+
num_layers=3,
|
| 671 |
+
add_downsample=True,
|
| 672 |
+
resnet_time_scale_shift="scale_shift",
|
| 673 |
+
resnet_eps=1e-5,
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
block_three.load_state_dict(block_three_sd_new)
|
| 677 |
+
|
| 678 |
+
print("DOWN BLOCK FOUR")
|
| 679 |
+
|
| 680 |
+
block_four_sd_orig = model.down[3].state_dict()
|
| 681 |
+
block_four_sd_new = {}
|
| 682 |
+
|
| 683 |
+
for i in range(3):
|
| 684 |
+
block_four_sd_new[f"resnets.{i}.norm1.weight"] = block_four_sd_orig.pop(f"{i}.gn_1.weight")
|
| 685 |
+
block_four_sd_new[f"resnets.{i}.norm1.bias"] = block_four_sd_orig.pop(f"{i}.gn_1.bias")
|
| 686 |
+
block_four_sd_new[f"resnets.{i}.conv1.weight"] = block_four_sd_orig.pop(f"{i}.f_1.weight")
|
| 687 |
+
block_four_sd_new[f"resnets.{i}.conv1.bias"] = block_four_sd_orig.pop(f"{i}.f_1.bias")
|
| 688 |
+
block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_four_sd_orig.pop(f"{i}.f_t.weight")
|
| 689 |
+
block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_four_sd_orig.pop(f"{i}.f_t.bias")
|
| 690 |
+
block_four_sd_new[f"resnets.{i}.norm2.weight"] = block_four_sd_orig.pop(f"{i}.gn_2.weight")
|
| 691 |
+
block_four_sd_new[f"resnets.{i}.norm2.bias"] = block_four_sd_orig.pop(f"{i}.gn_2.bias")
|
| 692 |
+
block_four_sd_new[f"resnets.{i}.conv2.weight"] = block_four_sd_orig.pop(f"{i}.f_2.weight")
|
| 693 |
+
block_four_sd_new[f"resnets.{i}.conv2.bias"] = block_four_sd_orig.pop(f"{i}.f_2.bias")
|
| 694 |
+
|
| 695 |
+
assert len(block_four_sd_orig) == 0
|
| 696 |
+
|
| 697 |
+
block_four = ResnetDownsampleBlock2D(
|
| 698 |
+
in_channels=1024,
|
| 699 |
+
out_channels=1024,
|
| 700 |
+
temb_channels=1280,
|
| 701 |
+
num_layers=3,
|
| 702 |
+
add_downsample=False,
|
| 703 |
+
resnet_time_scale_shift="scale_shift",
|
| 704 |
+
resnet_eps=1e-5,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
block_four.load_state_dict(block_four_sd_new)
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
print("MID BLOCK 1")
|
| 711 |
+
|
| 712 |
+
mid_block_one_sd_orig = model.mid.state_dict()
|
| 713 |
+
mid_block_one_sd_new = {}
|
| 714 |
+
|
| 715 |
+
for i in range(2):
|
| 716 |
+
mid_block_one_sd_new[f"resnets.{i}.norm1.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.weight")
|
| 717 |
+
mid_block_one_sd_new[f"resnets.{i}.norm1.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.bias")
|
| 718 |
+
mid_block_one_sd_new[f"resnets.{i}.conv1.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_1.weight")
|
| 719 |
+
mid_block_one_sd_new[f"resnets.{i}.conv1.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_1.bias")
|
| 720 |
+
mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_t.weight")
|
| 721 |
+
mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_t.bias")
|
| 722 |
+
mid_block_one_sd_new[f"resnets.{i}.norm2.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.weight")
|
| 723 |
+
mid_block_one_sd_new[f"resnets.{i}.norm2.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.bias")
|
| 724 |
+
mid_block_one_sd_new[f"resnets.{i}.conv2.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_2.weight")
|
| 725 |
+
mid_block_one_sd_new[f"resnets.{i}.conv2.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_2.bias")
|
| 726 |
+
|
| 727 |
+
assert len(mid_block_one_sd_orig) == 0
|
| 728 |
+
|
| 729 |
+
mid_block_one = UNetMidBlock2D(
|
| 730 |
+
in_channels=1024,
|
| 731 |
+
temb_channels=1280,
|
| 732 |
+
num_layers=1,
|
| 733 |
+
resnet_time_scale_shift="scale_shift",
|
| 734 |
+
resnet_eps=1e-5,
|
| 735 |
+
add_attention=False,
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
mid_block_one.load_state_dict(mid_block_one_sd_new)
|
| 739 |
+
|
| 740 |
+
print("UP BLOCK ONE")
|
| 741 |
+
|
| 742 |
+
up_block_one_sd_orig = model.up[-1].state_dict()
|
| 743 |
+
up_block_one_sd_new = {}
|
| 744 |
+
|
| 745 |
+
for i in range(4):
|
| 746 |
+
up_block_one_sd_new[f"resnets.{i}.norm1.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_1.weight")
|
| 747 |
+
up_block_one_sd_new[f"resnets.{i}.norm1.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_1.bias")
|
| 748 |
+
up_block_one_sd_new[f"resnets.{i}.conv1.weight"] = up_block_one_sd_orig.pop(f"{i}.f_1.weight")
|
| 749 |
+
up_block_one_sd_new[f"resnets.{i}.conv1.bias"] = up_block_one_sd_orig.pop(f"{i}.f_1.bias")
|
| 750 |
+
up_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_one_sd_orig.pop(f"{i}.f_t.weight")
|
| 751 |
+
up_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_one_sd_orig.pop(f"{i}.f_t.bias")
|
| 752 |
+
up_block_one_sd_new[f"resnets.{i}.norm2.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_2.weight")
|
| 753 |
+
up_block_one_sd_new[f"resnets.{i}.norm2.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_2.bias")
|
| 754 |
+
up_block_one_sd_new[f"resnets.{i}.conv2.weight"] = up_block_one_sd_orig.pop(f"{i}.f_2.weight")
|
| 755 |
+
up_block_one_sd_new[f"resnets.{i}.conv2.bias"] = up_block_one_sd_orig.pop(f"{i}.f_2.bias")
|
| 756 |
+
up_block_one_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_one_sd_orig.pop(f"{i}.f_s.weight")
|
| 757 |
+
up_block_one_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_one_sd_orig.pop(f"{i}.f_s.bias")
|
| 758 |
+
|
| 759 |
+
up_block_one_sd_new["upsamplers.0.norm1.weight"] = up_block_one_sd_orig.pop("4.gn_1.weight")
|
| 760 |
+
up_block_one_sd_new["upsamplers.0.norm1.bias"] = up_block_one_sd_orig.pop("4.gn_1.bias")
|
| 761 |
+
up_block_one_sd_new["upsamplers.0.conv1.weight"] = up_block_one_sd_orig.pop("4.f_1.weight")
|
| 762 |
+
up_block_one_sd_new["upsamplers.0.conv1.bias"] = up_block_one_sd_orig.pop("4.f_1.bias")
|
| 763 |
+
up_block_one_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_one_sd_orig.pop("4.f_t.weight")
|
| 764 |
+
up_block_one_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_one_sd_orig.pop("4.f_t.bias")
|
| 765 |
+
up_block_one_sd_new["upsamplers.0.norm2.weight"] = up_block_one_sd_orig.pop("4.gn_2.weight")
|
| 766 |
+
up_block_one_sd_new["upsamplers.0.norm2.bias"] = up_block_one_sd_orig.pop("4.gn_2.bias")
|
| 767 |
+
up_block_one_sd_new["upsamplers.0.conv2.weight"] = up_block_one_sd_orig.pop("4.f_2.weight")
|
| 768 |
+
up_block_one_sd_new["upsamplers.0.conv2.bias"] = up_block_one_sd_orig.pop("4.f_2.bias")
|
| 769 |
+
|
| 770 |
+
assert len(up_block_one_sd_orig) == 0
|
| 771 |
+
|
| 772 |
+
up_block_one = ResnetUpsampleBlock2D(
|
| 773 |
+
in_channels=1024,
|
| 774 |
+
prev_output_channel=1024,
|
| 775 |
+
out_channels=1024,
|
| 776 |
+
temb_channels=1280,
|
| 777 |
+
num_layers=4,
|
| 778 |
+
add_upsample=True,
|
| 779 |
+
resnet_time_scale_shift="scale_shift",
|
| 780 |
+
resnet_eps=1e-5,
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
up_block_one.load_state_dict(up_block_one_sd_new)
|
| 784 |
+
|
| 785 |
+
print("UP BLOCK TWO")
|
| 786 |
+
|
| 787 |
+
up_block_two_sd_orig = model.up[-2].state_dict()
|
| 788 |
+
up_block_two_sd_new = {}
|
| 789 |
+
|
| 790 |
+
for i in range(4):
|
| 791 |
+
up_block_two_sd_new[f"resnets.{i}.norm1.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_1.weight")
|
| 792 |
+
up_block_two_sd_new[f"resnets.{i}.norm1.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_1.bias")
|
| 793 |
+
up_block_two_sd_new[f"resnets.{i}.conv1.weight"] = up_block_two_sd_orig.pop(f"{i}.f_1.weight")
|
| 794 |
+
up_block_two_sd_new[f"resnets.{i}.conv1.bias"] = up_block_two_sd_orig.pop(f"{i}.f_1.bias")
|
| 795 |
+
up_block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_two_sd_orig.pop(f"{i}.f_t.weight")
|
| 796 |
+
up_block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_two_sd_orig.pop(f"{i}.f_t.bias")
|
| 797 |
+
up_block_two_sd_new[f"resnets.{i}.norm2.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_2.weight")
|
| 798 |
+
up_block_two_sd_new[f"resnets.{i}.norm2.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_2.bias")
|
| 799 |
+
up_block_two_sd_new[f"resnets.{i}.conv2.weight"] = up_block_two_sd_orig.pop(f"{i}.f_2.weight")
|
| 800 |
+
up_block_two_sd_new[f"resnets.{i}.conv2.bias"] = up_block_two_sd_orig.pop(f"{i}.f_2.bias")
|
| 801 |
+
up_block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_two_sd_orig.pop(f"{i}.f_s.weight")
|
| 802 |
+
up_block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_two_sd_orig.pop(f"{i}.f_s.bias")
|
| 803 |
+
|
| 804 |
+
up_block_two_sd_new["upsamplers.0.norm1.weight"] = up_block_two_sd_orig.pop("4.gn_1.weight")
|
| 805 |
+
up_block_two_sd_new["upsamplers.0.norm1.bias"] = up_block_two_sd_orig.pop("4.gn_1.bias")
|
| 806 |
+
up_block_two_sd_new["upsamplers.0.conv1.weight"] = up_block_two_sd_orig.pop("4.f_1.weight")
|
| 807 |
+
up_block_two_sd_new["upsamplers.0.conv1.bias"] = up_block_two_sd_orig.pop("4.f_1.bias")
|
| 808 |
+
up_block_two_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_two_sd_orig.pop("4.f_t.weight")
|
| 809 |
+
up_block_two_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_two_sd_orig.pop("4.f_t.bias")
|
| 810 |
+
up_block_two_sd_new["upsamplers.0.norm2.weight"] = up_block_two_sd_orig.pop("4.gn_2.weight")
|
| 811 |
+
up_block_two_sd_new["upsamplers.0.norm2.bias"] = up_block_two_sd_orig.pop("4.gn_2.bias")
|
| 812 |
+
up_block_two_sd_new["upsamplers.0.conv2.weight"] = up_block_two_sd_orig.pop("4.f_2.weight")
|
| 813 |
+
up_block_two_sd_new["upsamplers.0.conv2.bias"] = up_block_two_sd_orig.pop("4.f_2.bias")
|
| 814 |
+
|
| 815 |
+
assert len(up_block_two_sd_orig) == 0
|
| 816 |
+
|
| 817 |
+
up_block_two = ResnetUpsampleBlock2D(
|
| 818 |
+
in_channels=640,
|
| 819 |
+
prev_output_channel=1024,
|
| 820 |
+
out_channels=1024,
|
| 821 |
+
temb_channels=1280,
|
| 822 |
+
num_layers=4,
|
| 823 |
+
add_upsample=True,
|
| 824 |
+
resnet_time_scale_shift="scale_shift",
|
| 825 |
+
resnet_eps=1e-5,
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
up_block_two.load_state_dict(up_block_two_sd_new)
|
| 829 |
+
|
| 830 |
+
print("UP BLOCK THREE")
|
| 831 |
+
|
| 832 |
+
up_block_three_sd_orig = model.up[-3].state_dict()
|
| 833 |
+
up_block_three_sd_new = {}
|
| 834 |
+
|
| 835 |
+
for i in range(4):
|
| 836 |
+
up_block_three_sd_new[f"resnets.{i}.norm1.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_1.weight")
|
| 837 |
+
up_block_three_sd_new[f"resnets.{i}.norm1.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_1.bias")
|
| 838 |
+
up_block_three_sd_new[f"resnets.{i}.conv1.weight"] = up_block_three_sd_orig.pop(f"{i}.f_1.weight")
|
| 839 |
+
up_block_three_sd_new[f"resnets.{i}.conv1.bias"] = up_block_three_sd_orig.pop(f"{i}.f_1.bias")
|
| 840 |
+
up_block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_three_sd_orig.pop(f"{i}.f_t.weight")
|
| 841 |
+
up_block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_three_sd_orig.pop(f"{i}.f_t.bias")
|
| 842 |
+
up_block_three_sd_new[f"resnets.{i}.norm2.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_2.weight")
|
| 843 |
+
up_block_three_sd_new[f"resnets.{i}.norm2.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_2.bias")
|
| 844 |
+
up_block_three_sd_new[f"resnets.{i}.conv2.weight"] = up_block_three_sd_orig.pop(f"{i}.f_2.weight")
|
| 845 |
+
up_block_three_sd_new[f"resnets.{i}.conv2.bias"] = up_block_three_sd_orig.pop(f"{i}.f_2.bias")
|
| 846 |
+
up_block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_three_sd_orig.pop(f"{i}.f_s.weight")
|
| 847 |
+
up_block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_three_sd_orig.pop(f"{i}.f_s.bias")
|
| 848 |
+
|
| 849 |
+
up_block_three_sd_new["upsamplers.0.norm1.weight"] = up_block_three_sd_orig.pop("4.gn_1.weight")
|
| 850 |
+
up_block_three_sd_new["upsamplers.0.norm1.bias"] = up_block_three_sd_orig.pop("4.gn_1.bias")
|
| 851 |
+
up_block_three_sd_new["upsamplers.0.conv1.weight"] = up_block_three_sd_orig.pop("4.f_1.weight")
|
| 852 |
+
up_block_three_sd_new["upsamplers.0.conv1.bias"] = up_block_three_sd_orig.pop("4.f_1.bias")
|
| 853 |
+
up_block_three_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_three_sd_orig.pop("4.f_t.weight")
|
| 854 |
+
up_block_three_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_three_sd_orig.pop("4.f_t.bias")
|
| 855 |
+
up_block_three_sd_new["upsamplers.0.norm2.weight"] = up_block_three_sd_orig.pop("4.gn_2.weight")
|
| 856 |
+
up_block_three_sd_new["upsamplers.0.norm2.bias"] = up_block_three_sd_orig.pop("4.gn_2.bias")
|
| 857 |
+
up_block_three_sd_new["upsamplers.0.conv2.weight"] = up_block_three_sd_orig.pop("4.f_2.weight")
|
| 858 |
+
up_block_three_sd_new["upsamplers.0.conv2.bias"] = up_block_three_sd_orig.pop("4.f_2.bias")
|
| 859 |
+
|
| 860 |
+
assert len(up_block_three_sd_orig) == 0
|
| 861 |
+
|
| 862 |
+
up_block_three = ResnetUpsampleBlock2D(
|
| 863 |
+
in_channels=320,
|
| 864 |
+
prev_output_channel=1024,
|
| 865 |
+
out_channels=640,
|
| 866 |
+
temb_channels=1280,
|
| 867 |
+
num_layers=4,
|
| 868 |
+
add_upsample=True,
|
| 869 |
+
resnet_time_scale_shift="scale_shift",
|
| 870 |
+
resnet_eps=1e-5,
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
up_block_three.load_state_dict(up_block_three_sd_new)
|
| 874 |
+
|
| 875 |
+
print("UP BLOCK FOUR")
|
| 876 |
+
|
| 877 |
+
up_block_four_sd_orig = model.up[-4].state_dict()
|
| 878 |
+
up_block_four_sd_new = {}
|
| 879 |
+
|
| 880 |
+
for i in range(4):
|
| 881 |
+
up_block_four_sd_new[f"resnets.{i}.norm1.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_1.weight")
|
| 882 |
+
up_block_four_sd_new[f"resnets.{i}.norm1.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_1.bias")
|
| 883 |
+
up_block_four_sd_new[f"resnets.{i}.conv1.weight"] = up_block_four_sd_orig.pop(f"{i}.f_1.weight")
|
| 884 |
+
up_block_four_sd_new[f"resnets.{i}.conv1.bias"] = up_block_four_sd_orig.pop(f"{i}.f_1.bias")
|
| 885 |
+
up_block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_four_sd_orig.pop(f"{i}.f_t.weight")
|
| 886 |
+
up_block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_four_sd_orig.pop(f"{i}.f_t.bias")
|
| 887 |
+
up_block_four_sd_new[f"resnets.{i}.norm2.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_2.weight")
|
| 888 |
+
up_block_four_sd_new[f"resnets.{i}.norm2.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_2.bias")
|
| 889 |
+
up_block_four_sd_new[f"resnets.{i}.conv2.weight"] = up_block_four_sd_orig.pop(f"{i}.f_2.weight")
|
| 890 |
+
up_block_four_sd_new[f"resnets.{i}.conv2.bias"] = up_block_four_sd_orig.pop(f"{i}.f_2.bias")
|
| 891 |
+
up_block_four_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_four_sd_orig.pop(f"{i}.f_s.weight")
|
| 892 |
+
up_block_four_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_four_sd_orig.pop(f"{i}.f_s.bias")
|
| 893 |
+
|
| 894 |
+
assert len(up_block_four_sd_orig) == 0
|
| 895 |
+
|
| 896 |
+
up_block_four = ResnetUpsampleBlock2D(
|
| 897 |
+
in_channels=320,
|
| 898 |
+
prev_output_channel=640,
|
| 899 |
+
out_channels=320,
|
| 900 |
+
temb_channels=1280,
|
| 901 |
+
num_layers=4,
|
| 902 |
+
add_upsample=False,
|
| 903 |
+
resnet_time_scale_shift="scale_shift",
|
| 904 |
+
resnet_eps=1e-5,
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
up_block_four.load_state_dict(up_block_four_sd_new)
|
| 908 |
+
|
| 909 |
+
print("initial projection (conv_in)")
|
| 910 |
+
|
| 911 |
+
conv_in_sd_orig = model.embed_image.state_dict()
|
| 912 |
+
conv_in_sd_new = {}
|
| 913 |
+
|
| 914 |
+
conv_in_sd_new["weight"] = conv_in_sd_orig.pop("f.weight")
|
| 915 |
+
conv_in_sd_new["bias"] = conv_in_sd_orig.pop("f.bias")
|
| 916 |
+
|
| 917 |
+
assert len(conv_in_sd_orig) == 0
|
| 918 |
+
|
| 919 |
+
block_out_channels = [320, 640, 1024, 1024]
|
| 920 |
+
|
| 921 |
+
in_channels = 7
|
| 922 |
+
conv_in_kernel = 3
|
| 923 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
| 924 |
+
conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
|
| 925 |
+
|
| 926 |
+
conv_in.load_state_dict(conv_in_sd_new)
|
| 927 |
+
|
| 928 |
+
print("out projection (conv_out) (conv_norm_out)")
|
| 929 |
+
out_channels = 6
|
| 930 |
+
norm_num_groups = 32
|
| 931 |
+
norm_eps = 1e-5
|
| 932 |
+
act_fn = "silu"
|
| 933 |
+
conv_out_kernel = 3
|
| 934 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
| 935 |
+
conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
| 936 |
+
# uses torch.functional in orig
|
| 937 |
+
# conv_act = get_activation(act_fn)
|
| 938 |
+
conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding)
|
| 939 |
+
|
| 940 |
+
conv_norm_out.load_state_dict(model.output.gn.state_dict())
|
| 941 |
+
conv_out.load_state_dict(model.output.f.state_dict())
|
| 942 |
+
|
| 943 |
+
print("timestep projection (time_proj) (time_embedding)")
|
| 944 |
+
|
| 945 |
+
f1_sd = model.embed_time.f_1.state_dict()
|
| 946 |
+
f2_sd = model.embed_time.f_2.state_dict()
|
| 947 |
+
|
| 948 |
+
time_embedding_sd = {
|
| 949 |
+
"linear_1.weight": f1_sd.pop("weight"),
|
| 950 |
+
"linear_1.bias": f1_sd.pop("bias"),
|
| 951 |
+
"linear_2.weight": f2_sd.pop("weight"),
|
| 952 |
+
"linear_2.bias": f2_sd.pop("bias"),
|
| 953 |
+
}
|
| 954 |
+
|
| 955 |
+
assert len(f1_sd) == 0
|
| 956 |
+
assert len(f2_sd) == 0
|
| 957 |
+
|
| 958 |
+
time_embedding_type = "learned"
|
| 959 |
+
num_train_timesteps = 1024
|
| 960 |
+
time_embedding_dim = 1280
|
| 961 |
+
|
| 962 |
+
time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
|
| 963 |
+
timestep_input_dim = block_out_channels[0]
|
| 964 |
+
|
| 965 |
+
time_embedding = TimestepEmbedding(timestep_input_dim, time_embedding_dim)
|
| 966 |
+
|
| 967 |
+
time_proj.load_state_dict(model.embed_time.emb.state_dict())
|
| 968 |
+
time_embedding.load_state_dict(time_embedding_sd)
|
| 969 |
+
|
| 970 |
+
print("CONVERT")
|
| 971 |
+
|
| 972 |
+
time_embedding.to("cuda")
|
| 973 |
+
time_proj.to("cuda")
|
| 974 |
+
conv_in.to("cuda")
|
| 975 |
+
|
| 976 |
+
block_one.to("cuda")
|
| 977 |
+
block_two.to("cuda")
|
| 978 |
+
block_three.to("cuda")
|
| 979 |
+
block_four.to("cuda")
|
| 980 |
+
|
| 981 |
+
mid_block_one.to("cuda")
|
| 982 |
+
|
| 983 |
+
up_block_one.to("cuda")
|
| 984 |
+
up_block_two.to("cuda")
|
| 985 |
+
up_block_three.to("cuda")
|
| 986 |
+
up_block_four.to("cuda")
|
| 987 |
+
|
| 988 |
+
conv_norm_out.to("cuda")
|
| 989 |
+
conv_out.to("cuda")
|
| 990 |
+
|
| 991 |
+
model.time_proj = time_proj
|
| 992 |
+
model.time_embedding = time_embedding
|
| 993 |
+
model.embed_image = conv_in
|
| 994 |
+
|
| 995 |
+
model.down[0] = block_one
|
| 996 |
+
model.down[1] = block_two
|
| 997 |
+
model.down[2] = block_three
|
| 998 |
+
model.down[3] = block_four
|
| 999 |
+
|
| 1000 |
+
model.mid = mid_block_one
|
| 1001 |
+
|
| 1002 |
+
model.up[-1] = up_block_one
|
| 1003 |
+
model.up[-2] = up_block_two
|
| 1004 |
+
model.up[-3] = up_block_three
|
| 1005 |
+
model.up[-4] = up_block_four
|
| 1006 |
+
|
| 1007 |
+
model.output.gn = conv_norm_out
|
| 1008 |
+
model.output.f = conv_out
|
| 1009 |
+
|
| 1010 |
+
model.converted = True
|
| 1011 |
+
|
| 1012 |
+
sample_consistency_new = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
|
| 1013 |
+
save_image(sample_consistency_new, "con_new.png")
|
| 1014 |
+
|
| 1015 |
+
assert (sample_consistency_orig == sample_consistency_new).all()
|
| 1016 |
+
|
| 1017 |
+
print("making unet")
|
| 1018 |
+
|
| 1019 |
+
unet = UNet2DModel(
|
| 1020 |
+
in_channels=in_channels,
|
| 1021 |
+
out_channels=out_channels,
|
| 1022 |
+
down_block_types=(
|
| 1023 |
+
"ResnetDownsampleBlock2D",
|
| 1024 |
+
"ResnetDownsampleBlock2D",
|
| 1025 |
+
"ResnetDownsampleBlock2D",
|
| 1026 |
+
"ResnetDownsampleBlock2D",
|
| 1027 |
+
),
|
| 1028 |
+
up_block_types=(
|
| 1029 |
+
"ResnetUpsampleBlock2D",
|
| 1030 |
+
"ResnetUpsampleBlock2D",
|
| 1031 |
+
"ResnetUpsampleBlock2D",
|
| 1032 |
+
"ResnetUpsampleBlock2D",
|
| 1033 |
+
),
|
| 1034 |
+
block_out_channels=block_out_channels,
|
| 1035 |
+
layers_per_block=3,
|
| 1036 |
+
norm_num_groups=norm_num_groups,
|
| 1037 |
+
norm_eps=norm_eps,
|
| 1038 |
+
resnet_time_scale_shift="scale_shift",
|
| 1039 |
+
time_embedding_type="learned",
|
| 1040 |
+
num_train_timesteps=num_train_timesteps,
|
| 1041 |
+
add_attention=False,
|
| 1042 |
+
)
|
| 1043 |
+
|
| 1044 |
+
unet_state_dict = {}
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
def add_state_dict(prefix, mod):
|
| 1048 |
+
for k, v in mod.state_dict().items():
|
| 1049 |
+
unet_state_dict[f"{prefix}.{k}"] = v
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
add_state_dict("conv_in", conv_in)
|
| 1053 |
+
add_state_dict("time_proj", time_proj)
|
| 1054 |
+
add_state_dict("time_embedding", time_embedding)
|
| 1055 |
+
add_state_dict("down_blocks.0", block_one)
|
| 1056 |
+
add_state_dict("down_blocks.1", block_two)
|
| 1057 |
+
add_state_dict("down_blocks.2", block_three)
|
| 1058 |
+
add_state_dict("down_blocks.3", block_four)
|
| 1059 |
+
add_state_dict("mid_block", mid_block_one)
|
| 1060 |
+
add_state_dict("up_blocks.0", up_block_one)
|
| 1061 |
+
add_state_dict("up_blocks.1", up_block_two)
|
| 1062 |
+
add_state_dict("up_blocks.2", up_block_three)
|
| 1063 |
+
add_state_dict("up_blocks.3", up_block_four)
|
| 1064 |
+
add_state_dict("conv_norm_out", conv_norm_out)
|
| 1065 |
+
add_state_dict("conv_out", conv_out)
|
| 1066 |
+
|
| 1067 |
+
unet.load_state_dict(unet_state_dict)
|
| 1068 |
+
|
| 1069 |
+
print("running with diffusers unet")
|
| 1070 |
+
|
| 1071 |
+
unet.to("cuda")
|
| 1072 |
+
|
| 1073 |
+
decoder_consistency.ckpt = unet
|
| 1074 |
+
|
| 1075 |
+
sample_consistency_new_2 = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
|
| 1076 |
+
save_image(sample_consistency_new_2, "con_new_2.png")
|
| 1077 |
+
|
| 1078 |
+
assert (sample_consistency_orig == sample_consistency_new_2).all()
|
| 1079 |
+
|
| 1080 |
+
print("running with diffusers model")
|
| 1081 |
+
|
| 1082 |
+
Encoder.old_constructor = Encoder.__init__
|
| 1083 |
+
|
| 1084 |
+
|
| 1085 |
+
def new_constructor(self, **kwargs):
|
| 1086 |
+
self.old_constructor(**kwargs)
|
| 1087 |
+
self.constructor_arguments = kwargs
|
| 1088 |
+
|
| 1089 |
+
|
| 1090 |
+
Encoder.__init__ = new_constructor
|
| 1091 |
+
|
| 1092 |
+
|
| 1093 |
+
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
|
| 1094 |
+
consistency_vae = ConsistencyDecoderVAE(
|
| 1095 |
+
encoder_args=vae.encoder.constructor_arguments,
|
| 1096 |
+
decoder_args=unet.config,
|
| 1097 |
+
scaling_factor=vae.config.scaling_factor,
|
| 1098 |
+
block_out_channels=vae.config.block_out_channels,
|
| 1099 |
+
latent_channels=vae.config.latent_channels,
|
| 1100 |
+
)
|
| 1101 |
+
consistency_vae.encoder.load_state_dict(vae.encoder.state_dict())
|
| 1102 |
+
consistency_vae.quant_conv.load_state_dict(vae.quant_conv.state_dict())
|
| 1103 |
+
consistency_vae.decoder_unet.load_state_dict(unet.state_dict())
|
| 1104 |
+
|
| 1105 |
+
consistency_vae.to(dtype=torch.float16, device="cuda")
|
| 1106 |
+
|
| 1107 |
+
sample_consistency_new_3 = consistency_vae.decode(
|
| 1108 |
+
0.18215 * latent, generator=torch.Generator("cpu").manual_seed(0)
|
| 1109 |
+
).sample
|
| 1110 |
+
|
| 1111 |
+
print("max difference")
|
| 1112 |
+
print((sample_consistency_orig - sample_consistency_new_3).abs().max())
|
| 1113 |
+
print("total difference")
|
| 1114 |
+
print((sample_consistency_orig - sample_consistency_new_3).abs().sum())
|
| 1115 |
+
# assert (sample_consistency_orig == sample_consistency_new_3).all()
|
| 1116 |
+
|
| 1117 |
+
print("running with diffusers pipeline")
|
| 1118 |
+
|
| 1119 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 1120 |
+
"runwayml/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16
|
| 1121 |
+
)
|
| 1122 |
+
pipe.to("cuda")
|
| 1123 |
+
|
| 1124 |
+
pipe("horse", generator=torch.Generator("cpu").manual_seed(0)).images[0].save("horse.png")
|
| 1125 |
+
|
| 1126 |
+
|
| 1127 |
+
if args.save_pretrained is not None:
|
| 1128 |
+
consistency_vae.save_pretrained(args.save_pretrained)
|
diffusers/scripts/convert_dance_diffusion_to_diffusers.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
import torch
|
| 9 |
+
from audio_diffusion.models import DiffusionAttnUnet1D
|
| 10 |
+
from diffusion import sampling
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
|
| 14 |
+
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
MODELS_MAP = {
|
| 18 |
+
"gwf-440k": {
|
| 19 |
+
"url": "https://model-server.zqevans2.workers.dev/gwf-440k.ckpt",
|
| 20 |
+
"sample_rate": 48000,
|
| 21 |
+
"sample_size": 65536,
|
| 22 |
+
},
|
| 23 |
+
"jmann-small-190k": {
|
| 24 |
+
"url": "https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt",
|
| 25 |
+
"sample_rate": 48000,
|
| 26 |
+
"sample_size": 65536,
|
| 27 |
+
},
|
| 28 |
+
"jmann-large-580k": {
|
| 29 |
+
"url": "https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt",
|
| 30 |
+
"sample_rate": 48000,
|
| 31 |
+
"sample_size": 131072,
|
| 32 |
+
},
|
| 33 |
+
"maestro-uncond-150k": {
|
| 34 |
+
"url": "https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt",
|
| 35 |
+
"sample_rate": 16000,
|
| 36 |
+
"sample_size": 65536,
|
| 37 |
+
},
|
| 38 |
+
"unlocked-uncond-250k": {
|
| 39 |
+
"url": "https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt",
|
| 40 |
+
"sample_rate": 16000,
|
| 41 |
+
"sample_size": 65536,
|
| 42 |
+
},
|
| 43 |
+
"honk-140k": {
|
| 44 |
+
"url": "https://model-server.zqevans2.workers.dev/honk-140k.ckpt",
|
| 45 |
+
"sample_rate": 16000,
|
| 46 |
+
"sample_size": 65536,
|
| 47 |
+
},
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def alpha_sigma_to_t(alpha, sigma):
|
| 52 |
+
"""Returns a timestep, given the scaling factors for the clean image and for
|
| 53 |
+
the noise."""
|
| 54 |
+
return torch.atan2(sigma, alpha) / math.pi * 2
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_crash_schedule(t):
|
| 58 |
+
sigma = torch.sin(t * math.pi / 2) ** 2
|
| 59 |
+
alpha = (1 - sigma**2) ** 0.5
|
| 60 |
+
return alpha_sigma_to_t(alpha, sigma)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Object(object):
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class DiffusionUncond(nn.Module):
|
| 68 |
+
def __init__(self, global_args):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4)
|
| 72 |
+
self.diffusion_ema = deepcopy(self.diffusion)
|
| 73 |
+
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def download(model_name):
|
| 77 |
+
url = MODELS_MAP[model_name]["url"]
|
| 78 |
+
r = requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT)
|
| 79 |
+
|
| 80 |
+
local_filename = f"./{model_name}.ckpt"
|
| 81 |
+
with open(local_filename, "wb") as fp:
|
| 82 |
+
for chunk in r.iter_content(chunk_size=8192):
|
| 83 |
+
fp.write(chunk)
|
| 84 |
+
|
| 85 |
+
return local_filename
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
DOWN_NUM_TO_LAYER = {
|
| 89 |
+
"1": "resnets.0",
|
| 90 |
+
"2": "attentions.0",
|
| 91 |
+
"3": "resnets.1",
|
| 92 |
+
"4": "attentions.1",
|
| 93 |
+
"5": "resnets.2",
|
| 94 |
+
"6": "attentions.2",
|
| 95 |
+
}
|
| 96 |
+
UP_NUM_TO_LAYER = {
|
| 97 |
+
"8": "resnets.0",
|
| 98 |
+
"9": "attentions.0",
|
| 99 |
+
"10": "resnets.1",
|
| 100 |
+
"11": "attentions.1",
|
| 101 |
+
"12": "resnets.2",
|
| 102 |
+
"13": "attentions.2",
|
| 103 |
+
}
|
| 104 |
+
MID_NUM_TO_LAYER = {
|
| 105 |
+
"1": "resnets.0",
|
| 106 |
+
"2": "attentions.0",
|
| 107 |
+
"3": "resnets.1",
|
| 108 |
+
"4": "attentions.1",
|
| 109 |
+
"5": "resnets.2",
|
| 110 |
+
"6": "attentions.2",
|
| 111 |
+
"8": "resnets.3",
|
| 112 |
+
"9": "attentions.3",
|
| 113 |
+
"10": "resnets.4",
|
| 114 |
+
"11": "attentions.4",
|
| 115 |
+
"12": "resnets.5",
|
| 116 |
+
"13": "attentions.5",
|
| 117 |
+
}
|
| 118 |
+
DEPTH_0_TO_LAYER = {
|
| 119 |
+
"0": "resnets.0",
|
| 120 |
+
"1": "resnets.1",
|
| 121 |
+
"2": "resnets.2",
|
| 122 |
+
"4": "resnets.0",
|
| 123 |
+
"5": "resnets.1",
|
| 124 |
+
"6": "resnets.2",
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
RES_CONV_MAP = {
|
| 128 |
+
"skip": "conv_skip",
|
| 129 |
+
"main.0": "conv_1",
|
| 130 |
+
"main.1": "group_norm_1",
|
| 131 |
+
"main.3": "conv_2",
|
| 132 |
+
"main.4": "group_norm_2",
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
ATTN_MAP = {
|
| 136 |
+
"norm": "group_norm",
|
| 137 |
+
"qkv_proj": ["query", "key", "value"],
|
| 138 |
+
"out_proj": ["proj_attn"],
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def convert_resconv_naming(name):
|
| 143 |
+
if name.startswith("skip"):
|
| 144 |
+
return name.replace("skip", RES_CONV_MAP["skip"])
|
| 145 |
+
|
| 146 |
+
# name has to be of format main.{digit}
|
| 147 |
+
if not name.startswith("main."):
|
| 148 |
+
raise ValueError(f"ResConvBlock error with {name}")
|
| 149 |
+
|
| 150 |
+
return name.replace(name[:6], RES_CONV_MAP[name[:6]])
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def convert_attn_naming(name):
|
| 154 |
+
for key, value in ATTN_MAP.items():
|
| 155 |
+
if name.startswith(key) and not isinstance(value, list):
|
| 156 |
+
return name.replace(key, value)
|
| 157 |
+
elif name.startswith(key):
|
| 158 |
+
return [name.replace(key, v) for v in value]
|
| 159 |
+
raise ValueError(f"Attn error with {name}")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def rename(input_string, max_depth=13):
|
| 163 |
+
string = input_string
|
| 164 |
+
|
| 165 |
+
if string.split(".")[0] == "timestep_embed":
|
| 166 |
+
return string.replace("timestep_embed", "time_proj")
|
| 167 |
+
|
| 168 |
+
depth = 0
|
| 169 |
+
if string.startswith("net.3."):
|
| 170 |
+
depth += 1
|
| 171 |
+
string = string[6:]
|
| 172 |
+
elif string.startswith("net."):
|
| 173 |
+
string = string[4:]
|
| 174 |
+
|
| 175 |
+
while string.startswith("main.7."):
|
| 176 |
+
depth += 1
|
| 177 |
+
string = string[7:]
|
| 178 |
+
|
| 179 |
+
if string.startswith("main."):
|
| 180 |
+
string = string[5:]
|
| 181 |
+
|
| 182 |
+
# mid block
|
| 183 |
+
if string[:2].isdigit():
|
| 184 |
+
layer_num = string[:2]
|
| 185 |
+
string_left = string[2:]
|
| 186 |
+
else:
|
| 187 |
+
layer_num = string[0]
|
| 188 |
+
string_left = string[1:]
|
| 189 |
+
|
| 190 |
+
if depth == max_depth:
|
| 191 |
+
new_layer = MID_NUM_TO_LAYER[layer_num]
|
| 192 |
+
prefix = "mid_block"
|
| 193 |
+
elif depth > 0 and int(layer_num) < 7:
|
| 194 |
+
new_layer = DOWN_NUM_TO_LAYER[layer_num]
|
| 195 |
+
prefix = f"down_blocks.{depth}"
|
| 196 |
+
elif depth > 0 and int(layer_num) > 7:
|
| 197 |
+
new_layer = UP_NUM_TO_LAYER[layer_num]
|
| 198 |
+
prefix = f"up_blocks.{max_depth - depth - 1}"
|
| 199 |
+
elif depth == 0:
|
| 200 |
+
new_layer = DEPTH_0_TO_LAYER[layer_num]
|
| 201 |
+
prefix = f"up_blocks.{max_depth - 1}" if int(layer_num) > 3 else "down_blocks.0"
|
| 202 |
+
|
| 203 |
+
if not string_left.startswith("."):
|
| 204 |
+
raise ValueError(f"Naming error with {input_string} and string_left: {string_left}.")
|
| 205 |
+
|
| 206 |
+
string_left = string_left[1:]
|
| 207 |
+
|
| 208 |
+
if "resnets" in new_layer:
|
| 209 |
+
string_left = convert_resconv_naming(string_left)
|
| 210 |
+
elif "attentions" in new_layer:
|
| 211 |
+
new_string_left = convert_attn_naming(string_left)
|
| 212 |
+
string_left = new_string_left
|
| 213 |
+
|
| 214 |
+
if not isinstance(string_left, list):
|
| 215 |
+
new_string = prefix + "." + new_layer + "." + string_left
|
| 216 |
+
else:
|
| 217 |
+
new_string = [prefix + "." + new_layer + "." + s for s in string_left]
|
| 218 |
+
return new_string
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def rename_orig_weights(state_dict):
|
| 222 |
+
new_state_dict = {}
|
| 223 |
+
for k, v in state_dict.items():
|
| 224 |
+
if k.endswith("kernel"):
|
| 225 |
+
# up- and downsample layers, don't have trainable weights
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
new_k = rename(k)
|
| 229 |
+
|
| 230 |
+
# check if we need to transform from Conv => Linear for attention
|
| 231 |
+
if isinstance(new_k, list):
|
| 232 |
+
new_state_dict = transform_conv_attns(new_state_dict, new_k, v)
|
| 233 |
+
else:
|
| 234 |
+
new_state_dict[new_k] = v
|
| 235 |
+
|
| 236 |
+
return new_state_dict
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def transform_conv_attns(new_state_dict, new_k, v):
|
| 240 |
+
if len(new_k) == 1:
|
| 241 |
+
if len(v.shape) == 3:
|
| 242 |
+
# weight
|
| 243 |
+
new_state_dict[new_k[0]] = v[:, :, 0]
|
| 244 |
+
else:
|
| 245 |
+
# bias
|
| 246 |
+
new_state_dict[new_k[0]] = v
|
| 247 |
+
else:
|
| 248 |
+
# qkv matrices
|
| 249 |
+
trippled_shape = v.shape[0]
|
| 250 |
+
single_shape = trippled_shape // 3
|
| 251 |
+
for i in range(3):
|
| 252 |
+
if len(v.shape) == 3:
|
| 253 |
+
new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape, :, 0]
|
| 254 |
+
else:
|
| 255 |
+
new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape]
|
| 256 |
+
return new_state_dict
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def main(args):
|
| 260 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 261 |
+
|
| 262 |
+
model_name = args.model_path.split("/")[-1].split(".")[0]
|
| 263 |
+
if not os.path.isfile(args.model_path):
|
| 264 |
+
assert model_name == args.model_path, (
|
| 265 |
+
f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
|
| 266 |
+
)
|
| 267 |
+
args.model_path = download(model_name)
|
| 268 |
+
|
| 269 |
+
sample_rate = MODELS_MAP[model_name]["sample_rate"]
|
| 270 |
+
sample_size = MODELS_MAP[model_name]["sample_size"]
|
| 271 |
+
|
| 272 |
+
config = Object()
|
| 273 |
+
config.sample_size = sample_size
|
| 274 |
+
config.sample_rate = sample_rate
|
| 275 |
+
config.latent_dim = 0
|
| 276 |
+
|
| 277 |
+
diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate)
|
| 278 |
+
diffusers_state_dict = diffusers_model.state_dict()
|
| 279 |
+
|
| 280 |
+
orig_model = DiffusionUncond(config)
|
| 281 |
+
orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"])
|
| 282 |
+
orig_model = orig_model.diffusion_ema.eval()
|
| 283 |
+
orig_model_state_dict = orig_model.state_dict()
|
| 284 |
+
renamed_state_dict = rename_orig_weights(orig_model_state_dict)
|
| 285 |
+
|
| 286 |
+
renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys())
|
| 287 |
+
diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys())
|
| 288 |
+
|
| 289 |
+
assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}"
|
| 290 |
+
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
|
| 291 |
+
|
| 292 |
+
for key, value in renamed_state_dict.items():
|
| 293 |
+
assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (
|
| 294 |
+
f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
|
| 295 |
+
)
|
| 296 |
+
if key == "time_proj.weight":
|
| 297 |
+
value = value.squeeze()
|
| 298 |
+
|
| 299 |
+
diffusers_state_dict[key] = value
|
| 300 |
+
|
| 301 |
+
diffusers_model.load_state_dict(diffusers_state_dict)
|
| 302 |
+
|
| 303 |
+
steps = 100
|
| 304 |
+
seed = 33
|
| 305 |
+
|
| 306 |
+
diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps)
|
| 307 |
+
|
| 308 |
+
generator = torch.manual_seed(seed)
|
| 309 |
+
noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device)
|
| 310 |
+
|
| 311 |
+
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
|
| 312 |
+
step_list = get_crash_schedule(t)
|
| 313 |
+
|
| 314 |
+
pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler)
|
| 315 |
+
|
| 316 |
+
generator = torch.manual_seed(33)
|
| 317 |
+
audio = pipe(num_inference_steps=steps, generator=generator).audios
|
| 318 |
+
|
| 319 |
+
generated = sampling.iplms_sample(orig_model, noise, step_list, {})
|
| 320 |
+
generated = generated.clamp(-1, 1)
|
| 321 |
+
|
| 322 |
+
diff_sum = (generated - audio).abs().sum()
|
| 323 |
+
diff_max = (generated - audio).abs().max()
|
| 324 |
+
|
| 325 |
+
if args.save:
|
| 326 |
+
pipe.save_pretrained(args.checkpoint_path)
|
| 327 |
+
|
| 328 |
+
print("Diff sum", diff_sum)
|
| 329 |
+
print("Diff max", diff_max)
|
| 330 |
+
|
| 331 |
+
assert diff_max < 1e-3, f"Diff max: {diff_max} is too much :-/"
|
| 332 |
+
|
| 333 |
+
print(f"Conversion for {model_name} successful!")
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
if __name__ == "__main__":
|
| 337 |
+
parser = argparse.ArgumentParser()
|
| 338 |
+
|
| 339 |
+
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
|
| 340 |
+
parser.add_argument(
|
| 341 |
+
"--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."
|
| 342 |
+
)
|
| 343 |
+
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 344 |
+
args = parser.parse_args()
|
| 345 |
+
|
| 346 |
+
main(args)
|
diffusers/scripts/convert_dcae_to_diffusers.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from typing import Any, Dict
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
from safetensors.torch import load_file
|
| 7 |
+
|
| 8 |
+
from diffusers import AutoencoderDC
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def remap_qkv_(key: str, state_dict: Dict[str, Any]):
|
| 12 |
+
qkv = state_dict.pop(key)
|
| 13 |
+
q, k, v = torch.chunk(qkv, 3, dim=0)
|
| 14 |
+
parent_module, _, _ = key.rpartition(".qkv.conv.weight")
|
| 15 |
+
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
|
| 16 |
+
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
|
| 17 |
+
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def remap_proj_conv_(key: str, state_dict: Dict[str, Any]):
|
| 21 |
+
parent_module, _, _ = key.rpartition(".proj.conv.weight")
|
| 22 |
+
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
AE_KEYS_RENAME_DICT = {
|
| 26 |
+
# common
|
| 27 |
+
"main.": "",
|
| 28 |
+
"op_list.": "",
|
| 29 |
+
"context_module": "attn",
|
| 30 |
+
"local_module": "conv_out",
|
| 31 |
+
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
|
| 32 |
+
# If there were more scales, there would be more layers, so a loop would be better to handle this
|
| 33 |
+
"aggreg.0.0": "to_qkv_multiscale.0.proj_in",
|
| 34 |
+
"aggreg.0.1": "to_qkv_multiscale.0.proj_out",
|
| 35 |
+
"depth_conv.conv": "conv_depth",
|
| 36 |
+
"inverted_conv.conv": "conv_inverted",
|
| 37 |
+
"point_conv.conv": "conv_point",
|
| 38 |
+
"point_conv.norm": "norm",
|
| 39 |
+
"conv.conv.": "conv.",
|
| 40 |
+
"conv1.conv": "conv1",
|
| 41 |
+
"conv2.conv": "conv2",
|
| 42 |
+
"conv2.norm": "norm",
|
| 43 |
+
"proj.norm": "norm_out",
|
| 44 |
+
# encoder
|
| 45 |
+
"encoder.project_in.conv": "encoder.conv_in",
|
| 46 |
+
"encoder.project_out.0.conv": "encoder.conv_out",
|
| 47 |
+
"encoder.stages": "encoder.down_blocks",
|
| 48 |
+
# decoder
|
| 49 |
+
"decoder.project_in.conv": "decoder.conv_in",
|
| 50 |
+
"decoder.project_out.0": "decoder.norm_out",
|
| 51 |
+
"decoder.project_out.2.conv": "decoder.conv_out",
|
| 52 |
+
"decoder.stages": "decoder.up_blocks",
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
AE_F32C32_KEYS = {
|
| 56 |
+
# encoder
|
| 57 |
+
"encoder.project_in.conv": "encoder.conv_in.conv",
|
| 58 |
+
# decoder
|
| 59 |
+
"decoder.project_out.2.conv": "decoder.conv_out.conv",
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
AE_F64C128_KEYS = {
|
| 63 |
+
# encoder
|
| 64 |
+
"encoder.project_in.conv": "encoder.conv_in.conv",
|
| 65 |
+
# decoder
|
| 66 |
+
"decoder.project_out.2.conv": "decoder.conv_out.conv",
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
AE_F128C512_KEYS = {
|
| 70 |
+
# encoder
|
| 71 |
+
"encoder.project_in.conv": "encoder.conv_in.conv",
|
| 72 |
+
# decoder
|
| 73 |
+
"decoder.project_out.2.conv": "decoder.conv_out.conv",
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
AE_SPECIAL_KEYS_REMAP = {
|
| 77 |
+
"qkv.conv.weight": remap_qkv_,
|
| 78 |
+
"proj.conv.weight": remap_proj_conv_,
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
| 83 |
+
state_dict = saved_dict
|
| 84 |
+
if "model" in saved_dict.keys():
|
| 85 |
+
state_dict = state_dict["model"]
|
| 86 |
+
if "module" in saved_dict.keys():
|
| 87 |
+
state_dict = state_dict["module"]
|
| 88 |
+
if "state_dict" in saved_dict.keys():
|
| 89 |
+
state_dict = state_dict["state_dict"]
|
| 90 |
+
return state_dict
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
| 94 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def convert_ae(config_name: str, dtype: torch.dtype):
|
| 98 |
+
config = get_ae_config(config_name)
|
| 99 |
+
hub_id = f"mit-han-lab/{config_name}"
|
| 100 |
+
ckpt_path = hf_hub_download(hub_id, "model.safetensors")
|
| 101 |
+
original_state_dict = get_state_dict(load_file(ckpt_path))
|
| 102 |
+
|
| 103 |
+
ae = AutoencoderDC(**config).to(dtype=dtype)
|
| 104 |
+
|
| 105 |
+
for key in list(original_state_dict.keys()):
|
| 106 |
+
new_key = key[:]
|
| 107 |
+
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
|
| 108 |
+
new_key = new_key.replace(replace_key, rename_key)
|
| 109 |
+
update_state_dict_(original_state_dict, key, new_key)
|
| 110 |
+
|
| 111 |
+
for key in list(original_state_dict.keys()):
|
| 112 |
+
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
|
| 113 |
+
if special_key not in key:
|
| 114 |
+
continue
|
| 115 |
+
handler_fn_inplace(key, original_state_dict)
|
| 116 |
+
|
| 117 |
+
ae.load_state_dict(original_state_dict, strict=True)
|
| 118 |
+
return ae
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_ae_config(name: str):
|
| 122 |
+
if name in ["dc-ae-f32c32-sana-1.0"]:
|
| 123 |
+
config = {
|
| 124 |
+
"latent_channels": 32,
|
| 125 |
+
"encoder_block_types": (
|
| 126 |
+
"ResBlock",
|
| 127 |
+
"ResBlock",
|
| 128 |
+
"ResBlock",
|
| 129 |
+
"EfficientViTBlock",
|
| 130 |
+
"EfficientViTBlock",
|
| 131 |
+
"EfficientViTBlock",
|
| 132 |
+
),
|
| 133 |
+
"decoder_block_types": (
|
| 134 |
+
"ResBlock",
|
| 135 |
+
"ResBlock",
|
| 136 |
+
"ResBlock",
|
| 137 |
+
"EfficientViTBlock",
|
| 138 |
+
"EfficientViTBlock",
|
| 139 |
+
"EfficientViTBlock",
|
| 140 |
+
),
|
| 141 |
+
"encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
|
| 142 |
+
"decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
|
| 143 |
+
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
|
| 144 |
+
"decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
|
| 145 |
+
"encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
|
| 146 |
+
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
|
| 147 |
+
"downsample_block_type": "conv",
|
| 148 |
+
"upsample_block_type": "interpolate",
|
| 149 |
+
"decoder_norm_types": "rms_norm",
|
| 150 |
+
"decoder_act_fns": "silu",
|
| 151 |
+
"scaling_factor": 0.41407,
|
| 152 |
+
}
|
| 153 |
+
elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
|
| 154 |
+
AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS)
|
| 155 |
+
config = {
|
| 156 |
+
"latent_channels": 32,
|
| 157 |
+
"encoder_block_types": [
|
| 158 |
+
"ResBlock",
|
| 159 |
+
"ResBlock",
|
| 160 |
+
"ResBlock",
|
| 161 |
+
"EfficientViTBlock",
|
| 162 |
+
"EfficientViTBlock",
|
| 163 |
+
"EfficientViTBlock",
|
| 164 |
+
],
|
| 165 |
+
"decoder_block_types": [
|
| 166 |
+
"ResBlock",
|
| 167 |
+
"ResBlock",
|
| 168 |
+
"ResBlock",
|
| 169 |
+
"EfficientViTBlock",
|
| 170 |
+
"EfficientViTBlock",
|
| 171 |
+
"EfficientViTBlock",
|
| 172 |
+
],
|
| 173 |
+
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
|
| 174 |
+
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
|
| 175 |
+
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
|
| 176 |
+
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
|
| 177 |
+
"encoder_qkv_multiscales": ((), (), (), (), (), ()),
|
| 178 |
+
"decoder_qkv_multiscales": ((), (), (), (), (), ()),
|
| 179 |
+
"decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
|
| 180 |
+
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
|
| 181 |
+
}
|
| 182 |
+
if name == "dc-ae-f32c32-in-1.0":
|
| 183 |
+
config["scaling_factor"] = 0.3189
|
| 184 |
+
elif name == "dc-ae-f32c32-mix-1.0":
|
| 185 |
+
config["scaling_factor"] = 0.4552
|
| 186 |
+
elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
|
| 187 |
+
AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS)
|
| 188 |
+
config = {
|
| 189 |
+
"latent_channels": 128,
|
| 190 |
+
"encoder_block_types": [
|
| 191 |
+
"ResBlock",
|
| 192 |
+
"ResBlock",
|
| 193 |
+
"ResBlock",
|
| 194 |
+
"EfficientViTBlock",
|
| 195 |
+
"EfficientViTBlock",
|
| 196 |
+
"EfficientViTBlock",
|
| 197 |
+
"EfficientViTBlock",
|
| 198 |
+
],
|
| 199 |
+
"decoder_block_types": [
|
| 200 |
+
"ResBlock",
|
| 201 |
+
"ResBlock",
|
| 202 |
+
"ResBlock",
|
| 203 |
+
"EfficientViTBlock",
|
| 204 |
+
"EfficientViTBlock",
|
| 205 |
+
"EfficientViTBlock",
|
| 206 |
+
"EfficientViTBlock",
|
| 207 |
+
],
|
| 208 |
+
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
|
| 209 |
+
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
|
| 210 |
+
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
|
| 211 |
+
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
|
| 212 |
+
"encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
|
| 213 |
+
"decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
|
| 214 |
+
"decoder_norm_types": [
|
| 215 |
+
"batch_norm",
|
| 216 |
+
"batch_norm",
|
| 217 |
+
"batch_norm",
|
| 218 |
+
"rms_norm",
|
| 219 |
+
"rms_norm",
|
| 220 |
+
"rms_norm",
|
| 221 |
+
"rms_norm",
|
| 222 |
+
],
|
| 223 |
+
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
|
| 224 |
+
}
|
| 225 |
+
if name == "dc-ae-f64c128-in-1.0":
|
| 226 |
+
config["scaling_factor"] = 0.2889
|
| 227 |
+
elif name == "dc-ae-f64c128-mix-1.0":
|
| 228 |
+
config["scaling_factor"] = 0.4538
|
| 229 |
+
elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
|
| 230 |
+
AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS)
|
| 231 |
+
config = {
|
| 232 |
+
"latent_channels": 512,
|
| 233 |
+
"encoder_block_types": [
|
| 234 |
+
"ResBlock",
|
| 235 |
+
"ResBlock",
|
| 236 |
+
"ResBlock",
|
| 237 |
+
"EfficientViTBlock",
|
| 238 |
+
"EfficientViTBlock",
|
| 239 |
+
"EfficientViTBlock",
|
| 240 |
+
"EfficientViTBlock",
|
| 241 |
+
"EfficientViTBlock",
|
| 242 |
+
],
|
| 243 |
+
"decoder_block_types": [
|
| 244 |
+
"ResBlock",
|
| 245 |
+
"ResBlock",
|
| 246 |
+
"ResBlock",
|
| 247 |
+
"EfficientViTBlock",
|
| 248 |
+
"EfficientViTBlock",
|
| 249 |
+
"EfficientViTBlock",
|
| 250 |
+
"EfficientViTBlock",
|
| 251 |
+
"EfficientViTBlock",
|
| 252 |
+
],
|
| 253 |
+
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
|
| 254 |
+
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
|
| 255 |
+
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
|
| 256 |
+
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
|
| 257 |
+
"encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
|
| 258 |
+
"decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
|
| 259 |
+
"decoder_norm_types": [
|
| 260 |
+
"batch_norm",
|
| 261 |
+
"batch_norm",
|
| 262 |
+
"batch_norm",
|
| 263 |
+
"rms_norm",
|
| 264 |
+
"rms_norm",
|
| 265 |
+
"rms_norm",
|
| 266 |
+
"rms_norm",
|
| 267 |
+
"rms_norm",
|
| 268 |
+
],
|
| 269 |
+
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
|
| 270 |
+
}
|
| 271 |
+
if name == "dc-ae-f128c512-in-1.0":
|
| 272 |
+
config["scaling_factor"] = 0.4883
|
| 273 |
+
elif name == "dc-ae-f128c512-mix-1.0":
|
| 274 |
+
config["scaling_factor"] = 0.3620
|
| 275 |
+
else:
|
| 276 |
+
raise ValueError("Invalid config name provided.")
|
| 277 |
+
|
| 278 |
+
return config
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def get_args():
|
| 282 |
+
parser = argparse.ArgumentParser()
|
| 283 |
+
parser.add_argument(
|
| 284 |
+
"--config_name",
|
| 285 |
+
type=str,
|
| 286 |
+
default="dc-ae-f32c32-sana-1.0",
|
| 287 |
+
choices=[
|
| 288 |
+
"dc-ae-f32c32-sana-1.0",
|
| 289 |
+
"dc-ae-f32c32-in-1.0",
|
| 290 |
+
"dc-ae-f32c32-mix-1.0",
|
| 291 |
+
"dc-ae-f64c128-in-1.0",
|
| 292 |
+
"dc-ae-f64c128-mix-1.0",
|
| 293 |
+
"dc-ae-f128c512-in-1.0",
|
| 294 |
+
"dc-ae-f128c512-mix-1.0",
|
| 295 |
+
],
|
| 296 |
+
help="The DCAE checkpoint to convert",
|
| 297 |
+
)
|
| 298 |
+
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
| 299 |
+
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
| 300 |
+
return parser.parse_args()
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
DTYPE_MAPPING = {
|
| 304 |
+
"fp32": torch.float32,
|
| 305 |
+
"fp16": torch.float16,
|
| 306 |
+
"bf16": torch.bfloat16,
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
VARIANT_MAPPING = {
|
| 310 |
+
"fp32": None,
|
| 311 |
+
"fp16": "fp16",
|
| 312 |
+
"bf16": "bf16",
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
if __name__ == "__main__":
|
| 317 |
+
args = get_args()
|
| 318 |
+
|
| 319 |
+
dtype = DTYPE_MAPPING[args.dtype]
|
| 320 |
+
variant = VARIANT_MAPPING[args.dtype]
|
| 321 |
+
|
| 322 |
+
ae = convert_ae(args.config_name, dtype)
|
| 323 |
+
ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
|
diffusers/scripts/convert_diffusers_sdxl_lora_to_webui.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format
|
| 2 |
+
# This means that you can input your diffusers-trained LoRAs and
|
| 3 |
+
# Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others.
|
| 4 |
+
|
| 5 |
+
# To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy
|
| 6 |
+
# https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file
|
| 7 |
+
# and run the script:
|
| 8 |
+
# python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors
|
| 9 |
+
# now you can use corgy.safetensors in your WebUI of choice!
|
| 10 |
+
|
| 11 |
+
# To train your own, here are some diffusers training scripts and utils that you can use and then convert:
|
| 12 |
+
# LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease
|
| 13 |
+
# Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer:
|
| 14 |
+
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
|
| 15 |
+
# - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb
|
| 16 |
+
# Canonical diffusers training scripts:
|
| 17 |
+
# - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
|
| 18 |
+
# - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import os
|
| 22 |
+
|
| 23 |
+
from safetensors.torch import load_file, save_file
|
| 24 |
+
|
| 25 |
+
from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def convert_and_save(input_lora, output_lora=None):
|
| 29 |
+
if output_lora is None:
|
| 30 |
+
base_name = os.path.splitext(input_lora)[0]
|
| 31 |
+
output_lora = f"{base_name}_webui.safetensors"
|
| 32 |
+
|
| 33 |
+
diffusers_state_dict = load_file(input_lora)
|
| 34 |
+
peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict)
|
| 35 |
+
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
| 36 |
+
save_file(kohya_state_dict, output_lora)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format.")
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--input_lora",
|
| 43 |
+
type=str,
|
| 44 |
+
required=True,
|
| 45 |
+
help="Path to the input LoRA model file in the diffusers format.",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--output_lora",
|
| 49 |
+
type=str,
|
| 50 |
+
required=False,
|
| 51 |
+
help="Path for the converted LoRA (safetensors format for AUTOMATIC1111, ComfyUI, etc.). Optional, defaults to input name with a _webui suffix.",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
args = parser.parse_args()
|
| 55 |
+
|
| 56 |
+
convert_and_save(args.input_lora, args.output_lora)
|
diffusers/scripts/convert_flux_xlabs_ipadapter_to_diffusers.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from contextlib import nullcontext
|
| 3 |
+
|
| 4 |
+
import safetensors.torch
|
| 5 |
+
from accelerate import init_empty_weights
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
|
| 8 |
+
from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
if is_transformers_available():
|
| 12 |
+
from transformers import CLIPVisionModelWithProjection
|
| 13 |
+
|
| 14 |
+
vision = True
|
| 15 |
+
else:
|
| 16 |
+
vision = False
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
|
| 20 |
+
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
|
| 21 |
+
--filename "flux-ip-adapter.safetensors"
|
| 22 |
+
--output_path "flux-ip-adapter-hf/"
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
| 27 |
+
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
|
| 30 |
+
parser.add_argument("--filename", default="flux.safetensors", type=str)
|
| 31 |
+
parser.add_argument("--checkpoint_path", default=None, type=str)
|
| 32 |
+
parser.add_argument("--output_path", type=str)
|
| 33 |
+
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
|
| 34 |
+
|
| 35 |
+
args = parser.parse_args()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_original_checkpoint(args):
|
| 39 |
+
if args.original_state_dict_repo_id is not None:
|
| 40 |
+
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
|
| 41 |
+
elif args.checkpoint_path is not None:
|
| 42 |
+
ckpt_path = args.checkpoint_path
|
| 43 |
+
else:
|
| 44 |
+
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
| 45 |
+
|
| 46 |
+
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
| 47 |
+
return original_state_dict
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
|
| 51 |
+
converted_state_dict = {}
|
| 52 |
+
|
| 53 |
+
# image_proj
|
| 54 |
+
## norm
|
| 55 |
+
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
|
| 56 |
+
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
|
| 57 |
+
## proj
|
| 58 |
+
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
|
| 59 |
+
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
|
| 60 |
+
|
| 61 |
+
# double transformer blocks
|
| 62 |
+
for i in range(num_layers):
|
| 63 |
+
block_prefix = f"ip_adapter.{i}."
|
| 64 |
+
# to_k_ip
|
| 65 |
+
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
|
| 66 |
+
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
|
| 67 |
+
)
|
| 68 |
+
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
|
| 69 |
+
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
|
| 70 |
+
)
|
| 71 |
+
# to_v_ip
|
| 72 |
+
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
|
| 73 |
+
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
|
| 74 |
+
)
|
| 75 |
+
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
|
| 76 |
+
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return converted_state_dict
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def main(args):
|
| 83 |
+
original_ckpt = load_original_checkpoint(args)
|
| 84 |
+
|
| 85 |
+
num_layers = 19
|
| 86 |
+
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
|
| 87 |
+
|
| 88 |
+
print("Saving Flux IP-Adapter in Diffusers format.")
|
| 89 |
+
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
|
| 90 |
+
|
| 91 |
+
if vision:
|
| 92 |
+
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
|
| 93 |
+
model.save_pretrained(f"{args.output_path}/image_encoder")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
main(args)
|
diffusers/scripts/convert_hunyuandit_controlnet_to_diffusers.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from diffusers import HunyuanDiT2DControlNetModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main(args):
|
| 9 |
+
state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
|
| 10 |
+
|
| 11 |
+
if args.load_key != "none":
|
| 12 |
+
try:
|
| 13 |
+
state_dict = state_dict[args.load_key]
|
| 14 |
+
except KeyError:
|
| 15 |
+
raise KeyError(
|
| 16 |
+
f"{args.load_key} not found in the checkpoint."
|
| 17 |
+
"Please load from the following keys:{state_dict.keys()}"
|
| 18 |
+
)
|
| 19 |
+
device = "cuda"
|
| 20 |
+
|
| 21 |
+
model_config = HunyuanDiT2DControlNetModel.load_config(
|
| 22 |
+
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
|
| 23 |
+
)
|
| 24 |
+
model_config["use_style_cond_and_image_meta_size"] = (
|
| 25 |
+
args.use_style_cond_and_image_meta_size
|
| 26 |
+
) ### version <= v1.1: True; version >= v1.2: False
|
| 27 |
+
print(model_config)
|
| 28 |
+
|
| 29 |
+
for key in state_dict:
|
| 30 |
+
print("local:", key)
|
| 31 |
+
|
| 32 |
+
model = HunyuanDiT2DControlNetModel.from_config(model_config).to(device)
|
| 33 |
+
|
| 34 |
+
for key in model.state_dict():
|
| 35 |
+
print("diffusers:", key)
|
| 36 |
+
|
| 37 |
+
num_layers = 19
|
| 38 |
+
for i in range(num_layers):
|
| 39 |
+
# attn1
|
| 40 |
+
# Wkqv -> to_q, to_k, to_v
|
| 41 |
+
q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
|
| 42 |
+
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
|
| 43 |
+
state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
|
| 44 |
+
state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
|
| 45 |
+
state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
|
| 46 |
+
state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
|
| 47 |
+
state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
|
| 48 |
+
state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
|
| 49 |
+
state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
|
| 50 |
+
state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
|
| 51 |
+
|
| 52 |
+
# q_norm, k_norm -> norm_q, norm_k
|
| 53 |
+
state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
|
| 54 |
+
state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
|
| 55 |
+
state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
|
| 56 |
+
state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
|
| 57 |
+
|
| 58 |
+
state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
|
| 59 |
+
state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
|
| 60 |
+
state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
|
| 61 |
+
state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
|
| 62 |
+
|
| 63 |
+
# out_proj -> to_out
|
| 64 |
+
state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
|
| 65 |
+
state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
|
| 66 |
+
state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
|
| 67 |
+
state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
|
| 68 |
+
|
| 69 |
+
# attn2
|
| 70 |
+
# kq_proj -> to_k, to_v
|
| 71 |
+
k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
|
| 72 |
+
k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
|
| 73 |
+
state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
|
| 74 |
+
state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
|
| 75 |
+
state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
|
| 76 |
+
state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
|
| 77 |
+
state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
|
| 78 |
+
state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
|
| 79 |
+
|
| 80 |
+
# q_proj -> to_q
|
| 81 |
+
state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
|
| 82 |
+
state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
|
| 83 |
+
state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
|
| 84 |
+
state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
|
| 85 |
+
|
| 86 |
+
# q_norm, k_norm -> norm_q, norm_k
|
| 87 |
+
state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
|
| 88 |
+
state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
|
| 89 |
+
state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
|
| 90 |
+
state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
|
| 91 |
+
|
| 92 |
+
state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
|
| 93 |
+
state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
|
| 94 |
+
state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
|
| 95 |
+
state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
|
| 96 |
+
|
| 97 |
+
# out_proj -> to_out
|
| 98 |
+
state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
|
| 99 |
+
state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
|
| 100 |
+
state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
|
| 101 |
+
state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
|
| 102 |
+
|
| 103 |
+
# switch norm 2 and norm 3
|
| 104 |
+
norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
|
| 105 |
+
norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
|
| 106 |
+
state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
|
| 107 |
+
state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
|
| 108 |
+
state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
|
| 109 |
+
state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
|
| 110 |
+
|
| 111 |
+
# norm1 -> norm1.norm
|
| 112 |
+
# default_modulation.1 -> norm1.linear
|
| 113 |
+
state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]
|
| 114 |
+
state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]
|
| 115 |
+
state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]
|
| 116 |
+
state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]
|
| 117 |
+
state_dict.pop(f"blocks.{i}.norm1.weight")
|
| 118 |
+
state_dict.pop(f"blocks.{i}.norm1.bias")
|
| 119 |
+
state_dict.pop(f"blocks.{i}.default_modulation.1.weight")
|
| 120 |
+
state_dict.pop(f"blocks.{i}.default_modulation.1.bias")
|
| 121 |
+
|
| 122 |
+
# mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2
|
| 123 |
+
state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]
|
| 124 |
+
state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]
|
| 125 |
+
state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]
|
| 126 |
+
state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]
|
| 127 |
+
state_dict.pop(f"blocks.{i}.mlp.fc1.weight")
|
| 128 |
+
state_dict.pop(f"blocks.{i}.mlp.fc1.bias")
|
| 129 |
+
state_dict.pop(f"blocks.{i}.mlp.fc2.weight")
|
| 130 |
+
state_dict.pop(f"blocks.{i}.mlp.fc2.bias")
|
| 131 |
+
|
| 132 |
+
# after_proj_list -> controlnet_blocks
|
| 133 |
+
state_dict[f"controlnet_blocks.{i}.weight"] = state_dict[f"after_proj_list.{i}.weight"]
|
| 134 |
+
state_dict[f"controlnet_blocks.{i}.bias"] = state_dict[f"after_proj_list.{i}.bias"]
|
| 135 |
+
state_dict.pop(f"after_proj_list.{i}.weight")
|
| 136 |
+
state_dict.pop(f"after_proj_list.{i}.bias")
|
| 137 |
+
|
| 138 |
+
# before_proj -> input_block
|
| 139 |
+
state_dict["input_block.weight"] = state_dict["before_proj.weight"]
|
| 140 |
+
state_dict["input_block.bias"] = state_dict["before_proj.bias"]
|
| 141 |
+
state_dict.pop("before_proj.weight")
|
| 142 |
+
state_dict.pop("before_proj.bias")
|
| 143 |
+
|
| 144 |
+
# pooler -> time_extra_emb
|
| 145 |
+
state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
|
| 146 |
+
state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
|
| 147 |
+
state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
|
| 148 |
+
state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
|
| 149 |
+
state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
|
| 150 |
+
state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
|
| 151 |
+
state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
|
| 152 |
+
state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
|
| 153 |
+
state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
|
| 154 |
+
state_dict.pop("pooler.k_proj.weight")
|
| 155 |
+
state_dict.pop("pooler.k_proj.bias")
|
| 156 |
+
state_dict.pop("pooler.q_proj.weight")
|
| 157 |
+
state_dict.pop("pooler.q_proj.bias")
|
| 158 |
+
state_dict.pop("pooler.v_proj.weight")
|
| 159 |
+
state_dict.pop("pooler.v_proj.bias")
|
| 160 |
+
state_dict.pop("pooler.c_proj.weight")
|
| 161 |
+
state_dict.pop("pooler.c_proj.bias")
|
| 162 |
+
state_dict.pop("pooler.positional_embedding")
|
| 163 |
+
|
| 164 |
+
# t_embedder -> time_embedding (`TimestepEmbedding`)
|
| 165 |
+
state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
|
| 166 |
+
state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
|
| 167 |
+
state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
|
| 168 |
+
state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
|
| 169 |
+
|
| 170 |
+
state_dict.pop("t_embedder.mlp.0.bias")
|
| 171 |
+
state_dict.pop("t_embedder.mlp.0.weight")
|
| 172 |
+
state_dict.pop("t_embedder.mlp.2.bias")
|
| 173 |
+
state_dict.pop("t_embedder.mlp.2.weight")
|
| 174 |
+
|
| 175 |
+
# x_embedder -> pos_embd (`PatchEmbed`)
|
| 176 |
+
state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
|
| 177 |
+
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
|
| 178 |
+
state_dict.pop("x_embedder.proj.weight")
|
| 179 |
+
state_dict.pop("x_embedder.proj.bias")
|
| 180 |
+
|
| 181 |
+
# mlp_t5 -> text_embedder
|
| 182 |
+
state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]
|
| 183 |
+
state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]
|
| 184 |
+
state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]
|
| 185 |
+
state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]
|
| 186 |
+
state_dict.pop("mlp_t5.0.bias")
|
| 187 |
+
state_dict.pop("mlp_t5.0.weight")
|
| 188 |
+
state_dict.pop("mlp_t5.2.bias")
|
| 189 |
+
state_dict.pop("mlp_t5.2.weight")
|
| 190 |
+
|
| 191 |
+
# extra_embedder -> extra_embedder
|
| 192 |
+
state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]
|
| 193 |
+
state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]
|
| 194 |
+
state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]
|
| 195 |
+
state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]
|
| 196 |
+
state_dict.pop("extra_embedder.0.bias")
|
| 197 |
+
state_dict.pop("extra_embedder.0.weight")
|
| 198 |
+
state_dict.pop("extra_embedder.2.bias")
|
| 199 |
+
state_dict.pop("extra_embedder.2.weight")
|
| 200 |
+
|
| 201 |
+
# style_embedder
|
| 202 |
+
if model_config["use_style_cond_and_image_meta_size"]:
|
| 203 |
+
print(state_dict["style_embedder.weight"])
|
| 204 |
+
print(state_dict["style_embedder.weight"].shape)
|
| 205 |
+
state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]
|
| 206 |
+
state_dict.pop("style_embedder.weight")
|
| 207 |
+
|
| 208 |
+
model.load_state_dict(state_dict)
|
| 209 |
+
|
| 210 |
+
if args.save:
|
| 211 |
+
model.save_pretrained(args.output_checkpoint_path)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
if __name__ == "__main__":
|
| 215 |
+
parser = argparse.ArgumentParser()
|
| 216 |
+
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
|
| 219 |
+
)
|
| 220 |
+
parser.add_argument(
|
| 221 |
+
"--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."
|
| 222 |
+
)
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--output_checkpoint_path",
|
| 225 |
+
default=None,
|
| 226 |
+
type=str,
|
| 227 |
+
required=False,
|
| 228 |
+
help="Path to the output converted diffusers pipeline.",
|
| 229 |
+
)
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"
|
| 232 |
+
)
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--use_style_cond_and_image_meta_size",
|
| 235 |
+
type=bool,
|
| 236 |
+
default=False,
|
| 237 |
+
help="version <= v1.1: True; version >= v1.2: False",
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
args = parser.parse_args()
|
| 241 |
+
main(args)
|
diffusers/scripts/convert_i2vgen_to_diffusers.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Conversion script for the LDM checkpoints."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 21 |
+
|
| 22 |
+
from diffusers import DDIMScheduler, I2VGenXLPipeline, I2VGenXLUNet, StableDiffusionPipeline
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
CLIP_ID = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def assign_to_checkpoint(
|
| 29 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
| 33 |
+
attention layers, and takes into account additional replacements that may arise.
|
| 34 |
+
|
| 35 |
+
Assigns the weights to the new checkpoint.
|
| 36 |
+
"""
|
| 37 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
| 38 |
+
|
| 39 |
+
# Splits the attention layers into three variables.
|
| 40 |
+
if attention_paths_to_split is not None:
|
| 41 |
+
for path, path_map in attention_paths_to_split.items():
|
| 42 |
+
old_tensor = old_checkpoint[path]
|
| 43 |
+
channels = old_tensor.shape[0] // 3
|
| 44 |
+
|
| 45 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
| 46 |
+
|
| 47 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
| 48 |
+
|
| 49 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
| 50 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
| 51 |
+
|
| 52 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
| 53 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
| 54 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
| 55 |
+
|
| 56 |
+
for path in paths:
|
| 57 |
+
new_path = path["new"]
|
| 58 |
+
|
| 59 |
+
# These have already been assigned
|
| 60 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
if additional_replacements is not None:
|
| 64 |
+
for replacement in additional_replacements:
|
| 65 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
| 66 |
+
|
| 67 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
| 68 |
+
weight = old_checkpoint[path["old"]]
|
| 69 |
+
names = ["proj_attn.weight"]
|
| 70 |
+
names_2 = ["proj_out.weight", "proj_in.weight"]
|
| 71 |
+
if any(k in new_path for k in names):
|
| 72 |
+
checkpoint[new_path] = weight[:, :, 0]
|
| 73 |
+
elif any(k in new_path for k in names_2) and len(weight.shape) > 2 and ".attentions." not in new_path:
|
| 74 |
+
checkpoint[new_path] = weight[:, :, 0]
|
| 75 |
+
else:
|
| 76 |
+
checkpoint[new_path] = weight
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 80 |
+
"""
|
| 81 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 82 |
+
"""
|
| 83 |
+
mapping = []
|
| 84 |
+
for old_item in old_list:
|
| 85 |
+
new_item = old_item
|
| 86 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 87 |
+
|
| 88 |
+
return mapping
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
| 92 |
+
"""
|
| 93 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
| 94 |
+
"""
|
| 95 |
+
if n_shave_prefix_segments >= 0:
|
| 96 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
| 97 |
+
else:
|
| 98 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def renew_temp_conv_paths(old_list, n_shave_prefix_segments=0):
|
| 102 |
+
"""
|
| 103 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 104 |
+
"""
|
| 105 |
+
mapping = []
|
| 106 |
+
for old_item in old_list:
|
| 107 |
+
mapping.append({"old": old_item, "new": old_item})
|
| 108 |
+
|
| 109 |
+
return mapping
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 113 |
+
"""
|
| 114 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 115 |
+
"""
|
| 116 |
+
mapping = []
|
| 117 |
+
for old_item in old_list:
|
| 118 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
| 119 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
| 120 |
+
|
| 121 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
| 122 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
| 123 |
+
|
| 124 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
| 125 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
| 126 |
+
|
| 127 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 128 |
+
|
| 129 |
+
if "temopral_conv" not in old_item:
|
| 130 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 131 |
+
|
| 132 |
+
return mapping
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
|
| 136 |
+
"""
|
| 137 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
# extract state_dict for UNet
|
| 141 |
+
unet_state_dict = {}
|
| 142 |
+
keys = list(checkpoint.keys())
|
| 143 |
+
|
| 144 |
+
unet_key = "model.diffusion_model."
|
| 145 |
+
|
| 146 |
+
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
| 147 |
+
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
| 148 |
+
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
| 149 |
+
print(
|
| 150 |
+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
| 151 |
+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
| 152 |
+
)
|
| 153 |
+
for key in keys:
|
| 154 |
+
if key.startswith("model.diffusion_model"):
|
| 155 |
+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
| 156 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
| 157 |
+
else:
|
| 158 |
+
if sum(k.startswith("model_ema") for k in keys) > 100:
|
| 159 |
+
print(
|
| 160 |
+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
| 161 |
+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
for key in keys:
|
| 165 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
| 166 |
+
|
| 167 |
+
new_checkpoint = {}
|
| 168 |
+
|
| 169 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
| 170 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
| 171 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
| 172 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
| 173 |
+
|
| 174 |
+
additional_embedding_substrings = [
|
| 175 |
+
"local_image_concat",
|
| 176 |
+
"context_embedding",
|
| 177 |
+
"local_image_embedding",
|
| 178 |
+
"fps_embedding",
|
| 179 |
+
]
|
| 180 |
+
for k in unet_state_dict:
|
| 181 |
+
if any(substring in k for substring in additional_embedding_substrings):
|
| 182 |
+
diffusers_key = k.replace("local_image_concat", "image_latents_proj_in").replace(
|
| 183 |
+
"local_image_embedding", "image_latents_context_embedding"
|
| 184 |
+
)
|
| 185 |
+
new_checkpoint[diffusers_key] = unet_state_dict[k]
|
| 186 |
+
|
| 187 |
+
# temporal encoder.
|
| 188 |
+
new_checkpoint["image_latents_temporal_encoder.norm1.weight"] = unet_state_dict[
|
| 189 |
+
"local_temporal_encoder.layers.0.0.norm.weight"
|
| 190 |
+
]
|
| 191 |
+
new_checkpoint["image_latents_temporal_encoder.norm1.bias"] = unet_state_dict[
|
| 192 |
+
"local_temporal_encoder.layers.0.0.norm.bias"
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
# attention
|
| 196 |
+
qkv = unet_state_dict["local_temporal_encoder.layers.0.0.fn.to_qkv.weight"]
|
| 197 |
+
q, k, v = torch.chunk(qkv, 3, dim=0)
|
| 198 |
+
new_checkpoint["image_latents_temporal_encoder.attn1.to_q.weight"] = q
|
| 199 |
+
new_checkpoint["image_latents_temporal_encoder.attn1.to_k.weight"] = k
|
| 200 |
+
new_checkpoint["image_latents_temporal_encoder.attn1.to_v.weight"] = v
|
| 201 |
+
new_checkpoint["image_latents_temporal_encoder.attn1.to_out.0.weight"] = unet_state_dict[
|
| 202 |
+
"local_temporal_encoder.layers.0.0.fn.to_out.0.weight"
|
| 203 |
+
]
|
| 204 |
+
new_checkpoint["image_latents_temporal_encoder.attn1.to_out.0.bias"] = unet_state_dict[
|
| 205 |
+
"local_temporal_encoder.layers.0.0.fn.to_out.0.bias"
|
| 206 |
+
]
|
| 207 |
+
|
| 208 |
+
# feedforward
|
| 209 |
+
new_checkpoint["image_latents_temporal_encoder.ff.net.0.proj.weight"] = unet_state_dict[
|
| 210 |
+
"local_temporal_encoder.layers.0.1.net.0.0.weight"
|
| 211 |
+
]
|
| 212 |
+
new_checkpoint["image_latents_temporal_encoder.ff.net.0.proj.bias"] = unet_state_dict[
|
| 213 |
+
"local_temporal_encoder.layers.0.1.net.0.0.bias"
|
| 214 |
+
]
|
| 215 |
+
new_checkpoint["image_latents_temporal_encoder.ff.net.2.weight"] = unet_state_dict[
|
| 216 |
+
"local_temporal_encoder.layers.0.1.net.2.weight"
|
| 217 |
+
]
|
| 218 |
+
new_checkpoint["image_latents_temporal_encoder.ff.net.2.bias"] = unet_state_dict[
|
| 219 |
+
"local_temporal_encoder.layers.0.1.net.2.bias"
|
| 220 |
+
]
|
| 221 |
+
|
| 222 |
+
if "class_embed_type" in config:
|
| 223 |
+
if config["class_embed_type"] is None:
|
| 224 |
+
# No parameters to port
|
| 225 |
+
...
|
| 226 |
+
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
| 227 |
+
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
| 228 |
+
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
| 229 |
+
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
| 230 |
+
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
| 231 |
+
else:
|
| 232 |
+
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
| 233 |
+
|
| 234 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
| 235 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
| 236 |
+
|
| 237 |
+
first_temp_attention = [v for v in unet_state_dict if v.startswith("input_blocks.0.1")]
|
| 238 |
+
paths = renew_attention_paths(first_temp_attention)
|
| 239 |
+
meta_path = {"old": "input_blocks.0.1", "new": "transformer_in"}
|
| 240 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
| 241 |
+
|
| 242 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
| 243 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
| 244 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
| 245 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
| 246 |
+
|
| 247 |
+
# Retrieves the keys for the input blocks only
|
| 248 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
| 249 |
+
input_blocks = {
|
| 250 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
| 251 |
+
for layer_id in range(num_input_blocks)
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
# Retrieves the keys for the middle blocks only
|
| 255 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
| 256 |
+
middle_blocks = {
|
| 257 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
| 258 |
+
for layer_id in range(num_middle_blocks)
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
# Retrieves the keys for the output blocks only
|
| 262 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
| 263 |
+
output_blocks = {
|
| 264 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
| 265 |
+
for layer_id in range(num_output_blocks)
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
for i in range(1, num_input_blocks):
|
| 269 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
| 270 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
| 271 |
+
|
| 272 |
+
resnets = [
|
| 273 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
| 274 |
+
]
|
| 275 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
| 276 |
+
temp_attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.2" in key]
|
| 277 |
+
|
| 278 |
+
if f"input_blocks.{i}.op.weight" in unet_state_dict:
|
| 279 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
| 280 |
+
f"input_blocks.{i}.op.weight"
|
| 281 |
+
)
|
| 282 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
| 283 |
+
f"input_blocks.{i}.op.bias"
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
paths = renew_resnet_paths(resnets)
|
| 287 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 288 |
+
assign_to_checkpoint(
|
| 289 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
temporal_convs = [key for key in resnets if "temopral_conv" in key]
|
| 293 |
+
paths = renew_temp_conv_paths(temporal_convs)
|
| 294 |
+
meta_path = {
|
| 295 |
+
"old": f"input_blocks.{i}.0.temopral_conv",
|
| 296 |
+
"new": f"down_blocks.{block_id}.temp_convs.{layer_in_block_id}",
|
| 297 |
+
}
|
| 298 |
+
assign_to_checkpoint(
|
| 299 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
if len(attentions):
|
| 303 |
+
paths = renew_attention_paths(attentions)
|
| 304 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
| 305 |
+
assign_to_checkpoint(
|
| 306 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
if len(temp_attentions):
|
| 310 |
+
paths = renew_attention_paths(temp_attentions)
|
| 311 |
+
meta_path = {
|
| 312 |
+
"old": f"input_blocks.{i}.2",
|
| 313 |
+
"new": f"down_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
|
| 314 |
+
}
|
| 315 |
+
assign_to_checkpoint(
|
| 316 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
resnet_0 = middle_blocks[0]
|
| 320 |
+
temporal_convs_0 = [key for key in resnet_0 if "temopral_conv" in key]
|
| 321 |
+
attentions = middle_blocks[1]
|
| 322 |
+
temp_attentions = middle_blocks[2]
|
| 323 |
+
resnet_1 = middle_blocks[3]
|
| 324 |
+
temporal_convs_1 = [key for key in resnet_1 if "temopral_conv" in key]
|
| 325 |
+
|
| 326 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
| 327 |
+
meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"}
|
| 328 |
+
assign_to_checkpoint(
|
| 329 |
+
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
temp_conv_0_paths = renew_temp_conv_paths(temporal_convs_0)
|
| 333 |
+
meta_path = {"old": "middle_block.0.temopral_conv", "new": "mid_block.temp_convs.0"}
|
| 334 |
+
assign_to_checkpoint(
|
| 335 |
+
temp_conv_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
| 339 |
+
meta_path = {"old": "middle_block.3", "new": "mid_block.resnets.1"}
|
| 340 |
+
assign_to_checkpoint(
|
| 341 |
+
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
temp_conv_1_paths = renew_temp_conv_paths(temporal_convs_1)
|
| 345 |
+
meta_path = {"old": "middle_block.3.temopral_conv", "new": "mid_block.temp_convs.1"}
|
| 346 |
+
assign_to_checkpoint(
|
| 347 |
+
temp_conv_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
attentions_paths = renew_attention_paths(attentions)
|
| 351 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
| 352 |
+
assign_to_checkpoint(
|
| 353 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
temp_attentions_paths = renew_attention_paths(temp_attentions)
|
| 357 |
+
meta_path = {"old": "middle_block.2", "new": "mid_block.temp_attentions.0"}
|
| 358 |
+
assign_to_checkpoint(
|
| 359 |
+
temp_attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
for i in range(num_output_blocks):
|
| 363 |
+
block_id = i // (config["layers_per_block"] + 1)
|
| 364 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
| 365 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
| 366 |
+
output_block_list = {}
|
| 367 |
+
|
| 368 |
+
for layer in output_block_layers:
|
| 369 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
| 370 |
+
if layer_id in output_block_list:
|
| 371 |
+
output_block_list[layer_id].append(layer_name)
|
| 372 |
+
else:
|
| 373 |
+
output_block_list[layer_id] = [layer_name]
|
| 374 |
+
|
| 375 |
+
if len(output_block_list) > 1:
|
| 376 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
| 377 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
| 378 |
+
temp_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key]
|
| 379 |
+
|
| 380 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
| 381 |
+
paths = renew_resnet_paths(resnets)
|
| 382 |
+
|
| 383 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 384 |
+
assign_to_checkpoint(
|
| 385 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
temporal_convs = [key for key in resnets if "temopral_conv" in key]
|
| 389 |
+
paths = renew_temp_conv_paths(temporal_convs)
|
| 390 |
+
meta_path = {
|
| 391 |
+
"old": f"output_blocks.{i}.0.temopral_conv",
|
| 392 |
+
"new": f"up_blocks.{block_id}.temp_convs.{layer_in_block_id}",
|
| 393 |
+
}
|
| 394 |
+
assign_to_checkpoint(
|
| 395 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
| 399 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
| 400 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
| 401 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
| 402 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
| 403 |
+
]
|
| 404 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
| 405 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
| 406 |
+
]
|
| 407 |
+
|
| 408 |
+
# Clear attentions as they have been attributed above.
|
| 409 |
+
if len(attentions) == 2:
|
| 410 |
+
attentions = []
|
| 411 |
+
|
| 412 |
+
if len(attentions):
|
| 413 |
+
paths = renew_attention_paths(attentions)
|
| 414 |
+
meta_path = {
|
| 415 |
+
"old": f"output_blocks.{i}.1",
|
| 416 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
| 417 |
+
}
|
| 418 |
+
assign_to_checkpoint(
|
| 419 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
if len(temp_attentions):
|
| 423 |
+
paths = renew_attention_paths(temp_attentions)
|
| 424 |
+
meta_path = {
|
| 425 |
+
"old": f"output_blocks.{i}.2",
|
| 426 |
+
"new": f"up_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
|
| 427 |
+
}
|
| 428 |
+
assign_to_checkpoint(
|
| 429 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 430 |
+
)
|
| 431 |
+
else:
|
| 432 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
| 433 |
+
for path in resnet_0_paths:
|
| 434 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
| 435 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
| 436 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
| 437 |
+
|
| 438 |
+
temopral_conv_paths = [l for l in output_block_layers if "temopral_conv" in l]
|
| 439 |
+
for path in temopral_conv_paths:
|
| 440 |
+
pruned_path = path.split("temopral_conv.")[-1]
|
| 441 |
+
old_path = ".".join(["output_blocks", str(i), str(block_id), "temopral_conv", pruned_path])
|
| 442 |
+
new_path = ".".join(["up_blocks", str(block_id), "temp_convs", str(layer_in_block_id), pruned_path])
|
| 443 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
| 444 |
+
|
| 445 |
+
return new_checkpoint
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
if __name__ == "__main__":
|
| 449 |
+
parser = argparse.ArgumentParser()
|
| 450 |
+
|
| 451 |
+
parser.add_argument(
|
| 452 |
+
"--unet_checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
| 453 |
+
)
|
| 454 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 455 |
+
parser.add_argument("--push_to_hub", action="store_true")
|
| 456 |
+
args = parser.parse_args()
|
| 457 |
+
|
| 458 |
+
# UNet
|
| 459 |
+
unet_checkpoint = torch.load(args.unet_checkpoint_path, map_location="cpu")
|
| 460 |
+
unet_checkpoint = unet_checkpoint["state_dict"]
|
| 461 |
+
unet = I2VGenXLUNet(sample_size=32)
|
| 462 |
+
|
| 463 |
+
converted_ckpt = convert_ldm_unet_checkpoint(unet_checkpoint, unet.config)
|
| 464 |
+
|
| 465 |
+
diff_0 = set(unet.state_dict().keys()) - set(converted_ckpt.keys())
|
| 466 |
+
diff_1 = set(converted_ckpt.keys()) - set(unet.state_dict().keys())
|
| 467 |
+
|
| 468 |
+
assert len(diff_0) == len(diff_1) == 0, "Converted weights don't match"
|
| 469 |
+
|
| 470 |
+
unet.load_state_dict(converted_ckpt, strict=True)
|
| 471 |
+
|
| 472 |
+
# vae
|
| 473 |
+
temp_pipe = StableDiffusionPipeline.from_single_file(
|
| 474 |
+
"https://huggingface.co/ali-vilab/i2vgen-xl/blob/main/models/v2-1_512-ema-pruned.ckpt"
|
| 475 |
+
)
|
| 476 |
+
vae = temp_pipe.vae
|
| 477 |
+
del temp_pipe
|
| 478 |
+
|
| 479 |
+
# text encoder and tokenizer
|
| 480 |
+
text_encoder = CLIPTextModel.from_pretrained(CLIP_ID)
|
| 481 |
+
tokenizer = CLIPTokenizer.from_pretrained(CLIP_ID)
|
| 482 |
+
|
| 483 |
+
# image encoder and feature extractor
|
| 484 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(CLIP_ID)
|
| 485 |
+
feature_extractor = CLIPImageProcessor.from_pretrained(CLIP_ID)
|
| 486 |
+
|
| 487 |
+
# scheduler
|
| 488 |
+
# https://github.com/ali-vilab/i2vgen-xl/blob/main/configs/i2vgen_xl_train.yaml
|
| 489 |
+
scheduler = DDIMScheduler(
|
| 490 |
+
beta_schedule="squaredcos_cap_v2",
|
| 491 |
+
rescale_betas_zero_snr=True,
|
| 492 |
+
set_alpha_to_one=True,
|
| 493 |
+
clip_sample=False,
|
| 494 |
+
steps_offset=1,
|
| 495 |
+
timestep_spacing="leading",
|
| 496 |
+
prediction_type="v_prediction",
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# final
|
| 500 |
+
pipeline = I2VGenXLPipeline(
|
| 501 |
+
unet=unet,
|
| 502 |
+
vae=vae,
|
| 503 |
+
image_encoder=image_encoder,
|
| 504 |
+
feature_extractor=feature_extractor,
|
| 505 |
+
text_encoder=text_encoder,
|
| 506 |
+
tokenizer=tokenizer,
|
| 507 |
+
scheduler=scheduler,
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
pipeline.save_pretrained(args.dump_path, push_to_hub=args.push_to_hub)
|
diffusers/scripts/convert_if.py
ADDED
|
@@ -0,0 +1,1250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import inspect
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import yaml
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
from transformers import CLIPConfig, CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer
|
| 10 |
+
|
| 11 |
+
from diffusers import DDPMScheduler, IFPipeline, IFSuperResolutionPipeline, UNet2DConditionModel
|
| 12 |
+
from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def parse_args():
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
|
| 18 |
+
parser.add_argument("--dump_path", required=False, default=None, type=str)
|
| 19 |
+
|
| 20 |
+
parser.add_argument("--dump_path_stage_2", required=False, default=None, type=str)
|
| 21 |
+
|
| 22 |
+
parser.add_argument("--dump_path_stage_3", required=False, default=None, type=str)
|
| 23 |
+
|
| 24 |
+
parser.add_argument("--unet_config", required=False, default=None, type=str, help="Path to unet config file")
|
| 25 |
+
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--unet_checkpoint_path", required=False, default=None, type=str, help="Path to unet checkpoint file"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--unet_checkpoint_path_stage_2",
|
| 32 |
+
required=False,
|
| 33 |
+
default=None,
|
| 34 |
+
type=str,
|
| 35 |
+
help="Path to stage 2 unet checkpoint file",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--unet_checkpoint_path_stage_3",
|
| 40 |
+
required=False,
|
| 41 |
+
default=None,
|
| 42 |
+
type=str,
|
| 43 |
+
help="Path to stage 3 unet checkpoint file",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
parser.add_argument("--p_head_path", type=str, required=True)
|
| 47 |
+
|
| 48 |
+
parser.add_argument("--w_head_path", type=str, required=True)
|
| 49 |
+
|
| 50 |
+
args = parser.parse_args()
|
| 51 |
+
|
| 52 |
+
return args
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def main(args):
|
| 56 |
+
tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
|
| 57 |
+
text_encoder = T5EncoderModel.from_pretrained("google/t5-v1_1-xxl")
|
| 58 |
+
|
| 59 |
+
feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
| 60 |
+
safety_checker = convert_safety_checker(p_head_path=args.p_head_path, w_head_path=args.w_head_path)
|
| 61 |
+
|
| 62 |
+
if args.unet_config is not None and args.unet_checkpoint_path is not None and args.dump_path is not None:
|
| 63 |
+
convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args)
|
| 64 |
+
|
| 65 |
+
if args.unet_checkpoint_path_stage_2 is not None and args.dump_path_stage_2 is not None:
|
| 66 |
+
convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=2)
|
| 67 |
+
|
| 68 |
+
if args.unet_checkpoint_path_stage_3 is not None and args.dump_path_stage_3 is not None:
|
| 69 |
+
convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=3)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args):
|
| 73 |
+
unet = get_stage_1_unet(args.unet_config, args.unet_checkpoint_path)
|
| 74 |
+
|
| 75 |
+
scheduler = DDPMScheduler(
|
| 76 |
+
variance_type="learned_range",
|
| 77 |
+
beta_schedule="squaredcos_cap_v2",
|
| 78 |
+
prediction_type="epsilon",
|
| 79 |
+
thresholding=True,
|
| 80 |
+
dynamic_thresholding_ratio=0.95,
|
| 81 |
+
sample_max_value=1.5,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
pipe = IFPipeline(
|
| 85 |
+
tokenizer=tokenizer,
|
| 86 |
+
text_encoder=text_encoder,
|
| 87 |
+
unet=unet,
|
| 88 |
+
scheduler=scheduler,
|
| 89 |
+
safety_checker=safety_checker,
|
| 90 |
+
feature_extractor=feature_extractor,
|
| 91 |
+
requires_safety_checker=True,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
pipe.save_pretrained(args.dump_path)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage):
|
| 98 |
+
if stage == 2:
|
| 99 |
+
unet_checkpoint_path = args.unet_checkpoint_path_stage_2
|
| 100 |
+
sample_size = None
|
| 101 |
+
dump_path = args.dump_path_stage_2
|
| 102 |
+
elif stage == 3:
|
| 103 |
+
unet_checkpoint_path = args.unet_checkpoint_path_stage_3
|
| 104 |
+
sample_size = 1024
|
| 105 |
+
dump_path = args.dump_path_stage_3
|
| 106 |
+
else:
|
| 107 |
+
assert False
|
| 108 |
+
|
| 109 |
+
unet = get_super_res_unet(unet_checkpoint_path, verify_param_count=False, sample_size=sample_size)
|
| 110 |
+
|
| 111 |
+
image_noising_scheduler = DDPMScheduler(
|
| 112 |
+
beta_schedule="squaredcos_cap_v2",
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
scheduler = DDPMScheduler(
|
| 116 |
+
variance_type="learned_range",
|
| 117 |
+
beta_schedule="squaredcos_cap_v2",
|
| 118 |
+
prediction_type="epsilon",
|
| 119 |
+
thresholding=True,
|
| 120 |
+
dynamic_thresholding_ratio=0.95,
|
| 121 |
+
sample_max_value=1.0,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
pipe = IFSuperResolutionPipeline(
|
| 125 |
+
tokenizer=tokenizer,
|
| 126 |
+
text_encoder=text_encoder,
|
| 127 |
+
unet=unet,
|
| 128 |
+
scheduler=scheduler,
|
| 129 |
+
image_noising_scheduler=image_noising_scheduler,
|
| 130 |
+
safety_checker=safety_checker,
|
| 131 |
+
feature_extractor=feature_extractor,
|
| 132 |
+
requires_safety_checker=True,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
pipe.save_pretrained(dump_path)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_stage_1_unet(unet_config, unet_checkpoint_path):
|
| 139 |
+
original_unet_config = yaml.safe_load(unet_config)
|
| 140 |
+
original_unet_config = original_unet_config["params"]
|
| 141 |
+
|
| 142 |
+
unet_diffusers_config = create_unet_diffusers_config(original_unet_config)
|
| 143 |
+
|
| 144 |
+
unet = UNet2DConditionModel(**unet_diffusers_config)
|
| 145 |
+
|
| 146 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 147 |
+
unet_checkpoint = torch.load(unet_checkpoint_path, map_location=device)
|
| 148 |
+
|
| 149 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
| 150 |
+
unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
unet.load_state_dict(converted_unet_checkpoint)
|
| 154 |
+
|
| 155 |
+
return unet
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def convert_safety_checker(p_head_path, w_head_path):
|
| 159 |
+
state_dict = {}
|
| 160 |
+
|
| 161 |
+
# p head
|
| 162 |
+
|
| 163 |
+
p_head = np.load(p_head_path)
|
| 164 |
+
|
| 165 |
+
p_head_weights = p_head["weights"]
|
| 166 |
+
p_head_weights = torch.from_numpy(p_head_weights)
|
| 167 |
+
p_head_weights = p_head_weights.unsqueeze(0)
|
| 168 |
+
|
| 169 |
+
p_head_biases = p_head["biases"]
|
| 170 |
+
p_head_biases = torch.from_numpy(p_head_biases)
|
| 171 |
+
p_head_biases = p_head_biases.unsqueeze(0)
|
| 172 |
+
|
| 173 |
+
state_dict["p_head.weight"] = p_head_weights
|
| 174 |
+
state_dict["p_head.bias"] = p_head_biases
|
| 175 |
+
|
| 176 |
+
# w head
|
| 177 |
+
|
| 178 |
+
w_head = np.load(w_head_path)
|
| 179 |
+
|
| 180 |
+
w_head_weights = w_head["weights"]
|
| 181 |
+
w_head_weights = torch.from_numpy(w_head_weights)
|
| 182 |
+
w_head_weights = w_head_weights.unsqueeze(0)
|
| 183 |
+
|
| 184 |
+
w_head_biases = w_head["biases"]
|
| 185 |
+
w_head_biases = torch.from_numpy(w_head_biases)
|
| 186 |
+
w_head_biases = w_head_biases.unsqueeze(0)
|
| 187 |
+
|
| 188 |
+
state_dict["w_head.weight"] = w_head_weights
|
| 189 |
+
state_dict["w_head.bias"] = w_head_biases
|
| 190 |
+
|
| 191 |
+
# vision model
|
| 192 |
+
|
| 193 |
+
vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
| 194 |
+
vision_model_state_dict = vision_model.state_dict()
|
| 195 |
+
|
| 196 |
+
for key, value in vision_model_state_dict.items():
|
| 197 |
+
key = f"vision_model.{key}"
|
| 198 |
+
state_dict[key] = value
|
| 199 |
+
|
| 200 |
+
# full model
|
| 201 |
+
|
| 202 |
+
config = CLIPConfig.from_pretrained("openai/clip-vit-large-patch14")
|
| 203 |
+
safety_checker = IFSafetyChecker(config)
|
| 204 |
+
|
| 205 |
+
safety_checker.load_state_dict(state_dict)
|
| 206 |
+
|
| 207 |
+
return safety_checker
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
|
| 211 |
+
attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
|
| 212 |
+
attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
|
| 213 |
+
|
| 214 |
+
channel_mult = parse_list(original_unet_config["channel_mult"])
|
| 215 |
+
block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
|
| 216 |
+
|
| 217 |
+
down_block_types = []
|
| 218 |
+
resolution = 1
|
| 219 |
+
|
| 220 |
+
for i in range(len(block_out_channels)):
|
| 221 |
+
if resolution in attention_resolutions:
|
| 222 |
+
block_type = "SimpleCrossAttnDownBlock2D"
|
| 223 |
+
elif original_unet_config["resblock_updown"]:
|
| 224 |
+
block_type = "ResnetDownsampleBlock2D"
|
| 225 |
+
else:
|
| 226 |
+
block_type = "DownBlock2D"
|
| 227 |
+
|
| 228 |
+
down_block_types.append(block_type)
|
| 229 |
+
|
| 230 |
+
if i != len(block_out_channels) - 1:
|
| 231 |
+
resolution *= 2
|
| 232 |
+
|
| 233 |
+
up_block_types = []
|
| 234 |
+
for i in range(len(block_out_channels)):
|
| 235 |
+
if resolution in attention_resolutions:
|
| 236 |
+
block_type = "SimpleCrossAttnUpBlock2D"
|
| 237 |
+
elif original_unet_config["resblock_updown"]:
|
| 238 |
+
block_type = "ResnetUpsampleBlock2D"
|
| 239 |
+
else:
|
| 240 |
+
block_type = "UpBlock2D"
|
| 241 |
+
up_block_types.append(block_type)
|
| 242 |
+
resolution //= 2
|
| 243 |
+
|
| 244 |
+
head_dim = original_unet_config["num_head_channels"]
|
| 245 |
+
|
| 246 |
+
use_linear_projection = (
|
| 247 |
+
original_unet_config["use_linear_in_transformer"]
|
| 248 |
+
if "use_linear_in_transformer" in original_unet_config
|
| 249 |
+
else False
|
| 250 |
+
)
|
| 251 |
+
if use_linear_projection:
|
| 252 |
+
# stable diffusion 2-base-512 and 2-768
|
| 253 |
+
if head_dim is None:
|
| 254 |
+
head_dim = [5, 10, 20, 20]
|
| 255 |
+
|
| 256 |
+
projection_class_embeddings_input_dim = None
|
| 257 |
+
|
| 258 |
+
if class_embed_type is None:
|
| 259 |
+
if "num_classes" in original_unet_config:
|
| 260 |
+
if original_unet_config["num_classes"] == "sequential":
|
| 261 |
+
class_embed_type = "projection"
|
| 262 |
+
assert "adm_in_channels" in original_unet_config
|
| 263 |
+
projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
|
| 264 |
+
else:
|
| 265 |
+
raise NotImplementedError(
|
| 266 |
+
f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
config = {
|
| 270 |
+
"sample_size": original_unet_config["image_size"],
|
| 271 |
+
"in_channels": original_unet_config["in_channels"],
|
| 272 |
+
"down_block_types": tuple(down_block_types),
|
| 273 |
+
"block_out_channels": tuple(block_out_channels),
|
| 274 |
+
"layers_per_block": original_unet_config["num_res_blocks"],
|
| 275 |
+
"cross_attention_dim": original_unet_config["encoder_channels"],
|
| 276 |
+
"attention_head_dim": head_dim,
|
| 277 |
+
"use_linear_projection": use_linear_projection,
|
| 278 |
+
"class_embed_type": class_embed_type,
|
| 279 |
+
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
| 280 |
+
"out_channels": original_unet_config["out_channels"],
|
| 281 |
+
"up_block_types": tuple(up_block_types),
|
| 282 |
+
"upcast_attention": False, # TODO: guessing
|
| 283 |
+
"cross_attention_norm": "group_norm",
|
| 284 |
+
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
|
| 285 |
+
"addition_embed_type": "text",
|
| 286 |
+
"act_fn": "gelu",
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
if original_unet_config["use_scale_shift_norm"]:
|
| 290 |
+
config["resnet_time_scale_shift"] = "scale_shift"
|
| 291 |
+
|
| 292 |
+
if "encoder_dim" in original_unet_config:
|
| 293 |
+
config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
|
| 294 |
+
|
| 295 |
+
return config
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def convert_ldm_unet_checkpoint(unet_state_dict, config, path=None):
|
| 299 |
+
"""
|
| 300 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
| 301 |
+
"""
|
| 302 |
+
new_checkpoint = {}
|
| 303 |
+
|
| 304 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
| 305 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
| 306 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
| 307 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
| 308 |
+
|
| 309 |
+
if config["class_embed_type"] in [None, "identity"]:
|
| 310 |
+
# No parameters to port
|
| 311 |
+
...
|
| 312 |
+
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
| 313 |
+
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
| 314 |
+
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
| 315 |
+
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
| 316 |
+
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
| 317 |
+
else:
|
| 318 |
+
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
| 319 |
+
|
| 320 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
| 321 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
| 322 |
+
|
| 323 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
| 324 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
| 325 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
| 326 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
| 327 |
+
|
| 328 |
+
# Retrieves the keys for the input blocks only
|
| 329 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
| 330 |
+
input_blocks = {
|
| 331 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
| 332 |
+
for layer_id in range(num_input_blocks)
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
# Retrieves the keys for the middle blocks only
|
| 336 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
| 337 |
+
middle_blocks = {
|
| 338 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
| 339 |
+
for layer_id in range(num_middle_blocks)
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
# Retrieves the keys for the output blocks only
|
| 343 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
| 344 |
+
output_blocks = {
|
| 345 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
| 346 |
+
for layer_id in range(num_output_blocks)
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
for i in range(1, num_input_blocks):
|
| 350 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
| 351 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
| 352 |
+
|
| 353 |
+
resnets = [
|
| 354 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
| 355 |
+
]
|
| 356 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
| 357 |
+
|
| 358 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
| 359 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
| 360 |
+
f"input_blocks.{i}.0.op.weight"
|
| 361 |
+
)
|
| 362 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
| 363 |
+
f"input_blocks.{i}.0.op.bias"
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
paths = renew_resnet_paths(resnets)
|
| 367 |
+
|
| 368 |
+
# TODO need better check than i in [4, 8, 12, 16]
|
| 369 |
+
block_type = config["down_block_types"][block_id]
|
| 370 |
+
if (block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D") and i in [
|
| 371 |
+
4,
|
| 372 |
+
8,
|
| 373 |
+
12,
|
| 374 |
+
16,
|
| 375 |
+
]:
|
| 376 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"}
|
| 377 |
+
else:
|
| 378 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 379 |
+
|
| 380 |
+
assign_to_checkpoint(
|
| 381 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
if len(attentions):
|
| 385 |
+
old_path = f"input_blocks.{i}.1"
|
| 386 |
+
new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
|
| 387 |
+
|
| 388 |
+
assign_attention_to_checkpoint(
|
| 389 |
+
new_checkpoint=new_checkpoint,
|
| 390 |
+
unet_state_dict=unet_state_dict,
|
| 391 |
+
old_path=old_path,
|
| 392 |
+
new_path=new_path,
|
| 393 |
+
config=config,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
paths = renew_attention_paths(attentions)
|
| 397 |
+
meta_path = {"old": old_path, "new": new_path}
|
| 398 |
+
assign_to_checkpoint(
|
| 399 |
+
paths,
|
| 400 |
+
new_checkpoint,
|
| 401 |
+
unet_state_dict,
|
| 402 |
+
additional_replacements=[meta_path],
|
| 403 |
+
config=config,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
resnet_0 = middle_blocks[0]
|
| 407 |
+
attentions = middle_blocks[1]
|
| 408 |
+
resnet_1 = middle_blocks[2]
|
| 409 |
+
|
| 410 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
| 411 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
| 412 |
+
|
| 413 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
| 414 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
| 415 |
+
|
| 416 |
+
old_path = "middle_block.1"
|
| 417 |
+
new_path = "mid_block.attentions.0"
|
| 418 |
+
|
| 419 |
+
assign_attention_to_checkpoint(
|
| 420 |
+
new_checkpoint=new_checkpoint,
|
| 421 |
+
unet_state_dict=unet_state_dict,
|
| 422 |
+
old_path=old_path,
|
| 423 |
+
new_path=new_path,
|
| 424 |
+
config=config,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
attentions_paths = renew_attention_paths(attentions)
|
| 428 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
| 429 |
+
assign_to_checkpoint(
|
| 430 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
for i in range(num_output_blocks):
|
| 434 |
+
block_id = i // (config["layers_per_block"] + 1)
|
| 435 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
| 436 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
| 437 |
+
output_block_list = {}
|
| 438 |
+
|
| 439 |
+
for layer in output_block_layers:
|
| 440 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
| 441 |
+
if layer_id in output_block_list:
|
| 442 |
+
output_block_list[layer_id].append(layer_name)
|
| 443 |
+
else:
|
| 444 |
+
output_block_list[layer_id] = [layer_name]
|
| 445 |
+
|
| 446 |
+
# len(output_block_list) == 1 -> resnet
|
| 447 |
+
# len(output_block_list) == 2 -> resnet, attention
|
| 448 |
+
# len(output_block_list) == 3 -> resnet, attention, upscale resnet
|
| 449 |
+
|
| 450 |
+
if len(output_block_list) > 1:
|
| 451 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
| 452 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
| 453 |
+
|
| 454 |
+
paths = renew_resnet_paths(resnets)
|
| 455 |
+
|
| 456 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 457 |
+
|
| 458 |
+
assign_to_checkpoint(
|
| 459 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
| 463 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
| 464 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
| 465 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
| 466 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
| 467 |
+
]
|
| 468 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
| 469 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
| 470 |
+
]
|
| 471 |
+
|
| 472 |
+
# Clear attentions as they have been attributed above.
|
| 473 |
+
if len(attentions) == 2:
|
| 474 |
+
attentions = []
|
| 475 |
+
|
| 476 |
+
if len(attentions):
|
| 477 |
+
old_path = f"output_blocks.{i}.1"
|
| 478 |
+
new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}"
|
| 479 |
+
|
| 480 |
+
assign_attention_to_checkpoint(
|
| 481 |
+
new_checkpoint=new_checkpoint,
|
| 482 |
+
unet_state_dict=unet_state_dict,
|
| 483 |
+
old_path=old_path,
|
| 484 |
+
new_path=new_path,
|
| 485 |
+
config=config,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
paths = renew_attention_paths(attentions)
|
| 489 |
+
meta_path = {
|
| 490 |
+
"old": old_path,
|
| 491 |
+
"new": new_path,
|
| 492 |
+
}
|
| 493 |
+
assign_to_checkpoint(
|
| 494 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
if len(output_block_list) == 3:
|
| 498 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key]
|
| 499 |
+
paths = renew_resnet_paths(resnets)
|
| 500 |
+
meta_path = {"old": f"output_blocks.{i}.2", "new": f"up_blocks.{block_id}.upsamplers.0"}
|
| 501 |
+
assign_to_checkpoint(
|
| 502 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 503 |
+
)
|
| 504 |
+
else:
|
| 505 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
| 506 |
+
for path in resnet_0_paths:
|
| 507 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
| 508 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
| 509 |
+
|
| 510 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
| 511 |
+
|
| 512 |
+
if "encoder_proj.weight" in unet_state_dict:
|
| 513 |
+
new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict.pop("encoder_proj.weight")
|
| 514 |
+
new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict.pop("encoder_proj.bias")
|
| 515 |
+
|
| 516 |
+
if "encoder_pooling.0.weight" in unet_state_dict:
|
| 517 |
+
new_checkpoint["add_embedding.norm1.weight"] = unet_state_dict.pop("encoder_pooling.0.weight")
|
| 518 |
+
new_checkpoint["add_embedding.norm1.bias"] = unet_state_dict.pop("encoder_pooling.0.bias")
|
| 519 |
+
|
| 520 |
+
new_checkpoint["add_embedding.pool.positional_embedding"] = unet_state_dict.pop(
|
| 521 |
+
"encoder_pooling.1.positional_embedding"
|
| 522 |
+
)
|
| 523 |
+
new_checkpoint["add_embedding.pool.k_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.k_proj.weight")
|
| 524 |
+
new_checkpoint["add_embedding.pool.k_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.k_proj.bias")
|
| 525 |
+
new_checkpoint["add_embedding.pool.q_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.q_proj.weight")
|
| 526 |
+
new_checkpoint["add_embedding.pool.q_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.q_proj.bias")
|
| 527 |
+
new_checkpoint["add_embedding.pool.v_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.v_proj.weight")
|
| 528 |
+
new_checkpoint["add_embedding.pool.v_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.v_proj.bias")
|
| 529 |
+
|
| 530 |
+
new_checkpoint["add_embedding.proj.weight"] = unet_state_dict.pop("encoder_pooling.2.weight")
|
| 531 |
+
new_checkpoint["add_embedding.proj.bias"] = unet_state_dict.pop("encoder_pooling.2.bias")
|
| 532 |
+
|
| 533 |
+
new_checkpoint["add_embedding.norm2.weight"] = unet_state_dict.pop("encoder_pooling.3.weight")
|
| 534 |
+
new_checkpoint["add_embedding.norm2.bias"] = unet_state_dict.pop("encoder_pooling.3.bias")
|
| 535 |
+
|
| 536 |
+
return new_checkpoint
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
| 540 |
+
"""
|
| 541 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
| 542 |
+
"""
|
| 543 |
+
if n_shave_prefix_segments >= 0:
|
| 544 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
| 545 |
+
else:
|
| 546 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 550 |
+
"""
|
| 551 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 552 |
+
"""
|
| 553 |
+
mapping = []
|
| 554 |
+
for old_item in old_list:
|
| 555 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
| 556 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
| 557 |
+
|
| 558 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
| 559 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
| 560 |
+
|
| 561 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
| 562 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
| 563 |
+
|
| 564 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 565 |
+
|
| 566 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 567 |
+
|
| 568 |
+
return mapping
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 572 |
+
"""
|
| 573 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 574 |
+
"""
|
| 575 |
+
mapping = []
|
| 576 |
+
for old_item in old_list:
|
| 577 |
+
new_item = old_item
|
| 578 |
+
|
| 579 |
+
if "qkv" in new_item:
|
| 580 |
+
continue
|
| 581 |
+
|
| 582 |
+
if "encoder_kv" in new_item:
|
| 583 |
+
continue
|
| 584 |
+
|
| 585 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
| 586 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
| 587 |
+
|
| 588 |
+
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
| 589 |
+
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
| 590 |
+
|
| 591 |
+
new_item = new_item.replace("norm_encoder.weight", "norm_cross.weight")
|
| 592 |
+
new_item = new_item.replace("norm_encoder.bias", "norm_cross.bias")
|
| 593 |
+
|
| 594 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 595 |
+
|
| 596 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 597 |
+
|
| 598 |
+
return mapping
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
def assign_attention_to_checkpoint(new_checkpoint, unet_state_dict, old_path, new_path, config):
|
| 602 |
+
qkv_weight = unet_state_dict.pop(f"{old_path}.qkv.weight")
|
| 603 |
+
qkv_weight = qkv_weight[:, :, 0]
|
| 604 |
+
|
| 605 |
+
qkv_bias = unet_state_dict.pop(f"{old_path}.qkv.bias")
|
| 606 |
+
|
| 607 |
+
is_cross_attn_only = "only_cross_attention" in config and config["only_cross_attention"]
|
| 608 |
+
|
| 609 |
+
split = 1 if is_cross_attn_only else 3
|
| 610 |
+
|
| 611 |
+
weights, bias = split_attentions(
|
| 612 |
+
weight=qkv_weight,
|
| 613 |
+
bias=qkv_bias,
|
| 614 |
+
split=split,
|
| 615 |
+
chunk_size=config["attention_head_dim"],
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
if is_cross_attn_only:
|
| 619 |
+
query_weight, q_bias = weights, bias
|
| 620 |
+
new_checkpoint[f"{new_path}.to_q.weight"] = query_weight[0]
|
| 621 |
+
new_checkpoint[f"{new_path}.to_q.bias"] = q_bias[0]
|
| 622 |
+
else:
|
| 623 |
+
[query_weight, key_weight, value_weight], [q_bias, k_bias, v_bias] = weights, bias
|
| 624 |
+
new_checkpoint[f"{new_path}.to_q.weight"] = query_weight
|
| 625 |
+
new_checkpoint[f"{new_path}.to_q.bias"] = q_bias
|
| 626 |
+
new_checkpoint[f"{new_path}.to_k.weight"] = key_weight
|
| 627 |
+
new_checkpoint[f"{new_path}.to_k.bias"] = k_bias
|
| 628 |
+
new_checkpoint[f"{new_path}.to_v.weight"] = value_weight
|
| 629 |
+
new_checkpoint[f"{new_path}.to_v.bias"] = v_bias
|
| 630 |
+
|
| 631 |
+
encoder_kv_weight = unet_state_dict.pop(f"{old_path}.encoder_kv.weight")
|
| 632 |
+
encoder_kv_weight = encoder_kv_weight[:, :, 0]
|
| 633 |
+
|
| 634 |
+
encoder_kv_bias = unet_state_dict.pop(f"{old_path}.encoder_kv.bias")
|
| 635 |
+
|
| 636 |
+
[encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions(
|
| 637 |
+
weight=encoder_kv_weight,
|
| 638 |
+
bias=encoder_kv_bias,
|
| 639 |
+
split=2,
|
| 640 |
+
chunk_size=config["attention_head_dim"],
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
new_checkpoint[f"{new_path}.add_k_proj.weight"] = encoder_k_weight
|
| 644 |
+
new_checkpoint[f"{new_path}.add_k_proj.bias"] = encoder_k_bias
|
| 645 |
+
new_checkpoint[f"{new_path}.add_v_proj.weight"] = encoder_v_weight
|
| 646 |
+
new_checkpoint[f"{new_path}.add_v_proj.bias"] = encoder_v_bias
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, additional_replacements=None, config=None):
|
| 650 |
+
"""
|
| 651 |
+
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
| 652 |
+
attention layers, and takes into account additional replacements that may arise.
|
| 653 |
+
|
| 654 |
+
Assigns the weights to the new checkpoint.
|
| 655 |
+
"""
|
| 656 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
| 657 |
+
|
| 658 |
+
for path in paths:
|
| 659 |
+
new_path = path["new"]
|
| 660 |
+
|
| 661 |
+
# Global renaming happens here
|
| 662 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
| 663 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
| 664 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
| 665 |
+
|
| 666 |
+
if additional_replacements is not None:
|
| 667 |
+
for replacement in additional_replacements:
|
| 668 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
| 669 |
+
|
| 670 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
| 671 |
+
if "proj_attn.weight" in new_path or "to_out.0.weight" in new_path:
|
| 672 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
| 673 |
+
else:
|
| 674 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
|
| 678 |
+
def split_attentions(*, weight, bias, split, chunk_size):
|
| 679 |
+
weights = [None] * split
|
| 680 |
+
biases = [None] * split
|
| 681 |
+
|
| 682 |
+
weights_biases_idx = 0
|
| 683 |
+
|
| 684 |
+
for starting_row_index in range(0, weight.shape[0], chunk_size):
|
| 685 |
+
row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
|
| 686 |
+
|
| 687 |
+
weight_rows = weight[row_indices, :]
|
| 688 |
+
bias_rows = bias[row_indices]
|
| 689 |
+
|
| 690 |
+
if weights[weights_biases_idx] is None:
|
| 691 |
+
weights[weights_biases_idx] = weight_rows
|
| 692 |
+
biases[weights_biases_idx] = bias_rows
|
| 693 |
+
else:
|
| 694 |
+
assert weights[weights_biases_idx] is not None
|
| 695 |
+
weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
|
| 696 |
+
biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
|
| 697 |
+
|
| 698 |
+
weights_biases_idx = (weights_biases_idx + 1) % split
|
| 699 |
+
|
| 700 |
+
return weights, biases
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
def parse_list(value):
|
| 704 |
+
if isinstance(value, str):
|
| 705 |
+
value = value.split(",")
|
| 706 |
+
value = [int(v) for v in value]
|
| 707 |
+
elif isinstance(value, list):
|
| 708 |
+
pass
|
| 709 |
+
else:
|
| 710 |
+
raise ValueError(f"Can't parse list for type: {type(value)}")
|
| 711 |
+
|
| 712 |
+
return value
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
# below is copy and pasted from original convert_if_stage_2.py script
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None):
|
| 719 |
+
orig_path = unet_checkpoint_path
|
| 720 |
+
|
| 721 |
+
original_unet_config = yaml.safe_load(os.path.join(orig_path, "config.yml"))
|
| 722 |
+
original_unet_config = original_unet_config["params"]
|
| 723 |
+
|
| 724 |
+
unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config)
|
| 725 |
+
unet_diffusers_config["time_embedding_dim"] = original_unet_config["model_channels"] * int(
|
| 726 |
+
original_unet_config["channel_mult"].split(",")[-1]
|
| 727 |
+
)
|
| 728 |
+
if original_unet_config["encoder_dim"] != original_unet_config["encoder_channels"]:
|
| 729 |
+
unet_diffusers_config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
|
| 730 |
+
unet_diffusers_config["class_embed_type"] = "timestep"
|
| 731 |
+
unet_diffusers_config["addition_embed_type"] = "text"
|
| 732 |
+
|
| 733 |
+
unet_diffusers_config["time_embedding_act_fn"] = "gelu"
|
| 734 |
+
unet_diffusers_config["resnet_skip_time_act"] = True
|
| 735 |
+
unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071
|
| 736 |
+
unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071
|
| 737 |
+
unet_diffusers_config["only_cross_attention"] = (
|
| 738 |
+
bool(original_unet_config["disable_self_attentions"])
|
| 739 |
+
if (
|
| 740 |
+
"disable_self_attentions" in original_unet_config
|
| 741 |
+
and isinstance(original_unet_config["disable_self_attentions"], int)
|
| 742 |
+
)
|
| 743 |
+
else True
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
if sample_size is None:
|
| 747 |
+
unet_diffusers_config["sample_size"] = original_unet_config["image_size"]
|
| 748 |
+
else:
|
| 749 |
+
# The second upscaler unet's sample size is incorrectly specified
|
| 750 |
+
# in the config and is instead hardcoded in source
|
| 751 |
+
unet_diffusers_config["sample_size"] = sample_size
|
| 752 |
+
|
| 753 |
+
unet_checkpoint = torch.load(os.path.join(unet_checkpoint_path, "pytorch_model.bin"), map_location="cpu")
|
| 754 |
+
|
| 755 |
+
if verify_param_count:
|
| 756 |
+
# check that architecture matches - is a bit slow
|
| 757 |
+
verify_param_count(orig_path, unet_diffusers_config)
|
| 758 |
+
|
| 759 |
+
converted_unet_checkpoint = superres_convert_ldm_unet_checkpoint(
|
| 760 |
+
unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path
|
| 761 |
+
)
|
| 762 |
+
converted_keys = converted_unet_checkpoint.keys()
|
| 763 |
+
|
| 764 |
+
model = UNet2DConditionModel(**unet_diffusers_config)
|
| 765 |
+
expected_weights = model.state_dict().keys()
|
| 766 |
+
|
| 767 |
+
diff_c_e = set(converted_keys) - set(expected_weights)
|
| 768 |
+
diff_e_c = set(expected_weights) - set(converted_keys)
|
| 769 |
+
|
| 770 |
+
assert len(diff_e_c) == 0, f"Expected, but not converted: {diff_e_c}"
|
| 771 |
+
assert len(diff_c_e) == 0, f"Converted, but not expected: {diff_c_e}"
|
| 772 |
+
|
| 773 |
+
model.load_state_dict(converted_unet_checkpoint)
|
| 774 |
+
|
| 775 |
+
return model
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
def superres_create_unet_diffusers_config(original_unet_config):
|
| 779 |
+
attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
|
| 780 |
+
attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
|
| 781 |
+
|
| 782 |
+
channel_mult = parse_list(original_unet_config["channel_mult"])
|
| 783 |
+
block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
|
| 784 |
+
|
| 785 |
+
down_block_types = []
|
| 786 |
+
resolution = 1
|
| 787 |
+
|
| 788 |
+
for i in range(len(block_out_channels)):
|
| 789 |
+
if resolution in attention_resolutions:
|
| 790 |
+
block_type = "SimpleCrossAttnDownBlock2D"
|
| 791 |
+
elif original_unet_config["resblock_updown"]:
|
| 792 |
+
block_type = "ResnetDownsampleBlock2D"
|
| 793 |
+
else:
|
| 794 |
+
block_type = "DownBlock2D"
|
| 795 |
+
|
| 796 |
+
down_block_types.append(block_type)
|
| 797 |
+
|
| 798 |
+
if i != len(block_out_channels) - 1:
|
| 799 |
+
resolution *= 2
|
| 800 |
+
|
| 801 |
+
up_block_types = []
|
| 802 |
+
for i in range(len(block_out_channels)):
|
| 803 |
+
if resolution in attention_resolutions:
|
| 804 |
+
block_type = "SimpleCrossAttnUpBlock2D"
|
| 805 |
+
elif original_unet_config["resblock_updown"]:
|
| 806 |
+
block_type = "ResnetUpsampleBlock2D"
|
| 807 |
+
else:
|
| 808 |
+
block_type = "UpBlock2D"
|
| 809 |
+
up_block_types.append(block_type)
|
| 810 |
+
resolution //= 2
|
| 811 |
+
|
| 812 |
+
head_dim = original_unet_config["num_head_channels"]
|
| 813 |
+
use_linear_projection = (
|
| 814 |
+
original_unet_config["use_linear_in_transformer"]
|
| 815 |
+
if "use_linear_in_transformer" in original_unet_config
|
| 816 |
+
else False
|
| 817 |
+
)
|
| 818 |
+
if use_linear_projection:
|
| 819 |
+
# stable diffusion 2-base-512 and 2-768
|
| 820 |
+
if head_dim is None:
|
| 821 |
+
head_dim = [5, 10, 20, 20]
|
| 822 |
+
|
| 823 |
+
class_embed_type = None
|
| 824 |
+
projection_class_embeddings_input_dim = None
|
| 825 |
+
|
| 826 |
+
if "num_classes" in original_unet_config:
|
| 827 |
+
if original_unet_config["num_classes"] == "sequential":
|
| 828 |
+
class_embed_type = "projection"
|
| 829 |
+
assert "adm_in_channels" in original_unet_config
|
| 830 |
+
projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
|
| 831 |
+
else:
|
| 832 |
+
raise NotImplementedError(
|
| 833 |
+
f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
config = {
|
| 837 |
+
"in_channels": original_unet_config["in_channels"],
|
| 838 |
+
"down_block_types": tuple(down_block_types),
|
| 839 |
+
"block_out_channels": tuple(block_out_channels),
|
| 840 |
+
"layers_per_block": tuple(original_unet_config["num_res_blocks"]),
|
| 841 |
+
"cross_attention_dim": original_unet_config["encoder_channels"],
|
| 842 |
+
"attention_head_dim": head_dim,
|
| 843 |
+
"use_linear_projection": use_linear_projection,
|
| 844 |
+
"class_embed_type": class_embed_type,
|
| 845 |
+
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
| 846 |
+
"out_channels": original_unet_config["out_channels"],
|
| 847 |
+
"up_block_types": tuple(up_block_types),
|
| 848 |
+
"upcast_attention": False, # TODO: guessing
|
| 849 |
+
"cross_attention_norm": "group_norm",
|
| 850 |
+
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
|
| 851 |
+
"act_fn": "gelu",
|
| 852 |
+
}
|
| 853 |
+
|
| 854 |
+
if original_unet_config["use_scale_shift_norm"]:
|
| 855 |
+
config["resnet_time_scale_shift"] = "scale_shift"
|
| 856 |
+
|
| 857 |
+
return config
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
def superres_convert_ldm_unet_checkpoint(unet_state_dict, config, path=None, extract_ema=False):
|
| 861 |
+
"""
|
| 862 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
| 863 |
+
"""
|
| 864 |
+
new_checkpoint = {}
|
| 865 |
+
|
| 866 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
| 867 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
| 868 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
| 869 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
| 870 |
+
|
| 871 |
+
if config["class_embed_type"] is None:
|
| 872 |
+
# No parameters to port
|
| 873 |
+
...
|
| 874 |
+
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
| 875 |
+
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["aug_proj.0.weight"]
|
| 876 |
+
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["aug_proj.0.bias"]
|
| 877 |
+
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["aug_proj.2.weight"]
|
| 878 |
+
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["aug_proj.2.bias"]
|
| 879 |
+
else:
|
| 880 |
+
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
| 881 |
+
|
| 882 |
+
if "encoder_proj.weight" in unet_state_dict:
|
| 883 |
+
new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict["encoder_proj.weight"]
|
| 884 |
+
new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict["encoder_proj.bias"]
|
| 885 |
+
|
| 886 |
+
if "encoder_pooling.0.weight" in unet_state_dict:
|
| 887 |
+
mapping = {
|
| 888 |
+
"encoder_pooling.0": "add_embedding.norm1",
|
| 889 |
+
"encoder_pooling.1": "add_embedding.pool",
|
| 890 |
+
"encoder_pooling.2": "add_embedding.proj",
|
| 891 |
+
"encoder_pooling.3": "add_embedding.norm2",
|
| 892 |
+
}
|
| 893 |
+
for key in unet_state_dict.keys():
|
| 894 |
+
if key.startswith("encoder_pooling"):
|
| 895 |
+
prefix = key[: len("encoder_pooling.0")]
|
| 896 |
+
new_key = key.replace(prefix, mapping[prefix])
|
| 897 |
+
new_checkpoint[new_key] = unet_state_dict[key]
|
| 898 |
+
|
| 899 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
| 900 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
| 901 |
+
|
| 902 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
| 903 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
| 904 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
| 905 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
| 906 |
+
|
| 907 |
+
# Retrieves the keys for the input blocks only
|
| 908 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
| 909 |
+
input_blocks = {
|
| 910 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
| 911 |
+
for layer_id in range(num_input_blocks)
|
| 912 |
+
}
|
| 913 |
+
|
| 914 |
+
# Retrieves the keys for the middle blocks only
|
| 915 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
| 916 |
+
middle_blocks = {
|
| 917 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
| 918 |
+
for layer_id in range(num_middle_blocks)
|
| 919 |
+
}
|
| 920 |
+
|
| 921 |
+
# Retrieves the keys for the output blocks only
|
| 922 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
| 923 |
+
output_blocks = {
|
| 924 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
| 925 |
+
for layer_id in range(num_output_blocks)
|
| 926 |
+
}
|
| 927 |
+
if not isinstance(config["layers_per_block"], int):
|
| 928 |
+
layers_per_block_list = [e + 1 for e in config["layers_per_block"]]
|
| 929 |
+
layers_per_block_cumsum = list(np.cumsum(layers_per_block_list))
|
| 930 |
+
downsampler_ids = layers_per_block_cumsum
|
| 931 |
+
else:
|
| 932 |
+
# TODO need better check than i in [4, 8, 12, 16]
|
| 933 |
+
downsampler_ids = [4, 8, 12, 16]
|
| 934 |
+
|
| 935 |
+
for i in range(1, num_input_blocks):
|
| 936 |
+
if isinstance(config["layers_per_block"], int):
|
| 937 |
+
layers_per_block = config["layers_per_block"]
|
| 938 |
+
block_id = (i - 1) // (layers_per_block + 1)
|
| 939 |
+
layer_in_block_id = (i - 1) % (layers_per_block + 1)
|
| 940 |
+
else:
|
| 941 |
+
block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if (i - 1) < n)
|
| 942 |
+
passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0
|
| 943 |
+
layer_in_block_id = (i - 1) - passed_blocks
|
| 944 |
+
|
| 945 |
+
resnets = [
|
| 946 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
| 947 |
+
]
|
| 948 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
| 949 |
+
|
| 950 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
| 951 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
| 952 |
+
f"input_blocks.{i}.0.op.weight"
|
| 953 |
+
)
|
| 954 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
| 955 |
+
f"input_blocks.{i}.0.op.bias"
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
paths = renew_resnet_paths(resnets)
|
| 959 |
+
|
| 960 |
+
block_type = config["down_block_types"][block_id]
|
| 961 |
+
if (
|
| 962 |
+
block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D"
|
| 963 |
+
) and i in downsampler_ids:
|
| 964 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"}
|
| 965 |
+
else:
|
| 966 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 967 |
+
|
| 968 |
+
assign_to_checkpoint(
|
| 969 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
if len(attentions):
|
| 973 |
+
old_path = f"input_blocks.{i}.1"
|
| 974 |
+
new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
|
| 975 |
+
|
| 976 |
+
assign_attention_to_checkpoint(
|
| 977 |
+
new_checkpoint=new_checkpoint,
|
| 978 |
+
unet_state_dict=unet_state_dict,
|
| 979 |
+
old_path=old_path,
|
| 980 |
+
new_path=new_path,
|
| 981 |
+
config=config,
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
paths = renew_attention_paths(attentions)
|
| 985 |
+
meta_path = {"old": old_path, "new": new_path}
|
| 986 |
+
assign_to_checkpoint(
|
| 987 |
+
paths,
|
| 988 |
+
new_checkpoint,
|
| 989 |
+
unet_state_dict,
|
| 990 |
+
additional_replacements=[meta_path],
|
| 991 |
+
config=config,
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
resnet_0 = middle_blocks[0]
|
| 995 |
+
attentions = middle_blocks[1]
|
| 996 |
+
resnet_1 = middle_blocks[2]
|
| 997 |
+
|
| 998 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
| 999 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
| 1000 |
+
|
| 1001 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
| 1002 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
| 1003 |
+
|
| 1004 |
+
old_path = "middle_block.1"
|
| 1005 |
+
new_path = "mid_block.attentions.0"
|
| 1006 |
+
|
| 1007 |
+
assign_attention_to_checkpoint(
|
| 1008 |
+
new_checkpoint=new_checkpoint,
|
| 1009 |
+
unet_state_dict=unet_state_dict,
|
| 1010 |
+
old_path=old_path,
|
| 1011 |
+
new_path=new_path,
|
| 1012 |
+
config=config,
|
| 1013 |
+
)
|
| 1014 |
+
|
| 1015 |
+
attentions_paths = renew_attention_paths(attentions)
|
| 1016 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
| 1017 |
+
assign_to_checkpoint(
|
| 1018 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 1019 |
+
)
|
| 1020 |
+
if not isinstance(config["layers_per_block"], int):
|
| 1021 |
+
layers_per_block_list = list(reversed([e + 1 for e in config["layers_per_block"]]))
|
| 1022 |
+
layers_per_block_cumsum = list(np.cumsum(layers_per_block_list))
|
| 1023 |
+
|
| 1024 |
+
for i in range(num_output_blocks):
|
| 1025 |
+
if isinstance(config["layers_per_block"], int):
|
| 1026 |
+
layers_per_block = config["layers_per_block"]
|
| 1027 |
+
block_id = i // (layers_per_block + 1)
|
| 1028 |
+
layer_in_block_id = i % (layers_per_block + 1)
|
| 1029 |
+
else:
|
| 1030 |
+
block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if i < n)
|
| 1031 |
+
passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0
|
| 1032 |
+
layer_in_block_id = i - passed_blocks
|
| 1033 |
+
|
| 1034 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
| 1035 |
+
output_block_list = {}
|
| 1036 |
+
|
| 1037 |
+
for layer in output_block_layers:
|
| 1038 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
| 1039 |
+
if layer_id in output_block_list:
|
| 1040 |
+
output_block_list[layer_id].append(layer_name)
|
| 1041 |
+
else:
|
| 1042 |
+
output_block_list[layer_id] = [layer_name]
|
| 1043 |
+
|
| 1044 |
+
# len(output_block_list) == 1 -> resnet
|
| 1045 |
+
# len(output_block_list) == 2 -> resnet, attention or resnet, upscale resnet
|
| 1046 |
+
# len(output_block_list) == 3 -> resnet, attention, upscale resnet
|
| 1047 |
+
|
| 1048 |
+
if len(output_block_list) > 1:
|
| 1049 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
| 1050 |
+
|
| 1051 |
+
has_attention = True
|
| 1052 |
+
if len(output_block_list) == 2 and any("in_layers" in k for k in output_block_list["1"]):
|
| 1053 |
+
has_attention = False
|
| 1054 |
+
|
| 1055 |
+
maybe_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
| 1056 |
+
|
| 1057 |
+
paths = renew_resnet_paths(resnets)
|
| 1058 |
+
|
| 1059 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 1060 |
+
|
| 1061 |
+
assign_to_checkpoint(
|
| 1062 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
| 1066 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
| 1067 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
| 1068 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
| 1069 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
| 1070 |
+
]
|
| 1071 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
| 1072 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
| 1073 |
+
]
|
| 1074 |
+
|
| 1075 |
+
# this layer was no attention
|
| 1076 |
+
has_attention = False
|
| 1077 |
+
maybe_attentions = []
|
| 1078 |
+
|
| 1079 |
+
if has_attention:
|
| 1080 |
+
old_path = f"output_blocks.{i}.1"
|
| 1081 |
+
new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}"
|
| 1082 |
+
|
| 1083 |
+
assign_attention_to_checkpoint(
|
| 1084 |
+
new_checkpoint=new_checkpoint,
|
| 1085 |
+
unet_state_dict=unet_state_dict,
|
| 1086 |
+
old_path=old_path,
|
| 1087 |
+
new_path=new_path,
|
| 1088 |
+
config=config,
|
| 1089 |
+
)
|
| 1090 |
+
|
| 1091 |
+
paths = renew_attention_paths(maybe_attentions)
|
| 1092 |
+
meta_path = {
|
| 1093 |
+
"old": old_path,
|
| 1094 |
+
"new": new_path,
|
| 1095 |
+
}
|
| 1096 |
+
assign_to_checkpoint(
|
| 1097 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
if len(output_block_list) == 3 or (not has_attention and len(maybe_attentions) > 0):
|
| 1101 |
+
layer_id = len(output_block_list) - 1
|
| 1102 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.{layer_id}" in key]
|
| 1103 |
+
paths = renew_resnet_paths(resnets)
|
| 1104 |
+
meta_path = {"old": f"output_blocks.{i}.{layer_id}", "new": f"up_blocks.{block_id}.upsamplers.0"}
|
| 1105 |
+
assign_to_checkpoint(
|
| 1106 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 1107 |
+
)
|
| 1108 |
+
else:
|
| 1109 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
| 1110 |
+
for path in resnet_0_paths:
|
| 1111 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
| 1112 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
| 1113 |
+
|
| 1114 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
| 1115 |
+
|
| 1116 |
+
return new_checkpoint
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
def verify_param_count(orig_path, unet_diffusers_config):
|
| 1120 |
+
if "-II-" in orig_path:
|
| 1121 |
+
from deepfloyd_if.modules import IFStageII
|
| 1122 |
+
|
| 1123 |
+
if_II = IFStageII(device="cpu", dir_or_name=orig_path)
|
| 1124 |
+
elif "-III-" in orig_path:
|
| 1125 |
+
from deepfloyd_if.modules import IFStageIII
|
| 1126 |
+
|
| 1127 |
+
if_II = IFStageIII(device="cpu", dir_or_name=orig_path)
|
| 1128 |
+
else:
|
| 1129 |
+
assert f"Weird name. Should have -II- or -III- in path: {orig_path}"
|
| 1130 |
+
|
| 1131 |
+
unet = UNet2DConditionModel(**unet_diffusers_config)
|
| 1132 |
+
|
| 1133 |
+
# in params
|
| 1134 |
+
assert_param_count(unet.time_embedding, if_II.model.time_embed)
|
| 1135 |
+
assert_param_count(unet.conv_in, if_II.model.input_blocks[:1])
|
| 1136 |
+
|
| 1137 |
+
# downblocks
|
| 1138 |
+
assert_param_count(unet.down_blocks[0], if_II.model.input_blocks[1:4])
|
| 1139 |
+
assert_param_count(unet.down_blocks[1], if_II.model.input_blocks[4:7])
|
| 1140 |
+
assert_param_count(unet.down_blocks[2], if_II.model.input_blocks[7:11])
|
| 1141 |
+
|
| 1142 |
+
if "-II-" in orig_path:
|
| 1143 |
+
assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:17])
|
| 1144 |
+
assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[17:])
|
| 1145 |
+
if "-III-" in orig_path:
|
| 1146 |
+
assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:15])
|
| 1147 |
+
assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[15:20])
|
| 1148 |
+
assert_param_count(unet.down_blocks[5], if_II.model.input_blocks[20:])
|
| 1149 |
+
|
| 1150 |
+
# mid block
|
| 1151 |
+
assert_param_count(unet.mid_block, if_II.model.middle_block)
|
| 1152 |
+
|
| 1153 |
+
# up block
|
| 1154 |
+
if "-II-" in orig_path:
|
| 1155 |
+
assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:6])
|
| 1156 |
+
assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[6:12])
|
| 1157 |
+
assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[12:16])
|
| 1158 |
+
assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[16:19])
|
| 1159 |
+
assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[19:])
|
| 1160 |
+
if "-III-" in orig_path:
|
| 1161 |
+
assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:5])
|
| 1162 |
+
assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[5:10])
|
| 1163 |
+
assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[10:14])
|
| 1164 |
+
assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[14:18])
|
| 1165 |
+
assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[18:21])
|
| 1166 |
+
assert_param_count(unet.up_blocks[5], if_II.model.output_blocks[21:24])
|
| 1167 |
+
|
| 1168 |
+
# out params
|
| 1169 |
+
assert_param_count(unet.conv_norm_out, if_II.model.out[0])
|
| 1170 |
+
assert_param_count(unet.conv_out, if_II.model.out[2])
|
| 1171 |
+
|
| 1172 |
+
# make sure all model architecture has same param count
|
| 1173 |
+
assert_param_count(unet, if_II.model)
|
| 1174 |
+
|
| 1175 |
+
|
| 1176 |
+
def assert_param_count(model_1, model_2):
|
| 1177 |
+
count_1 = sum(p.numel() for p in model_1.parameters())
|
| 1178 |
+
count_2 = sum(p.numel() for p in model_2.parameters())
|
| 1179 |
+
assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}"
|
| 1180 |
+
|
| 1181 |
+
|
| 1182 |
+
def superres_check_against_original(dump_path, unet_checkpoint_path):
|
| 1183 |
+
model_path = dump_path
|
| 1184 |
+
model = UNet2DConditionModel.from_pretrained(model_path)
|
| 1185 |
+
model.to("cuda")
|
| 1186 |
+
orig_path = unet_checkpoint_path
|
| 1187 |
+
|
| 1188 |
+
if "-II-" in orig_path:
|
| 1189 |
+
from deepfloyd_if.modules import IFStageII
|
| 1190 |
+
|
| 1191 |
+
if_II_model = IFStageII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model
|
| 1192 |
+
elif "-III-" in orig_path:
|
| 1193 |
+
from deepfloyd_if.modules import IFStageIII
|
| 1194 |
+
|
| 1195 |
+
if_II_model = IFStageIII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model
|
| 1196 |
+
|
| 1197 |
+
batch_size = 1
|
| 1198 |
+
channels = model.config.in_channels // 2
|
| 1199 |
+
height = model.config.sample_size
|
| 1200 |
+
width = model.config.sample_size
|
| 1201 |
+
height = 1024
|
| 1202 |
+
width = 1024
|
| 1203 |
+
|
| 1204 |
+
torch.manual_seed(0)
|
| 1205 |
+
|
| 1206 |
+
latents = torch.randn((batch_size, channels, height, width), device=model.device)
|
| 1207 |
+
image_small = torch.randn((batch_size, channels, height // 4, width // 4), device=model.device)
|
| 1208 |
+
|
| 1209 |
+
interpolate_antialias = {}
|
| 1210 |
+
if "antialias" in inspect.signature(F.interpolate).parameters:
|
| 1211 |
+
interpolate_antialias["antialias"] = True
|
| 1212 |
+
image_upscaled = F.interpolate(
|
| 1213 |
+
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
|
| 1214 |
+
)
|
| 1215 |
+
|
| 1216 |
+
latent_model_input = torch.cat([latents, image_upscaled], dim=1).to(model.dtype)
|
| 1217 |
+
t = torch.tensor([5], device=model.device).to(model.dtype)
|
| 1218 |
+
|
| 1219 |
+
seq_len = 64
|
| 1220 |
+
encoder_hidden_states = torch.randn((batch_size, seq_len, model.config.encoder_hid_dim), device=model.device).to(
|
| 1221 |
+
model.dtype
|
| 1222 |
+
)
|
| 1223 |
+
|
| 1224 |
+
fake_class_labels = torch.tensor([t], device=model.device).to(model.dtype)
|
| 1225 |
+
|
| 1226 |
+
with torch.no_grad():
|
| 1227 |
+
out = if_II_model(latent_model_input, t, aug_steps=fake_class_labels, text_emb=encoder_hidden_states)
|
| 1228 |
+
|
| 1229 |
+
if_II_model.to("cpu")
|
| 1230 |
+
del if_II_model
|
| 1231 |
+
import gc
|
| 1232 |
+
|
| 1233 |
+
torch.cuda.empty_cache()
|
| 1234 |
+
gc.collect()
|
| 1235 |
+
print(50 * "=")
|
| 1236 |
+
|
| 1237 |
+
with torch.no_grad():
|
| 1238 |
+
noise_pred = model(
|
| 1239 |
+
sample=latent_model_input,
|
| 1240 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1241 |
+
class_labels=fake_class_labels,
|
| 1242 |
+
timestep=t,
|
| 1243 |
+
).sample
|
| 1244 |
+
|
| 1245 |
+
print("Out shape", noise_pred.shape)
|
| 1246 |
+
print("Diff", (out - noise_pred).abs().sum())
|
| 1247 |
+
|
| 1248 |
+
|
| 1249 |
+
if __name__ == "__main__":
|
| 1250 |
+
main(parse_args())
|
diffusers/scripts/convert_lora_safetensor_to_diffusers.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024, Haofan Wang, Qixun Wang, All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Conversion script for the LoRA's safetensors checkpoints."""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from safetensors.torch import load_file
|
| 22 |
+
|
| 23 |
+
from diffusers import StableDiffusionPipeline
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def convert(base_model_path, checkpoint_path, LORA_PREFIX_UNET, LORA_PREFIX_TEXT_ENCODER, alpha):
|
| 27 |
+
# load base model
|
| 28 |
+
pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
|
| 29 |
+
|
| 30 |
+
# load LoRA weight from .safetensors
|
| 31 |
+
state_dict = load_file(checkpoint_path)
|
| 32 |
+
|
| 33 |
+
visited = []
|
| 34 |
+
|
| 35 |
+
# directly update weight in diffusers model
|
| 36 |
+
for key in state_dict:
|
| 37 |
+
# it is suggested to print out the key, it usually will be something like below
|
| 38 |
+
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
| 39 |
+
|
| 40 |
+
# as we have set the alpha beforehand, so just skip
|
| 41 |
+
if ".alpha" in key or key in visited:
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
if "text" in key:
|
| 45 |
+
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
| 46 |
+
curr_layer = pipeline.text_encoder
|
| 47 |
+
else:
|
| 48 |
+
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
|
| 49 |
+
curr_layer = pipeline.unet
|
| 50 |
+
|
| 51 |
+
# find the target layer
|
| 52 |
+
temp_name = layer_infos.pop(0)
|
| 53 |
+
while len(layer_infos) > -1:
|
| 54 |
+
try:
|
| 55 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
| 56 |
+
if len(layer_infos) > 0:
|
| 57 |
+
temp_name = layer_infos.pop(0)
|
| 58 |
+
elif len(layer_infos) == 0:
|
| 59 |
+
break
|
| 60 |
+
except Exception:
|
| 61 |
+
if len(temp_name) > 0:
|
| 62 |
+
temp_name += "_" + layer_infos.pop(0)
|
| 63 |
+
else:
|
| 64 |
+
temp_name = layer_infos.pop(0)
|
| 65 |
+
|
| 66 |
+
pair_keys = []
|
| 67 |
+
if "lora_down" in key:
|
| 68 |
+
pair_keys.append(key.replace("lora_down", "lora_up"))
|
| 69 |
+
pair_keys.append(key)
|
| 70 |
+
else:
|
| 71 |
+
pair_keys.append(key)
|
| 72 |
+
pair_keys.append(key.replace("lora_up", "lora_down"))
|
| 73 |
+
|
| 74 |
+
# update weight
|
| 75 |
+
if len(state_dict[pair_keys[0]].shape) == 4:
|
| 76 |
+
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
| 77 |
+
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
| 78 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
| 79 |
+
else:
|
| 80 |
+
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
| 81 |
+
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
| 82 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
|
| 83 |
+
|
| 84 |
+
# update visited list
|
| 85 |
+
for item in pair_keys:
|
| 86 |
+
visited.append(item)
|
| 87 |
+
|
| 88 |
+
return pipeline
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
parser = argparse.ArgumentParser()
|
| 93 |
+
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--lora_prefix_text_encoder",
|
| 106 |
+
default="lora_te",
|
| 107 |
+
type=str,
|
| 108 |
+
help="The prefix of text encoder weight in safetensors",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
| 115 |
+
|
| 116 |
+
args = parser.parse_args()
|
| 117 |
+
|
| 118 |
+
base_model_path = args.base_model_path
|
| 119 |
+
checkpoint_path = args.checkpoint_path
|
| 120 |
+
dump_path = args.dump_path
|
| 121 |
+
lora_prefix_unet = args.lora_prefix_unet
|
| 122 |
+
lora_prefix_text_encoder = args.lora_prefix_text_encoder
|
| 123 |
+
alpha = args.alpha
|
| 124 |
+
|
| 125 |
+
pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
|
| 126 |
+
|
| 127 |
+
pipe = pipe.to(args.device)
|
| 128 |
+
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
diffusers/scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Conversion script for the NCSNPP checkpoints."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def convert_ncsnpp_checkpoint(checkpoint, config):
|
| 26 |
+
"""
|
| 27 |
+
Takes a state dict and the path to
|
| 28 |
+
"""
|
| 29 |
+
new_model_architecture = UNet2DModel(**config)
|
| 30 |
+
new_model_architecture.time_proj.W.data = checkpoint["all_modules.0.W"].data
|
| 31 |
+
new_model_architecture.time_proj.weight.data = checkpoint["all_modules.0.W"].data
|
| 32 |
+
new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data
|
| 33 |
+
new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data
|
| 34 |
+
|
| 35 |
+
new_model_architecture.time_embedding.linear_2.weight.data = checkpoint["all_modules.2.weight"].data
|
| 36 |
+
new_model_architecture.time_embedding.linear_2.bias.data = checkpoint["all_modules.2.bias"].data
|
| 37 |
+
|
| 38 |
+
new_model_architecture.conv_in.weight.data = checkpoint["all_modules.3.weight"].data
|
| 39 |
+
new_model_architecture.conv_in.bias.data = checkpoint["all_modules.3.bias"].data
|
| 40 |
+
|
| 41 |
+
new_model_architecture.conv_norm_out.weight.data = checkpoint[list(checkpoint.keys())[-4]].data
|
| 42 |
+
new_model_architecture.conv_norm_out.bias.data = checkpoint[list(checkpoint.keys())[-3]].data
|
| 43 |
+
new_model_architecture.conv_out.weight.data = checkpoint[list(checkpoint.keys())[-2]].data
|
| 44 |
+
new_model_architecture.conv_out.bias.data = checkpoint[list(checkpoint.keys())[-1]].data
|
| 45 |
+
|
| 46 |
+
module_index = 4
|
| 47 |
+
|
| 48 |
+
def set_attention_weights(new_layer, old_checkpoint, index):
|
| 49 |
+
new_layer.query.weight.data = old_checkpoint[f"all_modules.{index}.NIN_0.W"].data.T
|
| 50 |
+
new_layer.key.weight.data = old_checkpoint[f"all_modules.{index}.NIN_1.W"].data.T
|
| 51 |
+
new_layer.value.weight.data = old_checkpoint[f"all_modules.{index}.NIN_2.W"].data.T
|
| 52 |
+
|
| 53 |
+
new_layer.query.bias.data = old_checkpoint[f"all_modules.{index}.NIN_0.b"].data
|
| 54 |
+
new_layer.key.bias.data = old_checkpoint[f"all_modules.{index}.NIN_1.b"].data
|
| 55 |
+
new_layer.value.bias.data = old_checkpoint[f"all_modules.{index}.NIN_2.b"].data
|
| 56 |
+
|
| 57 |
+
new_layer.proj_attn.weight.data = old_checkpoint[f"all_modules.{index}.NIN_3.W"].data.T
|
| 58 |
+
new_layer.proj_attn.bias.data = old_checkpoint[f"all_modules.{index}.NIN_3.b"].data
|
| 59 |
+
|
| 60 |
+
new_layer.group_norm.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data
|
| 61 |
+
new_layer.group_norm.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data
|
| 62 |
+
|
| 63 |
+
def set_resnet_weights(new_layer, old_checkpoint, index):
|
| 64 |
+
new_layer.conv1.weight.data = old_checkpoint[f"all_modules.{index}.Conv_0.weight"].data
|
| 65 |
+
new_layer.conv1.bias.data = old_checkpoint[f"all_modules.{index}.Conv_0.bias"].data
|
| 66 |
+
new_layer.norm1.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data
|
| 67 |
+
new_layer.norm1.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data
|
| 68 |
+
|
| 69 |
+
new_layer.conv2.weight.data = old_checkpoint[f"all_modules.{index}.Conv_1.weight"].data
|
| 70 |
+
new_layer.conv2.bias.data = old_checkpoint[f"all_modules.{index}.Conv_1.bias"].data
|
| 71 |
+
new_layer.norm2.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.weight"].data
|
| 72 |
+
new_layer.norm2.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.bias"].data
|
| 73 |
+
|
| 74 |
+
new_layer.time_emb_proj.weight.data = old_checkpoint[f"all_modules.{index}.Dense_0.weight"].data
|
| 75 |
+
new_layer.time_emb_proj.bias.data = old_checkpoint[f"all_modules.{index}.Dense_0.bias"].data
|
| 76 |
+
|
| 77 |
+
if new_layer.in_channels != new_layer.out_channels or new_layer.up or new_layer.down:
|
| 78 |
+
new_layer.conv_shortcut.weight.data = old_checkpoint[f"all_modules.{index}.Conv_2.weight"].data
|
| 79 |
+
new_layer.conv_shortcut.bias.data = old_checkpoint[f"all_modules.{index}.Conv_2.bias"].data
|
| 80 |
+
|
| 81 |
+
for i, block in enumerate(new_model_architecture.downsample_blocks):
|
| 82 |
+
has_attentions = hasattr(block, "attentions")
|
| 83 |
+
for j in range(len(block.resnets)):
|
| 84 |
+
set_resnet_weights(block.resnets[j], checkpoint, module_index)
|
| 85 |
+
module_index += 1
|
| 86 |
+
if has_attentions:
|
| 87 |
+
set_attention_weights(block.attentions[j], checkpoint, module_index)
|
| 88 |
+
module_index += 1
|
| 89 |
+
|
| 90 |
+
if hasattr(block, "downsamplers") and block.downsamplers is not None:
|
| 91 |
+
set_resnet_weights(block.resnet_down, checkpoint, module_index)
|
| 92 |
+
module_index += 1
|
| 93 |
+
block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.Conv_0.weight"].data
|
| 94 |
+
block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data
|
| 95 |
+
module_index += 1
|
| 96 |
+
|
| 97 |
+
set_resnet_weights(new_model_architecture.mid_block.resnets[0], checkpoint, module_index)
|
| 98 |
+
module_index += 1
|
| 99 |
+
set_attention_weights(new_model_architecture.mid_block.attentions[0], checkpoint, module_index)
|
| 100 |
+
module_index += 1
|
| 101 |
+
set_resnet_weights(new_model_architecture.mid_block.resnets[1], checkpoint, module_index)
|
| 102 |
+
module_index += 1
|
| 103 |
+
|
| 104 |
+
for i, block in enumerate(new_model_architecture.up_blocks):
|
| 105 |
+
has_attentions = hasattr(block, "attentions")
|
| 106 |
+
for j in range(len(block.resnets)):
|
| 107 |
+
set_resnet_weights(block.resnets[j], checkpoint, module_index)
|
| 108 |
+
module_index += 1
|
| 109 |
+
if has_attentions:
|
| 110 |
+
set_attention_weights(
|
| 111 |
+
block.attentions[0], checkpoint, module_index
|
| 112 |
+
) # why can there only be a single attention layer for up?
|
| 113 |
+
module_index += 1
|
| 114 |
+
|
| 115 |
+
if hasattr(block, "resnet_up") and block.resnet_up is not None:
|
| 116 |
+
block.skip_norm.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
|
| 117 |
+
block.skip_norm.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
|
| 118 |
+
module_index += 1
|
| 119 |
+
block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
|
| 120 |
+
block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
|
| 121 |
+
module_index += 1
|
| 122 |
+
set_resnet_weights(block.resnet_up, checkpoint, module_index)
|
| 123 |
+
module_index += 1
|
| 124 |
+
|
| 125 |
+
new_model_architecture.conv_norm_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
|
| 126 |
+
new_model_architecture.conv_norm_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
|
| 127 |
+
module_index += 1
|
| 128 |
+
new_model_architecture.conv_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
|
| 129 |
+
new_model_architecture.conv_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
|
| 130 |
+
|
| 131 |
+
return new_model_architecture.state_dict()
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
if __name__ == "__main__":
|
| 135 |
+
parser = argparse.ArgumentParser()
|
| 136 |
+
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--checkpoint_path",
|
| 139 |
+
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_pytorch_model.bin",
|
| 140 |
+
type=str,
|
| 141 |
+
required=False,
|
| 142 |
+
help="Path to the checkpoint to convert.",
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
parser.add_argument(
|
| 146 |
+
"--config_file",
|
| 147 |
+
default="/Users/arthurzucker/Work/diffusers/ArthurZ/config.json",
|
| 148 |
+
type=str,
|
| 149 |
+
required=False,
|
| 150 |
+
help="The config json file corresponding to the architecture.",
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
parser.add_argument(
|
| 154 |
+
"--dump_path",
|
| 155 |
+
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model_new.pt",
|
| 156 |
+
type=str,
|
| 157 |
+
required=False,
|
| 158 |
+
help="Path to the output model.",
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
args = parser.parse_args()
|
| 162 |
+
|
| 163 |
+
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
| 164 |
+
|
| 165 |
+
with open(args.config_file) as f:
|
| 166 |
+
config = json.loads(f.read())
|
| 167 |
+
|
| 168 |
+
converted_checkpoint = convert_ncsnpp_checkpoint(
|
| 169 |
+
checkpoint,
|
| 170 |
+
config,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
if "sde" in config:
|
| 174 |
+
del config["sde"]
|
| 175 |
+
|
| 176 |
+
model = UNet2DModel(**config)
|
| 177 |
+
model.load_state_dict(converted_checkpoint)
|
| 178 |
+
|
| 179 |
+
try:
|
| 180 |
+
scheduler = ScoreSdeVeScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
|
| 181 |
+
|
| 182 |
+
pipe = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
|
| 183 |
+
pipe.save_pretrained(args.dump_path)
|
| 184 |
+
except: # noqa: E722
|
| 185 |
+
model.save_pretrained(args.dump_path)
|
diffusers/scripts/convert_omnigen_to_diffusers.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from huggingface_hub import snapshot_download
|
| 6 |
+
from safetensors.torch import load_file
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
|
| 9 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main(args):
|
| 13 |
+
# checkpoint from https://huggingface.co/Shitao/OmniGen-v1
|
| 14 |
+
|
| 15 |
+
if not os.path.exists(args.origin_ckpt_path):
|
| 16 |
+
print("Model not found, downloading...")
|
| 17 |
+
cache_folder = os.getenv("HF_HUB_CACHE")
|
| 18 |
+
args.origin_ckpt_path = snapshot_download(
|
| 19 |
+
repo_id=args.origin_ckpt_path,
|
| 20 |
+
cache_dir=cache_folder,
|
| 21 |
+
ignore_patterns=["flax_model.msgpack", "rust_model.ot", "tf_model.h5", "model.pt"],
|
| 22 |
+
)
|
| 23 |
+
print(f"Downloaded model to {args.origin_ckpt_path}")
|
| 24 |
+
|
| 25 |
+
ckpt = os.path.join(args.origin_ckpt_path, "model.safetensors")
|
| 26 |
+
ckpt = load_file(ckpt, device="cpu")
|
| 27 |
+
|
| 28 |
+
mapping_dict = {
|
| 29 |
+
"pos_embed": "patch_embedding.pos_embed",
|
| 30 |
+
"x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
|
| 31 |
+
"x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
|
| 32 |
+
"input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
|
| 33 |
+
"input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
|
| 34 |
+
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
| 35 |
+
"final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
|
| 36 |
+
"final_layer.linear.weight": "proj_out.weight",
|
| 37 |
+
"final_layer.linear.bias": "proj_out.bias",
|
| 38 |
+
"time_token.mlp.0.weight": "time_token.linear_1.weight",
|
| 39 |
+
"time_token.mlp.0.bias": "time_token.linear_1.bias",
|
| 40 |
+
"time_token.mlp.2.weight": "time_token.linear_2.weight",
|
| 41 |
+
"time_token.mlp.2.bias": "time_token.linear_2.bias",
|
| 42 |
+
"t_embedder.mlp.0.weight": "t_embedder.linear_1.weight",
|
| 43 |
+
"t_embedder.mlp.0.bias": "t_embedder.linear_1.bias",
|
| 44 |
+
"t_embedder.mlp.2.weight": "t_embedder.linear_2.weight",
|
| 45 |
+
"t_embedder.mlp.2.bias": "t_embedder.linear_2.bias",
|
| 46 |
+
"llm.embed_tokens.weight": "embed_tokens.weight",
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
converted_state_dict = {}
|
| 50 |
+
for k, v in ckpt.items():
|
| 51 |
+
if k in mapping_dict:
|
| 52 |
+
converted_state_dict[mapping_dict[k]] = v
|
| 53 |
+
elif "qkv" in k:
|
| 54 |
+
to_q, to_k, to_v = v.chunk(3)
|
| 55 |
+
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_q.weight"] = to_q
|
| 56 |
+
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_k.weight"] = to_k
|
| 57 |
+
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_v.weight"] = to_v
|
| 58 |
+
elif "o_proj" in k:
|
| 59 |
+
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_out.0.weight"] = v
|
| 60 |
+
else:
|
| 61 |
+
converted_state_dict[k[4:]] = v
|
| 62 |
+
|
| 63 |
+
transformer = OmniGenTransformer2DModel(
|
| 64 |
+
rope_scaling={
|
| 65 |
+
"long_factor": [
|
| 66 |
+
1.0299999713897705,
|
| 67 |
+
1.0499999523162842,
|
| 68 |
+
1.0499999523162842,
|
| 69 |
+
1.0799999237060547,
|
| 70 |
+
1.2299998998641968,
|
| 71 |
+
1.2299998998641968,
|
| 72 |
+
1.2999999523162842,
|
| 73 |
+
1.4499999284744263,
|
| 74 |
+
1.5999999046325684,
|
| 75 |
+
1.6499998569488525,
|
| 76 |
+
1.8999998569488525,
|
| 77 |
+
2.859999895095825,
|
| 78 |
+
3.68999981880188,
|
| 79 |
+
5.419999599456787,
|
| 80 |
+
5.489999771118164,
|
| 81 |
+
5.489999771118164,
|
| 82 |
+
9.09000015258789,
|
| 83 |
+
11.579999923706055,
|
| 84 |
+
15.65999984741211,
|
| 85 |
+
15.769999504089355,
|
| 86 |
+
15.789999961853027,
|
| 87 |
+
18.360000610351562,
|
| 88 |
+
21.989999771118164,
|
| 89 |
+
23.079999923706055,
|
| 90 |
+
30.009998321533203,
|
| 91 |
+
32.35000228881836,
|
| 92 |
+
32.590003967285156,
|
| 93 |
+
35.56000518798828,
|
| 94 |
+
39.95000457763672,
|
| 95 |
+
53.840003967285156,
|
| 96 |
+
56.20000457763672,
|
| 97 |
+
57.95000457763672,
|
| 98 |
+
59.29000473022461,
|
| 99 |
+
59.77000427246094,
|
| 100 |
+
59.920005798339844,
|
| 101 |
+
61.190006256103516,
|
| 102 |
+
61.96000671386719,
|
| 103 |
+
62.50000762939453,
|
| 104 |
+
63.3700065612793,
|
| 105 |
+
63.48000717163086,
|
| 106 |
+
63.48000717163086,
|
| 107 |
+
63.66000747680664,
|
| 108 |
+
63.850006103515625,
|
| 109 |
+
64.08000946044922,
|
| 110 |
+
64.760009765625,
|
| 111 |
+
64.80001068115234,
|
| 112 |
+
64.81001281738281,
|
| 113 |
+
64.81001281738281,
|
| 114 |
+
],
|
| 115 |
+
"short_factor": [
|
| 116 |
+
1.05,
|
| 117 |
+
1.05,
|
| 118 |
+
1.05,
|
| 119 |
+
1.1,
|
| 120 |
+
1.1,
|
| 121 |
+
1.1,
|
| 122 |
+
1.2500000000000002,
|
| 123 |
+
1.2500000000000002,
|
| 124 |
+
1.4000000000000004,
|
| 125 |
+
1.4500000000000004,
|
| 126 |
+
1.5500000000000005,
|
| 127 |
+
1.8500000000000008,
|
| 128 |
+
1.9000000000000008,
|
| 129 |
+
2.000000000000001,
|
| 130 |
+
2.000000000000001,
|
| 131 |
+
2.000000000000001,
|
| 132 |
+
2.000000000000001,
|
| 133 |
+
2.000000000000001,
|
| 134 |
+
2.000000000000001,
|
| 135 |
+
2.000000000000001,
|
| 136 |
+
2.000000000000001,
|
| 137 |
+
2.000000000000001,
|
| 138 |
+
2.000000000000001,
|
| 139 |
+
2.000000000000001,
|
| 140 |
+
2.000000000000001,
|
| 141 |
+
2.000000000000001,
|
| 142 |
+
2.000000000000001,
|
| 143 |
+
2.000000000000001,
|
| 144 |
+
2.000000000000001,
|
| 145 |
+
2.000000000000001,
|
| 146 |
+
2.000000000000001,
|
| 147 |
+
2.000000000000001,
|
| 148 |
+
2.1000000000000005,
|
| 149 |
+
2.1000000000000005,
|
| 150 |
+
2.2,
|
| 151 |
+
2.3499999999999996,
|
| 152 |
+
2.3499999999999996,
|
| 153 |
+
2.3499999999999996,
|
| 154 |
+
2.3499999999999996,
|
| 155 |
+
2.3999999999999995,
|
| 156 |
+
2.3999999999999995,
|
| 157 |
+
2.6499999999999986,
|
| 158 |
+
2.6999999999999984,
|
| 159 |
+
2.8999999999999977,
|
| 160 |
+
2.9499999999999975,
|
| 161 |
+
3.049999999999997,
|
| 162 |
+
3.049999999999997,
|
| 163 |
+
3.049999999999997,
|
| 164 |
+
],
|
| 165 |
+
"type": "su",
|
| 166 |
+
},
|
| 167 |
+
patch_size=2,
|
| 168 |
+
in_channels=4,
|
| 169 |
+
pos_embed_max_size=192,
|
| 170 |
+
)
|
| 171 |
+
transformer.load_state_dict(converted_state_dict, strict=True)
|
| 172 |
+
transformer.to(torch.bfloat16)
|
| 173 |
+
|
| 174 |
+
num_model_params = sum(p.numel() for p in transformer.parameters())
|
| 175 |
+
print(f"Total number of transformer parameters: {num_model_params}")
|
| 176 |
+
|
| 177 |
+
scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1)
|
| 178 |
+
|
| 179 |
+
vae = AutoencoderKL.from_pretrained(os.path.join(args.origin_ckpt_path, "vae"), torch_dtype=torch.float32)
|
| 180 |
+
|
| 181 |
+
tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path)
|
| 182 |
+
|
| 183 |
+
pipeline = OmniGenPipeline(tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler)
|
| 184 |
+
pipeline.save_pretrained(args.dump_path)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
if __name__ == "__main__":
|
| 188 |
+
parser = argparse.ArgumentParser()
|
| 189 |
+
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--origin_ckpt_path",
|
| 192 |
+
default="Shitao/OmniGen-v1",
|
| 193 |
+
type=str,
|
| 194 |
+
required=False,
|
| 195 |
+
help="Path to the checkpoint to convert.",
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline."
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
args = parser.parse_args()
|
| 203 |
+
main(args)
|
diffusers/scripts/convert_original_audioldm2_to_diffusers.py
ADDED
|
@@ -0,0 +1,1135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Conversion script for the AudioLDM2 checkpoints."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import re
|
| 19 |
+
from typing import List, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import yaml
|
| 23 |
+
from transformers import (
|
| 24 |
+
AutoFeatureExtractor,
|
| 25 |
+
AutoTokenizer,
|
| 26 |
+
ClapConfig,
|
| 27 |
+
ClapModel,
|
| 28 |
+
GPT2Config,
|
| 29 |
+
GPT2Model,
|
| 30 |
+
SpeechT5HifiGan,
|
| 31 |
+
SpeechT5HifiGanConfig,
|
| 32 |
+
T5Config,
|
| 33 |
+
T5EncoderModel,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
from diffusers import (
|
| 37 |
+
AudioLDM2Pipeline,
|
| 38 |
+
AudioLDM2ProjectionModel,
|
| 39 |
+
AudioLDM2UNet2DConditionModel,
|
| 40 |
+
AutoencoderKL,
|
| 41 |
+
DDIMScheduler,
|
| 42 |
+
DPMSolverMultistepScheduler,
|
| 43 |
+
EulerAncestralDiscreteScheduler,
|
| 44 |
+
EulerDiscreteScheduler,
|
| 45 |
+
HeunDiscreteScheduler,
|
| 46 |
+
LMSDiscreteScheduler,
|
| 47 |
+
PNDMScheduler,
|
| 48 |
+
)
|
| 49 |
+
from diffusers.utils import is_safetensors_available
|
| 50 |
+
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
|
| 54 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
| 55 |
+
"""
|
| 56 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
| 57 |
+
"""
|
| 58 |
+
if n_shave_prefix_segments >= 0:
|
| 59 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
| 60 |
+
else:
|
| 61 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_resnet_paths
|
| 65 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 66 |
+
"""
|
| 67 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 68 |
+
"""
|
| 69 |
+
mapping = []
|
| 70 |
+
for old_item in old_list:
|
| 71 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
| 72 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
| 73 |
+
|
| 74 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
| 75 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
| 76 |
+
|
| 77 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
| 78 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
| 79 |
+
|
| 80 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 81 |
+
|
| 82 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 83 |
+
|
| 84 |
+
return mapping
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths
|
| 88 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 89 |
+
"""
|
| 90 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 91 |
+
"""
|
| 92 |
+
mapping = []
|
| 93 |
+
for old_item in old_list:
|
| 94 |
+
new_item = old_item
|
| 95 |
+
|
| 96 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
| 97 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 98 |
+
|
| 99 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 100 |
+
|
| 101 |
+
return mapping
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_attention_paths
|
| 105 |
+
def renew_attention_paths(old_list):
|
| 106 |
+
"""
|
| 107 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 108 |
+
"""
|
| 109 |
+
mapping = []
|
| 110 |
+
for old_item in old_list:
|
| 111 |
+
new_item = old_item
|
| 112 |
+
|
| 113 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
| 114 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
| 115 |
+
|
| 116 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
| 117 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
| 118 |
+
|
| 119 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 120 |
+
|
| 121 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 122 |
+
|
| 123 |
+
return mapping
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 127 |
+
"""
|
| 128 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 129 |
+
"""
|
| 130 |
+
mapping = []
|
| 131 |
+
for old_item in old_list:
|
| 132 |
+
new_item = old_item
|
| 133 |
+
|
| 134 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
| 135 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
| 136 |
+
|
| 137 |
+
new_item = new_item.replace("q.weight", "to_q.weight")
|
| 138 |
+
new_item = new_item.replace("q.bias", "to_q.bias")
|
| 139 |
+
|
| 140 |
+
new_item = new_item.replace("k.weight", "to_k.weight")
|
| 141 |
+
new_item = new_item.replace("k.bias", "to_k.bias")
|
| 142 |
+
|
| 143 |
+
new_item = new_item.replace("v.weight", "to_v.weight")
|
| 144 |
+
new_item = new_item.replace("v.bias", "to_v.bias")
|
| 145 |
+
|
| 146 |
+
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
| 147 |
+
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
| 148 |
+
|
| 149 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 150 |
+
|
| 151 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 152 |
+
|
| 153 |
+
return mapping
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def assign_to_checkpoint(
|
| 157 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
| 158 |
+
):
|
| 159 |
+
"""
|
| 160 |
+
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
| 161 |
+
attention layers, and takes into account additional replacements that may arise.
|
| 162 |
+
|
| 163 |
+
Assigns the weights to the new checkpoint.
|
| 164 |
+
"""
|
| 165 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
| 166 |
+
|
| 167 |
+
# Splits the attention layers into three variables.
|
| 168 |
+
if attention_paths_to_split is not None:
|
| 169 |
+
for path, path_map in attention_paths_to_split.items():
|
| 170 |
+
old_tensor = old_checkpoint[path]
|
| 171 |
+
channels = old_tensor.shape[0] // 3
|
| 172 |
+
|
| 173 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
| 174 |
+
|
| 175 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
| 176 |
+
|
| 177 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
| 178 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
| 179 |
+
|
| 180 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
| 181 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
| 182 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
| 183 |
+
|
| 184 |
+
for path in paths:
|
| 185 |
+
new_path = path["new"]
|
| 186 |
+
|
| 187 |
+
# These have already been assigned
|
| 188 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
| 189 |
+
continue
|
| 190 |
+
|
| 191 |
+
if additional_replacements is not None:
|
| 192 |
+
for replacement in additional_replacements:
|
| 193 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
| 194 |
+
|
| 195 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
| 196 |
+
if "proj_attn.weight" in new_path:
|
| 197 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
| 198 |
+
else:
|
| 199 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def conv_attn_to_linear(checkpoint):
|
| 203 |
+
keys = list(checkpoint.keys())
|
| 204 |
+
attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
|
| 205 |
+
proj_key = "to_out.0.weight"
|
| 206 |
+
for key in keys:
|
| 207 |
+
if ".".join(key.split(".")[-2:]) in attn_keys or ".".join(key.split(".")[-3:]) == proj_key:
|
| 208 |
+
if checkpoint[key].ndim > 2:
|
| 209 |
+
checkpoint[key] = checkpoint[key].squeeze()
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def create_unet_diffusers_config(original_config, image_size: int):
|
| 213 |
+
"""
|
| 214 |
+
Creates a UNet config for diffusers based on the config of the original AudioLDM2 model.
|
| 215 |
+
"""
|
| 216 |
+
unet_params = original_config["model"]["params"]["unet_config"]["params"]
|
| 217 |
+
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
| 218 |
+
|
| 219 |
+
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
|
| 220 |
+
|
| 221 |
+
down_block_types = []
|
| 222 |
+
resolution = 1
|
| 223 |
+
for i in range(len(block_out_channels)):
|
| 224 |
+
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
|
| 225 |
+
down_block_types.append(block_type)
|
| 226 |
+
if i != len(block_out_channels) - 1:
|
| 227 |
+
resolution *= 2
|
| 228 |
+
|
| 229 |
+
up_block_types = []
|
| 230 |
+
for i in range(len(block_out_channels)):
|
| 231 |
+
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
|
| 232 |
+
up_block_types.append(block_type)
|
| 233 |
+
resolution //= 2
|
| 234 |
+
|
| 235 |
+
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
|
| 236 |
+
|
| 237 |
+
cross_attention_dim = list(unet_params["context_dim"]) if "context_dim" in unet_params else block_out_channels
|
| 238 |
+
if len(cross_attention_dim) > 1:
|
| 239 |
+
# require two or more cross-attention layers per-block, each of different dimension
|
| 240 |
+
cross_attention_dim = [cross_attention_dim for _ in range(len(block_out_channels))]
|
| 241 |
+
|
| 242 |
+
config = {
|
| 243 |
+
"sample_size": image_size // vae_scale_factor,
|
| 244 |
+
"in_channels": unet_params["in_channels"],
|
| 245 |
+
"out_channels": unet_params["out_channels"],
|
| 246 |
+
"down_block_types": tuple(down_block_types),
|
| 247 |
+
"up_block_types": tuple(up_block_types),
|
| 248 |
+
"block_out_channels": tuple(block_out_channels),
|
| 249 |
+
"layers_per_block": unet_params["num_res_blocks"],
|
| 250 |
+
"transformer_layers_per_block": unet_params["transformer_depth"],
|
| 251 |
+
"cross_attention_dim": tuple(cross_attention_dim),
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
return config
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config
|
| 258 |
+
def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
|
| 259 |
+
"""
|
| 260 |
+
Creates a VAE config for diffusers based on the config of the original AudioLDM2 model. Compared to the original
|
| 261 |
+
Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
|
| 262 |
+
"""
|
| 263 |
+
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
| 264 |
+
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
|
| 265 |
+
|
| 266 |
+
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
| 267 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
| 268 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
| 269 |
+
|
| 270 |
+
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215
|
| 271 |
+
|
| 272 |
+
config = {
|
| 273 |
+
"sample_size": image_size,
|
| 274 |
+
"in_channels": vae_params["in_channels"],
|
| 275 |
+
"out_channels": vae_params["out_ch"],
|
| 276 |
+
"down_block_types": tuple(down_block_types),
|
| 277 |
+
"up_block_types": tuple(up_block_types),
|
| 278 |
+
"block_out_channels": tuple(block_out_channels),
|
| 279 |
+
"latent_channels": vae_params["z_channels"],
|
| 280 |
+
"layers_per_block": vae_params["num_res_blocks"],
|
| 281 |
+
"scaling_factor": float(scaling_factor),
|
| 282 |
+
}
|
| 283 |
+
return config
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
|
| 287 |
+
def create_diffusers_schedular(original_config):
|
| 288 |
+
schedular = DDIMScheduler(
|
| 289 |
+
num_train_timesteps=original_config["model"]["params"]["timesteps"],
|
| 290 |
+
beta_start=original_config["model"]["params"]["linear_start"],
|
| 291 |
+
beta_end=original_config["model"]["params"]["linear_end"],
|
| 292 |
+
beta_schedule="scaled_linear",
|
| 293 |
+
)
|
| 294 |
+
return schedular
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
|
| 298 |
+
"""
|
| 299 |
+
Takes a state dict and a config, and returns a converted UNet checkpoint.
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
# extract state_dict for UNet
|
| 303 |
+
unet_state_dict = {}
|
| 304 |
+
keys = list(checkpoint.keys())
|
| 305 |
+
|
| 306 |
+
unet_key = "model.diffusion_model."
|
| 307 |
+
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
| 308 |
+
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
| 309 |
+
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
| 310 |
+
print(
|
| 311 |
+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
| 312 |
+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
| 313 |
+
)
|
| 314 |
+
for key in keys:
|
| 315 |
+
if key.startswith("model.diffusion_model"):
|
| 316 |
+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
| 317 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
| 318 |
+
else:
|
| 319 |
+
if sum(k.startswith("model_ema") for k in keys) > 100:
|
| 320 |
+
print(
|
| 321 |
+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
| 322 |
+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# strip the unet prefix from the weight names
|
| 326 |
+
for key in keys:
|
| 327 |
+
if key.startswith(unet_key):
|
| 328 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
| 329 |
+
|
| 330 |
+
new_checkpoint = {}
|
| 331 |
+
|
| 332 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
| 333 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
| 334 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
| 335 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
| 336 |
+
|
| 337 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
| 338 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
| 339 |
+
|
| 340 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
| 341 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
| 342 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
| 343 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
| 344 |
+
|
| 345 |
+
# Retrieves the keys for the input blocks only
|
| 346 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
| 347 |
+
input_blocks = {
|
| 348 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
| 349 |
+
for layer_id in range(num_input_blocks)
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
# Retrieves the keys for the middle blocks only
|
| 353 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
| 354 |
+
middle_blocks = {
|
| 355 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
|
| 356 |
+
for layer_id in range(num_middle_blocks)
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
# Retrieves the keys for the output blocks only
|
| 360 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
| 361 |
+
output_blocks = {
|
| 362 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
| 363 |
+
for layer_id in range(num_output_blocks)
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
# Check how many Transformer blocks we have per layer
|
| 367 |
+
if isinstance(config.get("cross_attention_dim"), (list, tuple)):
|
| 368 |
+
if isinstance(config["cross_attention_dim"][0], (list, tuple)):
|
| 369 |
+
# in this case we have multiple cross-attention layers per-block
|
| 370 |
+
num_attention_layers = len(config.get("cross_attention_dim")[0])
|
| 371 |
+
else:
|
| 372 |
+
num_attention_layers = 1
|
| 373 |
+
|
| 374 |
+
if config.get("extra_self_attn_layer"):
|
| 375 |
+
num_attention_layers += 1
|
| 376 |
+
|
| 377 |
+
for i in range(1, num_input_blocks):
|
| 378 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
| 379 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
| 380 |
+
|
| 381 |
+
resnets = [
|
| 382 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
| 383 |
+
]
|
| 384 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.0" not in key]
|
| 385 |
+
|
| 386 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
| 387 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
| 388 |
+
f"input_blocks.{i}.0.op.weight"
|
| 389 |
+
)
|
| 390 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
| 391 |
+
f"input_blocks.{i}.0.op.bias"
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
paths = renew_resnet_paths(resnets)
|
| 395 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 396 |
+
assign_to_checkpoint(
|
| 397 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
if len(attentions):
|
| 401 |
+
paths = renew_attention_paths(attentions)
|
| 402 |
+
meta_path = [
|
| 403 |
+
{
|
| 404 |
+
"old": f"input_blocks.{i}.{1 + layer_id}",
|
| 405 |
+
"new": f"down_blocks.{block_id}.attentions.{layer_in_block_id * num_attention_layers + layer_id}",
|
| 406 |
+
}
|
| 407 |
+
for layer_id in range(num_attention_layers)
|
| 408 |
+
]
|
| 409 |
+
assign_to_checkpoint(
|
| 410 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=meta_path, config=config
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
resnet_0 = middle_blocks[0]
|
| 414 |
+
resnet_1 = middle_blocks[num_middle_blocks - 1]
|
| 415 |
+
|
| 416 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
| 417 |
+
meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"}
|
| 418 |
+
assign_to_checkpoint(
|
| 419 |
+
resnet_0_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
| 423 |
+
meta_path = {"old": f"middle_block.{len(middle_blocks) - 1}", "new": "mid_block.resnets.1"}
|
| 424 |
+
assign_to_checkpoint(
|
| 425 |
+
resnet_1_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
for i in range(1, num_middle_blocks - 1):
|
| 429 |
+
attentions = middle_blocks[i]
|
| 430 |
+
attentions_paths = renew_attention_paths(attentions)
|
| 431 |
+
meta_path = {"old": f"middle_block.{i}", "new": f"mid_block.attentions.{i - 1}"}
|
| 432 |
+
assign_to_checkpoint(
|
| 433 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
for i in range(num_output_blocks):
|
| 437 |
+
block_id = i // (config["layers_per_block"] + 1)
|
| 438 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
| 439 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
| 440 |
+
output_block_list = {}
|
| 441 |
+
|
| 442 |
+
for layer in output_block_layers:
|
| 443 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
| 444 |
+
if layer_id in output_block_list:
|
| 445 |
+
output_block_list[layer_id].append(layer_name)
|
| 446 |
+
else:
|
| 447 |
+
output_block_list[layer_id] = [layer_name]
|
| 448 |
+
|
| 449 |
+
if len(output_block_list) > 1:
|
| 450 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
| 451 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.0" not in key]
|
| 452 |
+
|
| 453 |
+
paths = renew_resnet_paths(resnets)
|
| 454 |
+
|
| 455 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 456 |
+
assign_to_checkpoint(
|
| 457 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
| 461 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
| 462 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
| 463 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
| 464 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
| 465 |
+
]
|
| 466 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
| 467 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
| 468 |
+
]
|
| 469 |
+
|
| 470 |
+
attentions.remove(f"output_blocks.{i}.{index}.conv.bias")
|
| 471 |
+
attentions.remove(f"output_blocks.{i}.{index}.conv.weight")
|
| 472 |
+
|
| 473 |
+
# Clear attentions as they have been attributed above.
|
| 474 |
+
if len(attentions) == 2:
|
| 475 |
+
attentions = []
|
| 476 |
+
|
| 477 |
+
if len(attentions):
|
| 478 |
+
paths = renew_attention_paths(attentions)
|
| 479 |
+
meta_path = [
|
| 480 |
+
{
|
| 481 |
+
"old": f"output_blocks.{i}.{1 + layer_id}",
|
| 482 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id * num_attention_layers + layer_id}",
|
| 483 |
+
}
|
| 484 |
+
for layer_id in range(num_attention_layers)
|
| 485 |
+
]
|
| 486 |
+
assign_to_checkpoint(
|
| 487 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=meta_path, config=config
|
| 488 |
+
)
|
| 489 |
+
else:
|
| 490 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
| 491 |
+
for path in resnet_0_paths:
|
| 492 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
| 493 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
| 494 |
+
|
| 495 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
| 496 |
+
|
| 497 |
+
return new_checkpoint
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
| 501 |
+
# extract state dict for VAE
|
| 502 |
+
vae_state_dict = {}
|
| 503 |
+
vae_key = "first_stage_model."
|
| 504 |
+
keys = list(checkpoint.keys())
|
| 505 |
+
for key in keys:
|
| 506 |
+
if key.startswith(vae_key):
|
| 507 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
| 508 |
+
|
| 509 |
+
new_checkpoint = {}
|
| 510 |
+
|
| 511 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
| 512 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
| 513 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
| 514 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
| 515 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
| 516 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
| 517 |
+
|
| 518 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
| 519 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
| 520 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
| 521 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
| 522 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
| 523 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
| 524 |
+
|
| 525 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
| 526 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
| 527 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
| 528 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
| 529 |
+
|
| 530 |
+
# Retrieves the keys for the encoder down blocks only
|
| 531 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
| 532 |
+
down_blocks = {
|
| 533 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
# Retrieves the keys for the decoder up blocks only
|
| 537 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
| 538 |
+
up_blocks = {
|
| 539 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
for i in range(num_down_blocks):
|
| 543 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
| 544 |
+
|
| 545 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
| 546 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
| 547 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
| 548 |
+
)
|
| 549 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
| 550 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 554 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
| 555 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 556 |
+
|
| 557 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
| 558 |
+
num_mid_res_blocks = 2
|
| 559 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 560 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
| 561 |
+
|
| 562 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 563 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 564 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 565 |
+
|
| 566 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
| 567 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 568 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 569 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 570 |
+
conv_attn_to_linear(new_checkpoint)
|
| 571 |
+
|
| 572 |
+
for i in range(num_up_blocks):
|
| 573 |
+
block_id = num_up_blocks - 1 - i
|
| 574 |
+
resnets = [
|
| 575 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
| 576 |
+
]
|
| 577 |
+
|
| 578 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
| 579 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
| 580 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
| 581 |
+
]
|
| 582 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
| 583 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
| 584 |
+
]
|
| 585 |
+
|
| 586 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 587 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
| 588 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 589 |
+
|
| 590 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
| 591 |
+
num_mid_res_blocks = 2
|
| 592 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 593 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
| 594 |
+
|
| 595 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 596 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 597 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 598 |
+
|
| 599 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
| 600 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 601 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 602 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 603 |
+
conv_attn_to_linear(new_checkpoint)
|
| 604 |
+
return new_checkpoint
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
CLAP_KEYS_TO_MODIFY_MAPPING = {
|
| 608 |
+
"text_branch": "text_model",
|
| 609 |
+
"audio_branch": "audio_model.audio_encoder",
|
| 610 |
+
"attn": "attention.self",
|
| 611 |
+
"self.proj": "output.dense",
|
| 612 |
+
"attention.self_mask": "attn_mask",
|
| 613 |
+
"mlp.fc1": "intermediate.dense",
|
| 614 |
+
"mlp.fc2": "output.dense",
|
| 615 |
+
"norm1": "layernorm_before",
|
| 616 |
+
"norm2": "layernorm_after",
|
| 617 |
+
"bn0": "batch_norm",
|
| 618 |
+
}
|
| 619 |
+
|
| 620 |
+
CLAP_KEYS_TO_IGNORE = [
|
| 621 |
+
"text_transform",
|
| 622 |
+
"audio_transform",
|
| 623 |
+
"stft",
|
| 624 |
+
"logmel_extractor",
|
| 625 |
+
"tscam_conv",
|
| 626 |
+
"head",
|
| 627 |
+
"attn_mask",
|
| 628 |
+
]
|
| 629 |
+
|
| 630 |
+
CLAP_EXPECTED_MISSING_KEYS = ["text_model.embeddings.token_type_ids"]
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def convert_open_clap_checkpoint(checkpoint):
|
| 634 |
+
"""
|
| 635 |
+
Takes a state dict and returns a converted CLAP checkpoint.
|
| 636 |
+
"""
|
| 637 |
+
# extract state dict for CLAP text embedding model, discarding the audio component
|
| 638 |
+
model_state_dict = {}
|
| 639 |
+
model_key = "clap.model."
|
| 640 |
+
keys = list(checkpoint.keys())
|
| 641 |
+
for key in keys:
|
| 642 |
+
if key.startswith(model_key):
|
| 643 |
+
model_state_dict[key.replace(model_key, "")] = checkpoint.get(key)
|
| 644 |
+
|
| 645 |
+
new_checkpoint = {}
|
| 646 |
+
|
| 647 |
+
sequential_layers_pattern = r".*sequential.(\d+).*"
|
| 648 |
+
text_projection_pattern = r".*_projection.(\d+).*"
|
| 649 |
+
|
| 650 |
+
for key, value in model_state_dict.items():
|
| 651 |
+
# check if key should be ignored in mapping - if so map it to a key name that we'll filter out at the end
|
| 652 |
+
for key_to_ignore in CLAP_KEYS_TO_IGNORE:
|
| 653 |
+
if key_to_ignore in key:
|
| 654 |
+
key = "spectrogram"
|
| 655 |
+
|
| 656 |
+
# check if any key needs to be modified
|
| 657 |
+
for key_to_modify, new_key in CLAP_KEYS_TO_MODIFY_MAPPING.items():
|
| 658 |
+
if key_to_modify in key:
|
| 659 |
+
key = key.replace(key_to_modify, new_key)
|
| 660 |
+
|
| 661 |
+
if re.match(sequential_layers_pattern, key):
|
| 662 |
+
# replace sequential layers with list
|
| 663 |
+
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
|
| 664 |
+
|
| 665 |
+
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
|
| 666 |
+
elif re.match(text_projection_pattern, key):
|
| 667 |
+
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
|
| 668 |
+
|
| 669 |
+
# Because in CLAP they use `nn.Sequential`...
|
| 670 |
+
transformers_projection_layer = 1 if projecton_layer == 0 else 2
|
| 671 |
+
|
| 672 |
+
key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
|
| 673 |
+
|
| 674 |
+
if "audio" and "qkv" in key:
|
| 675 |
+
# split qkv into query key and value
|
| 676 |
+
mixed_qkv = value
|
| 677 |
+
qkv_dim = mixed_qkv.size(0) // 3
|
| 678 |
+
|
| 679 |
+
query_layer = mixed_qkv[:qkv_dim]
|
| 680 |
+
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
|
| 681 |
+
value_layer = mixed_qkv[qkv_dim * 2 :]
|
| 682 |
+
|
| 683 |
+
new_checkpoint[key.replace("qkv", "query")] = query_layer
|
| 684 |
+
new_checkpoint[key.replace("qkv", "key")] = key_layer
|
| 685 |
+
new_checkpoint[key.replace("qkv", "value")] = value_layer
|
| 686 |
+
elif key != "spectrogram":
|
| 687 |
+
new_checkpoint[key] = value
|
| 688 |
+
|
| 689 |
+
return new_checkpoint
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
def create_transformers_vocoder_config(original_config):
|
| 693 |
+
"""
|
| 694 |
+
Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
|
| 695 |
+
"""
|
| 696 |
+
vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"]
|
| 697 |
+
|
| 698 |
+
config = {
|
| 699 |
+
"model_in_dim": vocoder_params["num_mels"],
|
| 700 |
+
"sampling_rate": vocoder_params["sampling_rate"],
|
| 701 |
+
"upsample_initial_channel": vocoder_params["upsample_initial_channel"],
|
| 702 |
+
"upsample_rates": list(vocoder_params["upsample_rates"]),
|
| 703 |
+
"upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]),
|
| 704 |
+
"resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]),
|
| 705 |
+
"resblock_dilation_sizes": [
|
| 706 |
+
list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"]
|
| 707 |
+
],
|
| 708 |
+
"normalize_before": False,
|
| 709 |
+
}
|
| 710 |
+
|
| 711 |
+
return config
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def extract_sub_model(checkpoint, key_prefix):
|
| 715 |
+
"""
|
| 716 |
+
Takes a state dict and returns the state dict for a particular sub-model.
|
| 717 |
+
"""
|
| 718 |
+
|
| 719 |
+
sub_model_state_dict = {}
|
| 720 |
+
keys = list(checkpoint.keys())
|
| 721 |
+
for key in keys:
|
| 722 |
+
if key.startswith(key_prefix):
|
| 723 |
+
sub_model_state_dict[key.replace(key_prefix, "")] = checkpoint.get(key)
|
| 724 |
+
|
| 725 |
+
return sub_model_state_dict
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def convert_hifigan_checkpoint(checkpoint, config):
|
| 729 |
+
"""
|
| 730 |
+
Takes a state dict and config, and returns a converted HiFiGAN vocoder checkpoint.
|
| 731 |
+
"""
|
| 732 |
+
# extract state dict for vocoder
|
| 733 |
+
vocoder_state_dict = extract_sub_model(checkpoint, key_prefix="first_stage_model.vocoder.")
|
| 734 |
+
|
| 735 |
+
# fix upsampler keys, everything else is correct already
|
| 736 |
+
for i in range(len(config.upsample_rates)):
|
| 737 |
+
vocoder_state_dict[f"upsampler.{i}.weight"] = vocoder_state_dict.pop(f"ups.{i}.weight")
|
| 738 |
+
vocoder_state_dict[f"upsampler.{i}.bias"] = vocoder_state_dict.pop(f"ups.{i}.bias")
|
| 739 |
+
|
| 740 |
+
if not config.normalize_before:
|
| 741 |
+
# if we don't set normalize_before then these variables are unused, so we set them to their initialised values
|
| 742 |
+
vocoder_state_dict["mean"] = torch.zeros(config.model_in_dim)
|
| 743 |
+
vocoder_state_dict["scale"] = torch.ones(config.model_in_dim)
|
| 744 |
+
|
| 745 |
+
return vocoder_state_dict
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
def convert_projection_checkpoint(checkpoint):
|
| 749 |
+
projection_state_dict = {}
|
| 750 |
+
conditioner_state_dict = extract_sub_model(checkpoint, key_prefix="cond_stage_models.0.")
|
| 751 |
+
|
| 752 |
+
projection_state_dict["sos_embed"] = conditioner_state_dict["start_of_sequence_tokens.weight"][0]
|
| 753 |
+
projection_state_dict["sos_embed_1"] = conditioner_state_dict["start_of_sequence_tokens.weight"][1]
|
| 754 |
+
|
| 755 |
+
projection_state_dict["eos_embed"] = conditioner_state_dict["end_of_sequence_tokens.weight"][0]
|
| 756 |
+
projection_state_dict["eos_embed_1"] = conditioner_state_dict["end_of_sequence_tokens.weight"][1]
|
| 757 |
+
|
| 758 |
+
projection_state_dict["projection.weight"] = conditioner_state_dict["input_sequence_embed_linear.0.weight"]
|
| 759 |
+
projection_state_dict["projection.bias"] = conditioner_state_dict["input_sequence_embed_linear.0.bias"]
|
| 760 |
+
|
| 761 |
+
projection_state_dict["projection_1.weight"] = conditioner_state_dict["input_sequence_embed_linear.1.weight"]
|
| 762 |
+
projection_state_dict["projection_1.bias"] = conditioner_state_dict["input_sequence_embed_linear.1.bias"]
|
| 763 |
+
|
| 764 |
+
return projection_state_dict
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
# Adapted from https://github.com/haoheliu/AudioLDM2/blob/81ad2c6ce015c1310387695e2dae975a7d2ed6fd/audioldm2/utils.py#L143
|
| 768 |
+
DEFAULT_CONFIG = {
|
| 769 |
+
"model": {
|
| 770 |
+
"params": {
|
| 771 |
+
"linear_start": 0.0015,
|
| 772 |
+
"linear_end": 0.0195,
|
| 773 |
+
"timesteps": 1000,
|
| 774 |
+
"channels": 8,
|
| 775 |
+
"scale_by_std": True,
|
| 776 |
+
"unet_config": {
|
| 777 |
+
"target": "audioldm2.latent_diffusion.openaimodel.UNetModel",
|
| 778 |
+
"params": {
|
| 779 |
+
"context_dim": [None, 768, 1024],
|
| 780 |
+
"in_channels": 8,
|
| 781 |
+
"out_channels": 8,
|
| 782 |
+
"model_channels": 128,
|
| 783 |
+
"attention_resolutions": [8, 4, 2],
|
| 784 |
+
"num_res_blocks": 2,
|
| 785 |
+
"channel_mult": [1, 2, 3, 5],
|
| 786 |
+
"num_head_channels": 32,
|
| 787 |
+
"transformer_depth": 1,
|
| 788 |
+
},
|
| 789 |
+
},
|
| 790 |
+
"first_stage_config": {
|
| 791 |
+
"target": "audioldm2.variational_autoencoder.autoencoder.AutoencoderKL",
|
| 792 |
+
"params": {
|
| 793 |
+
"embed_dim": 8,
|
| 794 |
+
"ddconfig": {
|
| 795 |
+
"z_channels": 8,
|
| 796 |
+
"resolution": 256,
|
| 797 |
+
"in_channels": 1,
|
| 798 |
+
"out_ch": 1,
|
| 799 |
+
"ch": 128,
|
| 800 |
+
"ch_mult": [1, 2, 4],
|
| 801 |
+
"num_res_blocks": 2,
|
| 802 |
+
},
|
| 803 |
+
},
|
| 804 |
+
},
|
| 805 |
+
"cond_stage_config": {
|
| 806 |
+
"crossattn_audiomae_generated": {
|
| 807 |
+
"target": "audioldm2.latent_diffusion.modules.encoders.modules.SequenceGenAudioMAECond",
|
| 808 |
+
"params": {
|
| 809 |
+
"sequence_gen_length": 8,
|
| 810 |
+
"sequence_input_embed_dim": [512, 1024],
|
| 811 |
+
},
|
| 812 |
+
}
|
| 813 |
+
},
|
| 814 |
+
"vocoder_config": {
|
| 815 |
+
"target": "audioldm2.first_stage_model.vocoder",
|
| 816 |
+
"params": {
|
| 817 |
+
"upsample_rates": [5, 4, 2, 2, 2],
|
| 818 |
+
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
|
| 819 |
+
"upsample_initial_channel": 1024,
|
| 820 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
| 821 |
+
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 822 |
+
"num_mels": 64,
|
| 823 |
+
"sampling_rate": 16000,
|
| 824 |
+
},
|
| 825 |
+
},
|
| 826 |
+
},
|
| 827 |
+
},
|
| 828 |
+
}
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
def load_pipeline_from_original_AudioLDM2_ckpt(
|
| 832 |
+
checkpoint_path: str,
|
| 833 |
+
original_config_file: str = None,
|
| 834 |
+
image_size: int = 1024,
|
| 835 |
+
prediction_type: str = None,
|
| 836 |
+
extract_ema: bool = False,
|
| 837 |
+
scheduler_type: str = "ddim",
|
| 838 |
+
cross_attention_dim: Union[List, List[List]] = None,
|
| 839 |
+
transformer_layers_per_block: int = None,
|
| 840 |
+
device: str = None,
|
| 841 |
+
from_safetensors: bool = False,
|
| 842 |
+
) -> AudioLDM2Pipeline:
|
| 843 |
+
"""
|
| 844 |
+
Load an AudioLDM2 pipeline object from a `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file.
|
| 845 |
+
|
| 846 |
+
Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the
|
| 847 |
+
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
|
| 848 |
+
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
|
| 849 |
+
|
| 850 |
+
Args:
|
| 851 |
+
checkpoint_path (`str`): Path to `.ckpt` file.
|
| 852 |
+
original_config_file (`str`):
|
| 853 |
+
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
|
| 854 |
+
set to the AudioLDM2 base config.
|
| 855 |
+
image_size (`int`, *optional*, defaults to 1024):
|
| 856 |
+
The image size that the model was trained on.
|
| 857 |
+
prediction_type (`str`, *optional*):
|
| 858 |
+
The prediction type that the model was trained on. If `None`, will be automatically
|
| 859 |
+
inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`.
|
| 860 |
+
scheduler_type (`str`, *optional*, defaults to 'ddim'):
|
| 861 |
+
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
| 862 |
+
"ddim"]`.
|
| 863 |
+
cross_attention_dim (`list`, *optional*, defaults to `None`):
|
| 864 |
+
The dimension of the cross-attention layers. If `None`, the cross-attention dimension will be
|
| 865 |
+
automatically inferred. Set to `[768, 1024]` for the base model, or `[768, 1024, None]` for the large model.
|
| 866 |
+
transformer_layers_per_block (`int`, *optional*, defaults to `None`):
|
| 867 |
+
The number of transformer layers in each transformer block. If `None`, number of layers will be "
|
| 868 |
+
"automatically inferred. Set to `1` for the base model, or `2` for the large model.
|
| 869 |
+
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
| 870 |
+
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
|
| 871 |
+
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
| 872 |
+
inference. Non-EMA weights are usually better to continue fine-tuning.
|
| 873 |
+
device (`str`, *optional*, defaults to `None`):
|
| 874 |
+
The device to use. Pass `None` to determine automatically.
|
| 875 |
+
from_safetensors (`str`, *optional*, defaults to `False`):
|
| 876 |
+
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
|
| 877 |
+
return: An AudioLDM2Pipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
| 878 |
+
"""
|
| 879 |
+
|
| 880 |
+
if from_safetensors:
|
| 881 |
+
if not is_safetensors_available():
|
| 882 |
+
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
| 883 |
+
|
| 884 |
+
from safetensors import safe_open
|
| 885 |
+
|
| 886 |
+
checkpoint = {}
|
| 887 |
+
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
|
| 888 |
+
for key in f.keys():
|
| 889 |
+
checkpoint[key] = f.get_tensor(key)
|
| 890 |
+
else:
|
| 891 |
+
if device is None:
|
| 892 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 893 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 894 |
+
else:
|
| 895 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 896 |
+
|
| 897 |
+
if "state_dict" in checkpoint:
|
| 898 |
+
checkpoint = checkpoint["state_dict"]
|
| 899 |
+
|
| 900 |
+
if original_config_file is None:
|
| 901 |
+
original_config = DEFAULT_CONFIG
|
| 902 |
+
else:
|
| 903 |
+
original_config = yaml.safe_load(original_config_file)
|
| 904 |
+
|
| 905 |
+
if image_size is not None:
|
| 906 |
+
original_config["model"]["params"]["unet_config"]["params"]["image_size"] = image_size
|
| 907 |
+
|
| 908 |
+
if cross_attention_dim is not None:
|
| 909 |
+
original_config["model"]["params"]["unet_config"]["params"]["context_dim"] = cross_attention_dim
|
| 910 |
+
|
| 911 |
+
if transformer_layers_per_block is not None:
|
| 912 |
+
original_config["model"]["params"]["unet_config"]["params"]["transformer_depth"] = transformer_layers_per_block
|
| 913 |
+
|
| 914 |
+
if (
|
| 915 |
+
"parameterization" in original_config["model"]["params"]
|
| 916 |
+
and original_config["model"]["params"]["parameterization"] == "v"
|
| 917 |
+
):
|
| 918 |
+
if prediction_type is None:
|
| 919 |
+
prediction_type = "v_prediction"
|
| 920 |
+
else:
|
| 921 |
+
if prediction_type is None:
|
| 922 |
+
prediction_type = "epsilon"
|
| 923 |
+
|
| 924 |
+
num_train_timesteps = original_config["model"]["params"]["timesteps"]
|
| 925 |
+
beta_start = original_config["model"]["params"]["linear_start"]
|
| 926 |
+
beta_end = original_config["model"]["params"]["linear_end"]
|
| 927 |
+
|
| 928 |
+
scheduler = DDIMScheduler(
|
| 929 |
+
beta_end=beta_end,
|
| 930 |
+
beta_schedule="scaled_linear",
|
| 931 |
+
beta_start=beta_start,
|
| 932 |
+
num_train_timesteps=num_train_timesteps,
|
| 933 |
+
steps_offset=1,
|
| 934 |
+
clip_sample=False,
|
| 935 |
+
set_alpha_to_one=False,
|
| 936 |
+
prediction_type=prediction_type,
|
| 937 |
+
)
|
| 938 |
+
# make sure scheduler works correctly with DDIM
|
| 939 |
+
scheduler.register_to_config(clip_sample=False)
|
| 940 |
+
|
| 941 |
+
if scheduler_type == "pndm":
|
| 942 |
+
config = dict(scheduler.config)
|
| 943 |
+
config["skip_prk_steps"] = True
|
| 944 |
+
scheduler = PNDMScheduler.from_config(config)
|
| 945 |
+
elif scheduler_type == "lms":
|
| 946 |
+
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
| 947 |
+
elif scheduler_type == "heun":
|
| 948 |
+
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
| 949 |
+
elif scheduler_type == "euler":
|
| 950 |
+
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
| 951 |
+
elif scheduler_type == "euler-ancestral":
|
| 952 |
+
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
| 953 |
+
elif scheduler_type == "dpm":
|
| 954 |
+
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
| 955 |
+
elif scheduler_type == "ddim":
|
| 956 |
+
scheduler = scheduler
|
| 957 |
+
else:
|
| 958 |
+
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
| 959 |
+
|
| 960 |
+
# Convert the UNet2DModel
|
| 961 |
+
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
| 962 |
+
unet = AudioLDM2UNet2DConditionModel(**unet_config)
|
| 963 |
+
|
| 964 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
| 965 |
+
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
|
| 966 |
+
)
|
| 967 |
+
|
| 968 |
+
unet.load_state_dict(converted_unet_checkpoint)
|
| 969 |
+
|
| 970 |
+
# Convert the VAE model
|
| 971 |
+
vae_config = create_vae_diffusers_config(original_config, checkpoint=checkpoint, image_size=image_size)
|
| 972 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
| 973 |
+
|
| 974 |
+
vae = AutoencoderKL(**vae_config)
|
| 975 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
| 976 |
+
|
| 977 |
+
# Convert the joint audio-text encoding model
|
| 978 |
+
clap_config = ClapConfig.from_pretrained("laion/clap-htsat-unfused")
|
| 979 |
+
clap_config.audio_config.update(
|
| 980 |
+
{
|
| 981 |
+
"patch_embeds_hidden_size": 128,
|
| 982 |
+
"hidden_size": 1024,
|
| 983 |
+
"depths": [2, 2, 12, 2],
|
| 984 |
+
}
|
| 985 |
+
)
|
| 986 |
+
# AudioLDM2 uses the same tokenizer and feature extractor as the original CLAP model
|
| 987 |
+
clap_tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
|
| 988 |
+
clap_feature_extractor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused")
|
| 989 |
+
|
| 990 |
+
converted_clap_model = convert_open_clap_checkpoint(checkpoint)
|
| 991 |
+
clap_model = ClapModel(clap_config)
|
| 992 |
+
|
| 993 |
+
missing_keys, unexpected_keys = clap_model.load_state_dict(converted_clap_model, strict=False)
|
| 994 |
+
# we expect not to have token_type_ids in our original state dict so let's ignore them
|
| 995 |
+
missing_keys = list(set(missing_keys) - set(CLAP_EXPECTED_MISSING_KEYS))
|
| 996 |
+
|
| 997 |
+
if len(unexpected_keys) > 0:
|
| 998 |
+
raise ValueError(f"Unexpected keys when loading CLAP model: {unexpected_keys}")
|
| 999 |
+
|
| 1000 |
+
if len(missing_keys) > 0:
|
| 1001 |
+
raise ValueError(f"Missing keys when loading CLAP model: {missing_keys}")
|
| 1002 |
+
|
| 1003 |
+
# Convert the vocoder model
|
| 1004 |
+
vocoder_config = create_transformers_vocoder_config(original_config)
|
| 1005 |
+
vocoder_config = SpeechT5HifiGanConfig(**vocoder_config)
|
| 1006 |
+
converted_vocoder_checkpoint = convert_hifigan_checkpoint(checkpoint, vocoder_config)
|
| 1007 |
+
|
| 1008 |
+
vocoder = SpeechT5HifiGan(vocoder_config)
|
| 1009 |
+
vocoder.load_state_dict(converted_vocoder_checkpoint)
|
| 1010 |
+
|
| 1011 |
+
# Convert the Flan-T5 encoder model: AudioLDM2 uses the same configuration and tokenizer as the original Flan-T5 large model
|
| 1012 |
+
t5_config = T5Config.from_pretrained("google/flan-t5-large")
|
| 1013 |
+
converted_t5_checkpoint = extract_sub_model(checkpoint, key_prefix="cond_stage_models.1.model.")
|
| 1014 |
+
|
| 1015 |
+
t5_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
|
| 1016 |
+
# hard-coded in the original implementation (i.e. not retrievable from the config)
|
| 1017 |
+
t5_tokenizer.model_max_length = 128
|
| 1018 |
+
t5_model = T5EncoderModel(t5_config)
|
| 1019 |
+
t5_model.load_state_dict(converted_t5_checkpoint)
|
| 1020 |
+
|
| 1021 |
+
# Convert the GPT2 encoder model: AudioLDM2 uses the same configuration as the original GPT2 base model
|
| 1022 |
+
gpt2_config = GPT2Config.from_pretrained("gpt2")
|
| 1023 |
+
gpt2_model = GPT2Model(gpt2_config)
|
| 1024 |
+
gpt2_model.config.max_new_tokens = original_config["model"]["params"]["cond_stage_config"][
|
| 1025 |
+
"crossattn_audiomae_generated"
|
| 1026 |
+
]["params"]["sequence_gen_length"]
|
| 1027 |
+
|
| 1028 |
+
converted_gpt2_checkpoint = extract_sub_model(checkpoint, key_prefix="cond_stage_models.0.model.")
|
| 1029 |
+
gpt2_model.load_state_dict(converted_gpt2_checkpoint)
|
| 1030 |
+
|
| 1031 |
+
# Convert the extra embedding / projection layers
|
| 1032 |
+
projection_model = AudioLDM2ProjectionModel(clap_config.projection_dim, t5_config.d_model, gpt2_config.n_embd)
|
| 1033 |
+
|
| 1034 |
+
converted_projection_checkpoint = convert_projection_checkpoint(checkpoint)
|
| 1035 |
+
projection_model.load_state_dict(converted_projection_checkpoint)
|
| 1036 |
+
|
| 1037 |
+
# Instantiate the diffusers pipeline
|
| 1038 |
+
pipe = AudioLDM2Pipeline(
|
| 1039 |
+
vae=vae,
|
| 1040 |
+
text_encoder=clap_model,
|
| 1041 |
+
text_encoder_2=t5_model,
|
| 1042 |
+
projection_model=projection_model,
|
| 1043 |
+
language_model=gpt2_model,
|
| 1044 |
+
tokenizer=clap_tokenizer,
|
| 1045 |
+
tokenizer_2=t5_tokenizer,
|
| 1046 |
+
feature_extractor=clap_feature_extractor,
|
| 1047 |
+
unet=unet,
|
| 1048 |
+
scheduler=scheduler,
|
| 1049 |
+
vocoder=vocoder,
|
| 1050 |
+
)
|
| 1051 |
+
|
| 1052 |
+
return pipe
|
| 1053 |
+
|
| 1054 |
+
|
| 1055 |
+
if __name__ == "__main__":
|
| 1056 |
+
parser = argparse.ArgumentParser()
|
| 1057 |
+
|
| 1058 |
+
parser.add_argument(
|
| 1059 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
| 1060 |
+
)
|
| 1061 |
+
parser.add_argument(
|
| 1062 |
+
"--original_config_file",
|
| 1063 |
+
default=None,
|
| 1064 |
+
type=str,
|
| 1065 |
+
help="The YAML config file corresponding to the original architecture.",
|
| 1066 |
+
)
|
| 1067 |
+
parser.add_argument(
|
| 1068 |
+
"--cross_attention_dim",
|
| 1069 |
+
default=None,
|
| 1070 |
+
type=int,
|
| 1071 |
+
nargs="+",
|
| 1072 |
+
help="The dimension of the cross-attention layers. If `None`, the cross-attention dimension will be "
|
| 1073 |
+
"automatically inferred. Set to `768+1024` for the base model, or `768+1024+640` for the large model",
|
| 1074 |
+
)
|
| 1075 |
+
parser.add_argument(
|
| 1076 |
+
"--transformer_layers_per_block",
|
| 1077 |
+
default=None,
|
| 1078 |
+
type=int,
|
| 1079 |
+
help="The number of transformer layers in each transformer block. If `None`, number of layers will be "
|
| 1080 |
+
"automatically inferred. Set to `1` for the base model, or `2` for the large model.",
|
| 1081 |
+
)
|
| 1082 |
+
parser.add_argument(
|
| 1083 |
+
"--scheduler_type",
|
| 1084 |
+
default="ddim",
|
| 1085 |
+
type=str,
|
| 1086 |
+
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
|
| 1087 |
+
)
|
| 1088 |
+
parser.add_argument(
|
| 1089 |
+
"--image_size",
|
| 1090 |
+
default=1048,
|
| 1091 |
+
type=int,
|
| 1092 |
+
help="The image size that the model was trained on.",
|
| 1093 |
+
)
|
| 1094 |
+
parser.add_argument(
|
| 1095 |
+
"--prediction_type",
|
| 1096 |
+
default=None,
|
| 1097 |
+
type=str,
|
| 1098 |
+
help=("The prediction type that the model was trained on."),
|
| 1099 |
+
)
|
| 1100 |
+
parser.add_argument(
|
| 1101 |
+
"--extract_ema",
|
| 1102 |
+
action="store_true",
|
| 1103 |
+
help=(
|
| 1104 |
+
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
|
| 1105 |
+
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
|
| 1106 |
+
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
|
| 1107 |
+
),
|
| 1108 |
+
)
|
| 1109 |
+
parser.add_argument(
|
| 1110 |
+
"--from_safetensors",
|
| 1111 |
+
action="store_true",
|
| 1112 |
+
help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
|
| 1113 |
+
)
|
| 1114 |
+
parser.add_argument(
|
| 1115 |
+
"--to_safetensors",
|
| 1116 |
+
action="store_true",
|
| 1117 |
+
help="Whether to store pipeline in safetensors format or not.",
|
| 1118 |
+
)
|
| 1119 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 1120 |
+
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
| 1121 |
+
args = parser.parse_args()
|
| 1122 |
+
|
| 1123 |
+
pipe = load_pipeline_from_original_AudioLDM2_ckpt(
|
| 1124 |
+
checkpoint_path=args.checkpoint_path,
|
| 1125 |
+
original_config_file=args.original_config_file,
|
| 1126 |
+
image_size=args.image_size,
|
| 1127 |
+
prediction_type=args.prediction_type,
|
| 1128 |
+
extract_ema=args.extract_ema,
|
| 1129 |
+
scheduler_type=args.scheduler_type,
|
| 1130 |
+
cross_attention_dim=args.cross_attention_dim,
|
| 1131 |
+
transformer_layers_per_block=args.transformer_layers_per_block,
|
| 1132 |
+
from_safetensors=args.from_safetensors,
|
| 1133 |
+
device=args.device,
|
| 1134 |
+
)
|
| 1135 |
+
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
diffusers/scripts/convert_original_musicldm_to_diffusers.py
ADDED
|
@@ -0,0 +1,1056 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Conversion script for the MusicLDM checkpoints."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import re
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import yaml
|
| 22 |
+
from transformers import (
|
| 23 |
+
AutoFeatureExtractor,
|
| 24 |
+
AutoTokenizer,
|
| 25 |
+
ClapConfig,
|
| 26 |
+
ClapModel,
|
| 27 |
+
SpeechT5HifiGan,
|
| 28 |
+
SpeechT5HifiGanConfig,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
from diffusers import (
|
| 32 |
+
AutoencoderKL,
|
| 33 |
+
DDIMScheduler,
|
| 34 |
+
DPMSolverMultistepScheduler,
|
| 35 |
+
EulerAncestralDiscreteScheduler,
|
| 36 |
+
EulerDiscreteScheduler,
|
| 37 |
+
HeunDiscreteScheduler,
|
| 38 |
+
LMSDiscreteScheduler,
|
| 39 |
+
MusicLDMPipeline,
|
| 40 |
+
PNDMScheduler,
|
| 41 |
+
UNet2DConditionModel,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
|
| 46 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
| 47 |
+
"""
|
| 48 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
| 49 |
+
"""
|
| 50 |
+
if n_shave_prefix_segments >= 0:
|
| 51 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
| 52 |
+
else:
|
| 53 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_resnet_paths
|
| 57 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 58 |
+
"""
|
| 59 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 60 |
+
"""
|
| 61 |
+
mapping = []
|
| 62 |
+
for old_item in old_list:
|
| 63 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
| 64 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
| 65 |
+
|
| 66 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
| 67 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
| 68 |
+
|
| 69 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
| 70 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
| 71 |
+
|
| 72 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 73 |
+
|
| 74 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 75 |
+
|
| 76 |
+
return mapping
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths
|
| 80 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 81 |
+
"""
|
| 82 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 83 |
+
"""
|
| 84 |
+
mapping = []
|
| 85 |
+
for old_item in old_list:
|
| 86 |
+
new_item = old_item
|
| 87 |
+
|
| 88 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
| 89 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 90 |
+
|
| 91 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 92 |
+
|
| 93 |
+
return mapping
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_attention_paths
|
| 97 |
+
def renew_attention_paths(old_list):
|
| 98 |
+
"""
|
| 99 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 100 |
+
"""
|
| 101 |
+
mapping = []
|
| 102 |
+
for old_item in old_list:
|
| 103 |
+
new_item = old_item
|
| 104 |
+
|
| 105 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
| 106 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
| 107 |
+
|
| 108 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
| 109 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
| 110 |
+
|
| 111 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 112 |
+
|
| 113 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 114 |
+
|
| 115 |
+
return mapping
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 119 |
+
"""
|
| 120 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 121 |
+
"""
|
| 122 |
+
mapping = []
|
| 123 |
+
for old_item in old_list:
|
| 124 |
+
new_item = old_item
|
| 125 |
+
|
| 126 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
| 127 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
| 128 |
+
|
| 129 |
+
new_item = new_item.replace("q.weight", "to_q.weight")
|
| 130 |
+
new_item = new_item.replace("q.bias", "to_q.bias")
|
| 131 |
+
|
| 132 |
+
new_item = new_item.replace("k.weight", "to_k.weight")
|
| 133 |
+
new_item = new_item.replace("k.bias", "to_k.bias")
|
| 134 |
+
|
| 135 |
+
new_item = new_item.replace("v.weight", "to_v.weight")
|
| 136 |
+
new_item = new_item.replace("v.bias", "to_v.bias")
|
| 137 |
+
|
| 138 |
+
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
| 139 |
+
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
| 140 |
+
|
| 141 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 142 |
+
|
| 143 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 144 |
+
|
| 145 |
+
return mapping
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
|
| 149 |
+
def assign_to_checkpoint(
|
| 150 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
| 151 |
+
):
|
| 152 |
+
"""
|
| 153 |
+
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
| 154 |
+
attention layers, and takes into account additional replacements that may arise.
|
| 155 |
+
|
| 156 |
+
Assigns the weights to the new checkpoint.
|
| 157 |
+
"""
|
| 158 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
| 159 |
+
|
| 160 |
+
# Splits the attention layers into three variables.
|
| 161 |
+
if attention_paths_to_split is not None:
|
| 162 |
+
for path, path_map in attention_paths_to_split.items():
|
| 163 |
+
old_tensor = old_checkpoint[path]
|
| 164 |
+
channels = old_tensor.shape[0] // 3
|
| 165 |
+
|
| 166 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
| 167 |
+
|
| 168 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
| 169 |
+
|
| 170 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
| 171 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
| 172 |
+
|
| 173 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
| 174 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
| 175 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
| 176 |
+
|
| 177 |
+
for path in paths:
|
| 178 |
+
new_path = path["new"]
|
| 179 |
+
|
| 180 |
+
# These have already been assigned
|
| 181 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
# Global renaming happens here
|
| 185 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
| 186 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
| 187 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
| 188 |
+
|
| 189 |
+
if additional_replacements is not None:
|
| 190 |
+
for replacement in additional_replacements:
|
| 191 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
| 192 |
+
|
| 193 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
| 194 |
+
if "proj_attn.weight" in new_path:
|
| 195 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
| 196 |
+
else:
|
| 197 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def conv_attn_to_linear(checkpoint):
|
| 201 |
+
keys = list(checkpoint.keys())
|
| 202 |
+
attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
|
| 203 |
+
proj_key = "to_out.0.weight"
|
| 204 |
+
for key in keys:
|
| 205 |
+
if ".".join(key.split(".")[-2:]) in attn_keys or ".".join(key.split(".")[-3:]) == proj_key:
|
| 206 |
+
if checkpoint[key].ndim > 2:
|
| 207 |
+
checkpoint[key] = checkpoint[key].squeeze()
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def create_unet_diffusers_config(original_config, image_size: int):
|
| 211 |
+
"""
|
| 212 |
+
Creates a UNet config for diffusers based on the config of the original MusicLDM model.
|
| 213 |
+
"""
|
| 214 |
+
unet_params = original_config["model"]["params"]["unet_config"]["params"]
|
| 215 |
+
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
| 216 |
+
|
| 217 |
+
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
|
| 218 |
+
|
| 219 |
+
down_block_types = []
|
| 220 |
+
resolution = 1
|
| 221 |
+
for i in range(len(block_out_channels)):
|
| 222 |
+
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
|
| 223 |
+
down_block_types.append(block_type)
|
| 224 |
+
if i != len(block_out_channels) - 1:
|
| 225 |
+
resolution *= 2
|
| 226 |
+
|
| 227 |
+
up_block_types = []
|
| 228 |
+
for i in range(len(block_out_channels)):
|
| 229 |
+
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
|
| 230 |
+
up_block_types.append(block_type)
|
| 231 |
+
resolution //= 2
|
| 232 |
+
|
| 233 |
+
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
|
| 234 |
+
|
| 235 |
+
cross_attention_dim = (
|
| 236 |
+
unet_params["cross_attention_dim"] if "cross_attention_dim" in unet_params else block_out_channels
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None
|
| 240 |
+
projection_class_embeddings_input_dim = (
|
| 241 |
+
unet_params["extra_film_condition_dim"] if "extra_film_condition_dim" in unet_params else None
|
| 242 |
+
)
|
| 243 |
+
class_embeddings_concat = unet_params["extra_film_use_concat"] if "extra_film_use_concat" in unet_params else None
|
| 244 |
+
|
| 245 |
+
config = {
|
| 246 |
+
"sample_size": image_size // vae_scale_factor,
|
| 247 |
+
"in_channels": unet_params["in_channels"],
|
| 248 |
+
"out_channels": unet_params["out_channels"],
|
| 249 |
+
"down_block_types": tuple(down_block_types),
|
| 250 |
+
"up_block_types": tuple(up_block_types),
|
| 251 |
+
"block_out_channels": tuple(block_out_channels),
|
| 252 |
+
"layers_per_block": unet_params["num_res_blocks"],
|
| 253 |
+
"cross_attention_dim": cross_attention_dim,
|
| 254 |
+
"class_embed_type": class_embed_type,
|
| 255 |
+
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
| 256 |
+
"class_embeddings_concat": class_embeddings_concat,
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
return config
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# Adapted from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config
|
| 263 |
+
def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
|
| 264 |
+
"""
|
| 265 |
+
Creates a VAE config for diffusers based on the config of the original MusicLDM model. Compared to the original
|
| 266 |
+
Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
|
| 267 |
+
"""
|
| 268 |
+
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
| 269 |
+
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
|
| 270 |
+
|
| 271 |
+
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
| 272 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
| 273 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
| 274 |
+
|
| 275 |
+
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215
|
| 276 |
+
|
| 277 |
+
config = {
|
| 278 |
+
"sample_size": image_size,
|
| 279 |
+
"in_channels": vae_params["in_channels"],
|
| 280 |
+
"out_channels": vae_params["out_ch"],
|
| 281 |
+
"down_block_types": tuple(down_block_types),
|
| 282 |
+
"up_block_types": tuple(up_block_types),
|
| 283 |
+
"block_out_channels": tuple(block_out_channels),
|
| 284 |
+
"latent_channels": vae_params["z_channels"],
|
| 285 |
+
"layers_per_block": vae_params["num_res_blocks"],
|
| 286 |
+
"scaling_factor": float(scaling_factor),
|
| 287 |
+
}
|
| 288 |
+
return config
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
|
| 292 |
+
def create_diffusers_schedular(original_config):
|
| 293 |
+
schedular = DDIMScheduler(
|
| 294 |
+
num_train_timesteps=original_config["model"]["params"]["timesteps"],
|
| 295 |
+
beta_start=original_config["model"]["params"]["linear_start"],
|
| 296 |
+
beta_end=original_config["model"]["params"]["linear_end"],
|
| 297 |
+
beta_schedule="scaled_linear",
|
| 298 |
+
)
|
| 299 |
+
return schedular
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
|
| 303 |
+
"""
|
| 304 |
+
Takes a state dict and a config, and returns a converted checkpoint. Compared to the original Stable Diffusion
|
| 305 |
+
conversion, this function additionally converts the learnt film embedding linear layer.
|
| 306 |
+
"""
|
| 307 |
+
|
| 308 |
+
# extract state_dict for UNet
|
| 309 |
+
unet_state_dict = {}
|
| 310 |
+
keys = list(checkpoint.keys())
|
| 311 |
+
|
| 312 |
+
unet_key = "model.diffusion_model."
|
| 313 |
+
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
| 314 |
+
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
| 315 |
+
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
| 316 |
+
print(
|
| 317 |
+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
| 318 |
+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
| 319 |
+
)
|
| 320 |
+
for key in keys:
|
| 321 |
+
if key.startswith("model.diffusion_model"):
|
| 322 |
+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
| 323 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
| 324 |
+
else:
|
| 325 |
+
if sum(k.startswith("model_ema") for k in keys) > 100:
|
| 326 |
+
print(
|
| 327 |
+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
| 328 |
+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
for key in keys:
|
| 332 |
+
if key.startswith(unet_key):
|
| 333 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
| 334 |
+
|
| 335 |
+
new_checkpoint = {}
|
| 336 |
+
|
| 337 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
| 338 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
| 339 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
| 340 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
| 341 |
+
|
| 342 |
+
new_checkpoint["class_embedding.weight"] = unet_state_dict["film_emb.weight"]
|
| 343 |
+
new_checkpoint["class_embedding.bias"] = unet_state_dict["film_emb.bias"]
|
| 344 |
+
|
| 345 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
| 346 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
| 347 |
+
|
| 348 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
| 349 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
| 350 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
| 351 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
| 352 |
+
|
| 353 |
+
# Retrieves the keys for the input blocks only
|
| 354 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
| 355 |
+
input_blocks = {
|
| 356 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
| 357 |
+
for layer_id in range(num_input_blocks)
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
# Retrieves the keys for the middle blocks only
|
| 361 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
| 362 |
+
middle_blocks = {
|
| 363 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
| 364 |
+
for layer_id in range(num_middle_blocks)
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
# Retrieves the keys for the output blocks only
|
| 368 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
| 369 |
+
output_blocks = {
|
| 370 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
| 371 |
+
for layer_id in range(num_output_blocks)
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
for i in range(1, num_input_blocks):
|
| 375 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
| 376 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
| 377 |
+
|
| 378 |
+
resnets = [
|
| 379 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
| 380 |
+
]
|
| 381 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
| 382 |
+
|
| 383 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
| 384 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
| 385 |
+
f"input_blocks.{i}.0.op.weight"
|
| 386 |
+
)
|
| 387 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
| 388 |
+
f"input_blocks.{i}.0.op.bias"
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
paths = renew_resnet_paths(resnets)
|
| 392 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 393 |
+
assign_to_checkpoint(
|
| 394 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
if len(attentions):
|
| 398 |
+
paths = renew_attention_paths(attentions)
|
| 399 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
| 400 |
+
assign_to_checkpoint(
|
| 401 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
resnet_0 = middle_blocks[0]
|
| 405 |
+
attentions = middle_blocks[1]
|
| 406 |
+
resnet_1 = middle_blocks[2]
|
| 407 |
+
|
| 408 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
| 409 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
| 410 |
+
|
| 411 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
| 412 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
| 413 |
+
|
| 414 |
+
attentions_paths = renew_attention_paths(attentions)
|
| 415 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
| 416 |
+
assign_to_checkpoint(
|
| 417 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
for i in range(num_output_blocks):
|
| 421 |
+
block_id = i // (config["layers_per_block"] + 1)
|
| 422 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
| 423 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
| 424 |
+
output_block_list = {}
|
| 425 |
+
|
| 426 |
+
for layer in output_block_layers:
|
| 427 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
| 428 |
+
if layer_id in output_block_list:
|
| 429 |
+
output_block_list[layer_id].append(layer_name)
|
| 430 |
+
else:
|
| 431 |
+
output_block_list[layer_id] = [layer_name]
|
| 432 |
+
|
| 433 |
+
if len(output_block_list) > 1:
|
| 434 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
| 435 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
| 436 |
+
|
| 437 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
| 438 |
+
paths = renew_resnet_paths(resnets)
|
| 439 |
+
|
| 440 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 441 |
+
assign_to_checkpoint(
|
| 442 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
| 446 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
| 447 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
| 448 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
| 449 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
| 450 |
+
]
|
| 451 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
| 452 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
| 453 |
+
]
|
| 454 |
+
|
| 455 |
+
# Clear attentions as they have been attributed above.
|
| 456 |
+
if len(attentions) == 2:
|
| 457 |
+
attentions = []
|
| 458 |
+
|
| 459 |
+
if len(attentions):
|
| 460 |
+
paths = renew_attention_paths(attentions)
|
| 461 |
+
meta_path = {
|
| 462 |
+
"old": f"output_blocks.{i}.1",
|
| 463 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
| 464 |
+
}
|
| 465 |
+
assign_to_checkpoint(
|
| 466 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 467 |
+
)
|
| 468 |
+
else:
|
| 469 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
| 470 |
+
for path in resnet_0_paths:
|
| 471 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
| 472 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
| 473 |
+
|
| 474 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
| 475 |
+
|
| 476 |
+
return new_checkpoint
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
|
| 480 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
| 481 |
+
# extract state dict for VAE
|
| 482 |
+
vae_state_dict = {}
|
| 483 |
+
vae_key = "first_stage_model."
|
| 484 |
+
keys = list(checkpoint.keys())
|
| 485 |
+
for key in keys:
|
| 486 |
+
if key.startswith(vae_key):
|
| 487 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
| 488 |
+
|
| 489 |
+
new_checkpoint = {}
|
| 490 |
+
|
| 491 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
| 492 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
| 493 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
| 494 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
| 495 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
| 496 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
| 497 |
+
|
| 498 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
| 499 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
| 500 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
| 501 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
| 502 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
| 503 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
| 504 |
+
|
| 505 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
| 506 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
| 507 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
| 508 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
| 509 |
+
|
| 510 |
+
# Retrieves the keys for the encoder down blocks only
|
| 511 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
| 512 |
+
down_blocks = {
|
| 513 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
# Retrieves the keys for the decoder up blocks only
|
| 517 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
| 518 |
+
up_blocks = {
|
| 519 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
for i in range(num_down_blocks):
|
| 523 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
| 524 |
+
|
| 525 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
| 526 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
| 527 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
| 528 |
+
)
|
| 529 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
| 530 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 534 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
| 535 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 536 |
+
|
| 537 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
| 538 |
+
num_mid_res_blocks = 2
|
| 539 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 540 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
| 541 |
+
|
| 542 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 543 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 544 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 545 |
+
|
| 546 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
| 547 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 548 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 549 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 550 |
+
conv_attn_to_linear(new_checkpoint)
|
| 551 |
+
|
| 552 |
+
for i in range(num_up_blocks):
|
| 553 |
+
block_id = num_up_blocks - 1 - i
|
| 554 |
+
resnets = [
|
| 555 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
| 556 |
+
]
|
| 557 |
+
|
| 558 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
| 559 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
| 560 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
| 561 |
+
]
|
| 562 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
| 563 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
| 564 |
+
]
|
| 565 |
+
|
| 566 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 567 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
| 568 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 569 |
+
|
| 570 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
| 571 |
+
num_mid_res_blocks = 2
|
| 572 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 573 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
| 574 |
+
|
| 575 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 576 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 577 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 578 |
+
|
| 579 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
| 580 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 581 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 582 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 583 |
+
conv_attn_to_linear(new_checkpoint)
|
| 584 |
+
return new_checkpoint
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
CLAP_KEYS_TO_MODIFY_MAPPING = {
|
| 588 |
+
"text_branch": "text_model",
|
| 589 |
+
"audio_branch": "audio_model.audio_encoder",
|
| 590 |
+
"attn": "attention.self",
|
| 591 |
+
"self.proj": "output.dense",
|
| 592 |
+
"attention.self_mask": "attn_mask",
|
| 593 |
+
"mlp.fc1": "intermediate.dense",
|
| 594 |
+
"mlp.fc2": "output.dense",
|
| 595 |
+
"norm1": "layernorm_before",
|
| 596 |
+
"norm2": "layernorm_after",
|
| 597 |
+
"bn0": "batch_norm",
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
CLAP_KEYS_TO_IGNORE = [
|
| 601 |
+
"text_transform",
|
| 602 |
+
"audio_transform",
|
| 603 |
+
"stft",
|
| 604 |
+
"logmel_extractor",
|
| 605 |
+
"tscam_conv",
|
| 606 |
+
"head",
|
| 607 |
+
"attn_mask",
|
| 608 |
+
]
|
| 609 |
+
|
| 610 |
+
CLAP_EXPECTED_MISSING_KEYS = ["text_model.embeddings.token_type_ids"]
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def convert_open_clap_checkpoint(checkpoint):
|
| 614 |
+
"""
|
| 615 |
+
Takes a state dict and returns a converted CLAP checkpoint.
|
| 616 |
+
"""
|
| 617 |
+
# extract state dict for CLAP text embedding model, discarding the audio component
|
| 618 |
+
model_state_dict = {}
|
| 619 |
+
model_key = "cond_stage_model.model."
|
| 620 |
+
keys = list(checkpoint.keys())
|
| 621 |
+
for key in keys:
|
| 622 |
+
if key.startswith(model_key):
|
| 623 |
+
model_state_dict[key.replace(model_key, "")] = checkpoint.get(key)
|
| 624 |
+
|
| 625 |
+
new_checkpoint = {}
|
| 626 |
+
|
| 627 |
+
sequential_layers_pattern = r".*sequential.(\d+).*"
|
| 628 |
+
text_projection_pattern = r".*_projection.(\d+).*"
|
| 629 |
+
|
| 630 |
+
for key, value in model_state_dict.items():
|
| 631 |
+
# check if key should be ignored in mapping - if so map it to a key name that we'll filter out at the end
|
| 632 |
+
for key_to_ignore in CLAP_KEYS_TO_IGNORE:
|
| 633 |
+
if key_to_ignore in key:
|
| 634 |
+
key = "spectrogram"
|
| 635 |
+
|
| 636 |
+
# check if any key needs to be modified
|
| 637 |
+
for key_to_modify, new_key in CLAP_KEYS_TO_MODIFY_MAPPING.items():
|
| 638 |
+
if key_to_modify in key:
|
| 639 |
+
key = key.replace(key_to_modify, new_key)
|
| 640 |
+
|
| 641 |
+
if re.match(sequential_layers_pattern, key):
|
| 642 |
+
# replace sequential layers with list
|
| 643 |
+
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
|
| 644 |
+
|
| 645 |
+
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
|
| 646 |
+
elif re.match(text_projection_pattern, key):
|
| 647 |
+
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
|
| 648 |
+
|
| 649 |
+
# Because in CLAP they use `nn.Sequential`...
|
| 650 |
+
transformers_projection_layer = 1 if projecton_layer == 0 else 2
|
| 651 |
+
|
| 652 |
+
key = key.replace(f"_projection.{projecton_layer}.", f"_projection.linear{transformers_projection_layer}.")
|
| 653 |
+
|
| 654 |
+
if "audio" and "qkv" in key:
|
| 655 |
+
# split qkv into query key and value
|
| 656 |
+
mixed_qkv = value
|
| 657 |
+
qkv_dim = mixed_qkv.size(0) // 3
|
| 658 |
+
|
| 659 |
+
query_layer = mixed_qkv[:qkv_dim]
|
| 660 |
+
key_layer = mixed_qkv[qkv_dim : qkv_dim * 2]
|
| 661 |
+
value_layer = mixed_qkv[qkv_dim * 2 :]
|
| 662 |
+
|
| 663 |
+
new_checkpoint[key.replace("qkv", "query")] = query_layer
|
| 664 |
+
new_checkpoint[key.replace("qkv", "key")] = key_layer
|
| 665 |
+
new_checkpoint[key.replace("qkv", "value")] = value_layer
|
| 666 |
+
elif key != "spectrogram":
|
| 667 |
+
new_checkpoint[key] = value
|
| 668 |
+
|
| 669 |
+
return new_checkpoint
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
def create_transformers_vocoder_config(original_config):
|
| 673 |
+
"""
|
| 674 |
+
Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
|
| 675 |
+
"""
|
| 676 |
+
vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"]
|
| 677 |
+
|
| 678 |
+
config = {
|
| 679 |
+
"model_in_dim": vocoder_params["num_mels"],
|
| 680 |
+
"sampling_rate": vocoder_params["sampling_rate"],
|
| 681 |
+
"upsample_initial_channel": vocoder_params["upsample_initial_channel"],
|
| 682 |
+
"upsample_rates": list(vocoder_params["upsample_rates"]),
|
| 683 |
+
"upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]),
|
| 684 |
+
"resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]),
|
| 685 |
+
"resblock_dilation_sizes": [
|
| 686 |
+
list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"]
|
| 687 |
+
],
|
| 688 |
+
"normalize_before": False,
|
| 689 |
+
}
|
| 690 |
+
|
| 691 |
+
return config
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
def convert_hifigan_checkpoint(checkpoint, config):
|
| 695 |
+
"""
|
| 696 |
+
Takes a state dict and config, and returns a converted HiFiGAN vocoder checkpoint.
|
| 697 |
+
"""
|
| 698 |
+
# extract state dict for vocoder
|
| 699 |
+
vocoder_state_dict = {}
|
| 700 |
+
vocoder_key = "first_stage_model.vocoder."
|
| 701 |
+
keys = list(checkpoint.keys())
|
| 702 |
+
for key in keys:
|
| 703 |
+
if key.startswith(vocoder_key):
|
| 704 |
+
vocoder_state_dict[key.replace(vocoder_key, "")] = checkpoint.get(key)
|
| 705 |
+
|
| 706 |
+
# fix upsampler keys, everything else is correct already
|
| 707 |
+
for i in range(len(config.upsample_rates)):
|
| 708 |
+
vocoder_state_dict[f"upsampler.{i}.weight"] = vocoder_state_dict.pop(f"ups.{i}.weight")
|
| 709 |
+
vocoder_state_dict[f"upsampler.{i}.bias"] = vocoder_state_dict.pop(f"ups.{i}.bias")
|
| 710 |
+
|
| 711 |
+
if not config.normalize_before:
|
| 712 |
+
# if we don't set normalize_before then these variables are unused, so we set them to their initialised values
|
| 713 |
+
vocoder_state_dict["mean"] = torch.zeros(config.model_in_dim)
|
| 714 |
+
vocoder_state_dict["scale"] = torch.ones(config.model_in_dim)
|
| 715 |
+
|
| 716 |
+
return vocoder_state_dict
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
# Adapted from https://huggingface.co/spaces/haoheliu/MusicLDM-text-to-audio-generation/blob/84a0384742a22bd80c44e903e241f0623e874f1d/MusicLDM/utils.py#L72-L73
|
| 720 |
+
DEFAULT_CONFIG = {
|
| 721 |
+
"model": {
|
| 722 |
+
"params": {
|
| 723 |
+
"linear_start": 0.0015,
|
| 724 |
+
"linear_end": 0.0195,
|
| 725 |
+
"timesteps": 1000,
|
| 726 |
+
"channels": 8,
|
| 727 |
+
"scale_by_std": True,
|
| 728 |
+
"unet_config": {
|
| 729 |
+
"target": "MusicLDM.latent_diffusion.openaimodel.UNetModel",
|
| 730 |
+
"params": {
|
| 731 |
+
"extra_film_condition_dim": 512,
|
| 732 |
+
"extra_film_use_concat": True,
|
| 733 |
+
"in_channels": 8,
|
| 734 |
+
"out_channels": 8,
|
| 735 |
+
"model_channels": 128,
|
| 736 |
+
"attention_resolutions": [8, 4, 2],
|
| 737 |
+
"num_res_blocks": 2,
|
| 738 |
+
"channel_mult": [1, 2, 3, 5],
|
| 739 |
+
"num_head_channels": 32,
|
| 740 |
+
},
|
| 741 |
+
},
|
| 742 |
+
"first_stage_config": {
|
| 743 |
+
"target": "MusicLDM.variational_autoencoder.autoencoder.AutoencoderKL",
|
| 744 |
+
"params": {
|
| 745 |
+
"embed_dim": 8,
|
| 746 |
+
"ddconfig": {
|
| 747 |
+
"z_channels": 8,
|
| 748 |
+
"resolution": 256,
|
| 749 |
+
"in_channels": 1,
|
| 750 |
+
"out_ch": 1,
|
| 751 |
+
"ch": 128,
|
| 752 |
+
"ch_mult": [1, 2, 4],
|
| 753 |
+
"num_res_blocks": 2,
|
| 754 |
+
},
|
| 755 |
+
},
|
| 756 |
+
},
|
| 757 |
+
"vocoder_config": {
|
| 758 |
+
"target": "MusicLDM.first_stage_model.vocoder",
|
| 759 |
+
"params": {
|
| 760 |
+
"upsample_rates": [5, 4, 2, 2, 2],
|
| 761 |
+
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
|
| 762 |
+
"upsample_initial_channel": 1024,
|
| 763 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
| 764 |
+
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 765 |
+
"num_mels": 64,
|
| 766 |
+
"sampling_rate": 16000,
|
| 767 |
+
},
|
| 768 |
+
},
|
| 769 |
+
},
|
| 770 |
+
},
|
| 771 |
+
}
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
def load_pipeline_from_original_MusicLDM_ckpt(
|
| 775 |
+
checkpoint_path: str,
|
| 776 |
+
original_config_file: str = None,
|
| 777 |
+
image_size: int = 1024,
|
| 778 |
+
prediction_type: str = None,
|
| 779 |
+
extract_ema: bool = False,
|
| 780 |
+
scheduler_type: str = "ddim",
|
| 781 |
+
num_in_channels: int = None,
|
| 782 |
+
model_channels: int = None,
|
| 783 |
+
num_head_channels: int = None,
|
| 784 |
+
device: str = None,
|
| 785 |
+
from_safetensors: bool = False,
|
| 786 |
+
) -> MusicLDMPipeline:
|
| 787 |
+
"""
|
| 788 |
+
Load an MusicLDM pipeline object from a `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file.
|
| 789 |
+
|
| 790 |
+
Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the
|
| 791 |
+
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
|
| 792 |
+
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
|
| 793 |
+
|
| 794 |
+
Args:
|
| 795 |
+
checkpoint_path (`str`): Path to `.ckpt` file.
|
| 796 |
+
original_config_file (`str`):
|
| 797 |
+
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
|
| 798 |
+
set to the MusicLDM-s-full-v2 config.
|
| 799 |
+
image_size (`int`, *optional*, defaults to 1024):
|
| 800 |
+
The image size that the model was trained on.
|
| 801 |
+
prediction_type (`str`, *optional*):
|
| 802 |
+
The prediction type that the model was trained on. If `None`, will be automatically
|
| 803 |
+
inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`.
|
| 804 |
+
num_in_channels (`int`, *optional*, defaults to None):
|
| 805 |
+
The number of UNet input channels. If `None`, it will be automatically inferred from the config.
|
| 806 |
+
model_channels (`int`, *optional*, defaults to None):
|
| 807 |
+
The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override
|
| 808 |
+
to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.
|
| 809 |
+
num_head_channels (`int`, *optional*, defaults to None):
|
| 810 |
+
The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override
|
| 811 |
+
to 32 for the small and medium checkpoints, and 64 for the large.
|
| 812 |
+
scheduler_type (`str`, *optional*, defaults to 'pndm'):
|
| 813 |
+
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
| 814 |
+
"ddim"]`.
|
| 815 |
+
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
| 816 |
+
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
|
| 817 |
+
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
| 818 |
+
inference. Non-EMA weights are usually better to continue fine-tuning.
|
| 819 |
+
device (`str`, *optional*, defaults to `None`):
|
| 820 |
+
The device to use. Pass `None` to determine automatically.
|
| 821 |
+
from_safetensors (`str`, *optional*, defaults to `False`):
|
| 822 |
+
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
|
| 823 |
+
return: An MusicLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
| 824 |
+
"""
|
| 825 |
+
if from_safetensors:
|
| 826 |
+
from safetensors import safe_open
|
| 827 |
+
|
| 828 |
+
checkpoint = {}
|
| 829 |
+
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
|
| 830 |
+
for key in f.keys():
|
| 831 |
+
checkpoint[key] = f.get_tensor(key)
|
| 832 |
+
else:
|
| 833 |
+
if device is None:
|
| 834 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 835 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 836 |
+
else:
|
| 837 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 838 |
+
|
| 839 |
+
if "state_dict" in checkpoint:
|
| 840 |
+
checkpoint = checkpoint["state_dict"]
|
| 841 |
+
|
| 842 |
+
if original_config_file is None:
|
| 843 |
+
original_config = DEFAULT_CONFIG
|
| 844 |
+
else:
|
| 845 |
+
original_config = yaml.safe_load(original_config_file)
|
| 846 |
+
|
| 847 |
+
if num_in_channels is not None:
|
| 848 |
+
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
| 849 |
+
|
| 850 |
+
if model_channels is not None:
|
| 851 |
+
original_config["model"]["params"]["unet_config"]["params"]["model_channels"] = model_channels
|
| 852 |
+
|
| 853 |
+
if num_head_channels is not None:
|
| 854 |
+
original_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = num_head_channels
|
| 855 |
+
|
| 856 |
+
if (
|
| 857 |
+
"parameterization" in original_config["model"]["params"]
|
| 858 |
+
and original_config["model"]["params"]["parameterization"] == "v"
|
| 859 |
+
):
|
| 860 |
+
if prediction_type is None:
|
| 861 |
+
prediction_type = "v_prediction"
|
| 862 |
+
else:
|
| 863 |
+
if prediction_type is None:
|
| 864 |
+
prediction_type = "epsilon"
|
| 865 |
+
|
| 866 |
+
if image_size is None:
|
| 867 |
+
image_size = 512
|
| 868 |
+
|
| 869 |
+
num_train_timesteps = original_config["model"]["params"]["timesteps"]
|
| 870 |
+
beta_start = original_config["model"]["params"]["linear_start"]
|
| 871 |
+
beta_end = original_config["model"]["params"]["linear_end"]
|
| 872 |
+
|
| 873 |
+
scheduler = DDIMScheduler(
|
| 874 |
+
beta_end=beta_end,
|
| 875 |
+
beta_schedule="scaled_linear",
|
| 876 |
+
beta_start=beta_start,
|
| 877 |
+
num_train_timesteps=num_train_timesteps,
|
| 878 |
+
steps_offset=1,
|
| 879 |
+
clip_sample=False,
|
| 880 |
+
set_alpha_to_one=False,
|
| 881 |
+
prediction_type=prediction_type,
|
| 882 |
+
)
|
| 883 |
+
# make sure scheduler works correctly with DDIM
|
| 884 |
+
scheduler.register_to_config(clip_sample=False)
|
| 885 |
+
|
| 886 |
+
if scheduler_type == "pndm":
|
| 887 |
+
config = dict(scheduler.config)
|
| 888 |
+
config["skip_prk_steps"] = True
|
| 889 |
+
scheduler = PNDMScheduler.from_config(config)
|
| 890 |
+
elif scheduler_type == "lms":
|
| 891 |
+
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
| 892 |
+
elif scheduler_type == "heun":
|
| 893 |
+
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
| 894 |
+
elif scheduler_type == "euler":
|
| 895 |
+
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
| 896 |
+
elif scheduler_type == "euler-ancestral":
|
| 897 |
+
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
| 898 |
+
elif scheduler_type == "dpm":
|
| 899 |
+
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
| 900 |
+
elif scheduler_type == "ddim":
|
| 901 |
+
scheduler = scheduler
|
| 902 |
+
else:
|
| 903 |
+
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
| 904 |
+
|
| 905 |
+
# Convert the UNet2DModel
|
| 906 |
+
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
| 907 |
+
unet = UNet2DConditionModel(**unet_config)
|
| 908 |
+
|
| 909 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
| 910 |
+
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
unet.load_state_dict(converted_unet_checkpoint)
|
| 914 |
+
|
| 915 |
+
# Convert the VAE model
|
| 916 |
+
vae_config = create_vae_diffusers_config(original_config, checkpoint=checkpoint, image_size=image_size)
|
| 917 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
| 918 |
+
|
| 919 |
+
vae = AutoencoderKL(**vae_config)
|
| 920 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
| 921 |
+
|
| 922 |
+
# Convert the text model
|
| 923 |
+
# MusicLDM uses the same tokenizer as the original CLAP model, but a slightly different configuration
|
| 924 |
+
config = ClapConfig.from_pretrained("laion/clap-htsat-unfused")
|
| 925 |
+
config.audio_config.update(
|
| 926 |
+
{
|
| 927 |
+
"patch_embeds_hidden_size": 128,
|
| 928 |
+
"hidden_size": 1024,
|
| 929 |
+
"depths": [2, 2, 12, 2],
|
| 930 |
+
}
|
| 931 |
+
)
|
| 932 |
+
tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
|
| 933 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused")
|
| 934 |
+
|
| 935 |
+
converted_text_model = convert_open_clap_checkpoint(checkpoint)
|
| 936 |
+
text_model = ClapModel(config)
|
| 937 |
+
|
| 938 |
+
missing_keys, unexpected_keys = text_model.load_state_dict(converted_text_model, strict=False)
|
| 939 |
+
# we expect not to have token_type_ids in our original state dict so let's ignore them
|
| 940 |
+
missing_keys = list(set(missing_keys) - set(CLAP_EXPECTED_MISSING_KEYS))
|
| 941 |
+
|
| 942 |
+
if len(unexpected_keys) > 0:
|
| 943 |
+
raise ValueError(f"Unexpected keys when loading CLAP model: {unexpected_keys}")
|
| 944 |
+
|
| 945 |
+
if len(missing_keys) > 0:
|
| 946 |
+
raise ValueError(f"Missing keys when loading CLAP model: {missing_keys}")
|
| 947 |
+
|
| 948 |
+
# Convert the vocoder model
|
| 949 |
+
vocoder_config = create_transformers_vocoder_config(original_config)
|
| 950 |
+
vocoder_config = SpeechT5HifiGanConfig(**vocoder_config)
|
| 951 |
+
converted_vocoder_checkpoint = convert_hifigan_checkpoint(checkpoint, vocoder_config)
|
| 952 |
+
|
| 953 |
+
vocoder = SpeechT5HifiGan(vocoder_config)
|
| 954 |
+
vocoder.load_state_dict(converted_vocoder_checkpoint)
|
| 955 |
+
|
| 956 |
+
# Instantiate the diffusers pipeline
|
| 957 |
+
pipe = MusicLDMPipeline(
|
| 958 |
+
vae=vae,
|
| 959 |
+
text_encoder=text_model,
|
| 960 |
+
tokenizer=tokenizer,
|
| 961 |
+
unet=unet,
|
| 962 |
+
scheduler=scheduler,
|
| 963 |
+
vocoder=vocoder,
|
| 964 |
+
feature_extractor=feature_extractor,
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
return pipe
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
if __name__ == "__main__":
|
| 971 |
+
parser = argparse.ArgumentParser()
|
| 972 |
+
|
| 973 |
+
parser.add_argument(
|
| 974 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
| 975 |
+
)
|
| 976 |
+
parser.add_argument(
|
| 977 |
+
"--original_config_file",
|
| 978 |
+
default=None,
|
| 979 |
+
type=str,
|
| 980 |
+
help="The YAML config file corresponding to the original architecture.",
|
| 981 |
+
)
|
| 982 |
+
parser.add_argument(
|
| 983 |
+
"--num_in_channels",
|
| 984 |
+
default=None,
|
| 985 |
+
type=int,
|
| 986 |
+
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
|
| 987 |
+
)
|
| 988 |
+
parser.add_argument(
|
| 989 |
+
"--model_channels",
|
| 990 |
+
default=None,
|
| 991 |
+
type=int,
|
| 992 |
+
help="The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override"
|
| 993 |
+
" to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.",
|
| 994 |
+
)
|
| 995 |
+
parser.add_argument(
|
| 996 |
+
"--num_head_channels",
|
| 997 |
+
default=None,
|
| 998 |
+
type=int,
|
| 999 |
+
help="The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override"
|
| 1000 |
+
" to 32 for the small and medium checkpoints, and 64 for the large.",
|
| 1001 |
+
)
|
| 1002 |
+
parser.add_argument(
|
| 1003 |
+
"--scheduler_type",
|
| 1004 |
+
default="ddim",
|
| 1005 |
+
type=str,
|
| 1006 |
+
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
|
| 1007 |
+
)
|
| 1008 |
+
parser.add_argument(
|
| 1009 |
+
"--image_size",
|
| 1010 |
+
default=None,
|
| 1011 |
+
type=int,
|
| 1012 |
+
help=("The image size that the model was trained on."),
|
| 1013 |
+
)
|
| 1014 |
+
parser.add_argument(
|
| 1015 |
+
"--prediction_type",
|
| 1016 |
+
default=None,
|
| 1017 |
+
type=str,
|
| 1018 |
+
help=("The prediction type that the model was trained on."),
|
| 1019 |
+
)
|
| 1020 |
+
parser.add_argument(
|
| 1021 |
+
"--extract_ema",
|
| 1022 |
+
action="store_true",
|
| 1023 |
+
help=(
|
| 1024 |
+
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
|
| 1025 |
+
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
|
| 1026 |
+
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
|
| 1027 |
+
),
|
| 1028 |
+
)
|
| 1029 |
+
parser.add_argument(
|
| 1030 |
+
"--from_safetensors",
|
| 1031 |
+
action="store_true",
|
| 1032 |
+
help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
|
| 1033 |
+
)
|
| 1034 |
+
parser.add_argument(
|
| 1035 |
+
"--to_safetensors",
|
| 1036 |
+
action="store_true",
|
| 1037 |
+
help="Whether to store pipeline in safetensors format or not.",
|
| 1038 |
+
)
|
| 1039 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
| 1040 |
+
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
| 1041 |
+
args = parser.parse_args()
|
| 1042 |
+
|
| 1043 |
+
pipe = load_pipeline_from_original_MusicLDM_ckpt(
|
| 1044 |
+
checkpoint_path=args.checkpoint_path,
|
| 1045 |
+
original_config_file=args.original_config_file,
|
| 1046 |
+
image_size=args.image_size,
|
| 1047 |
+
prediction_type=args.prediction_type,
|
| 1048 |
+
extract_ema=args.extract_ema,
|
| 1049 |
+
scheduler_type=args.scheduler_type,
|
| 1050 |
+
num_in_channels=args.num_in_channels,
|
| 1051 |
+
model_channels=args.model_channels,
|
| 1052 |
+
num_head_channels=args.num_head_channels,
|
| 1053 |
+
from_safetensors=args.from_safetensors,
|
| 1054 |
+
device=args.device,
|
| 1055 |
+
)
|
| 1056 |
+
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
diffusers/scripts/convert_pixart_sigma_to_diffusers.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 6 |
+
|
| 7 |
+
from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtSigmaPipeline, Transformer2DModel
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
ckpt_id = "PixArt-alpha"
|
| 11 |
+
# https://github.com/PixArt-alpha/PixArt-sigma/blob/dd087141864e30ec44f12cb7448dd654be065e88/scripts/inference.py#L158
|
| 12 |
+
interpolation_scale = {256: 0.5, 512: 1, 1024: 2, 2048: 4}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main(args):
|
| 16 |
+
all_state_dict = torch.load(args.orig_ckpt_path)
|
| 17 |
+
state_dict = all_state_dict.pop("state_dict")
|
| 18 |
+
converted_state_dict = {}
|
| 19 |
+
|
| 20 |
+
# Patch embeddings.
|
| 21 |
+
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
|
| 22 |
+
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
|
| 23 |
+
|
| 24 |
+
# Caption projection.
|
| 25 |
+
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
|
| 26 |
+
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
|
| 27 |
+
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
|
| 28 |
+
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
|
| 29 |
+
|
| 30 |
+
# AdaLN-single LN
|
| 31 |
+
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
| 32 |
+
"t_embedder.mlp.0.weight"
|
| 33 |
+
)
|
| 34 |
+
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
|
| 35 |
+
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
| 36 |
+
"t_embedder.mlp.2.weight"
|
| 37 |
+
)
|
| 38 |
+
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
|
| 39 |
+
|
| 40 |
+
if args.micro_condition:
|
| 41 |
+
# Resolution.
|
| 42 |
+
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = state_dict.pop(
|
| 43 |
+
"csize_embedder.mlp.0.weight"
|
| 44 |
+
)
|
| 45 |
+
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = state_dict.pop(
|
| 46 |
+
"csize_embedder.mlp.0.bias"
|
| 47 |
+
)
|
| 48 |
+
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = state_dict.pop(
|
| 49 |
+
"csize_embedder.mlp.2.weight"
|
| 50 |
+
)
|
| 51 |
+
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = state_dict.pop(
|
| 52 |
+
"csize_embedder.mlp.2.bias"
|
| 53 |
+
)
|
| 54 |
+
# Aspect ratio.
|
| 55 |
+
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = state_dict.pop(
|
| 56 |
+
"ar_embedder.mlp.0.weight"
|
| 57 |
+
)
|
| 58 |
+
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = state_dict.pop(
|
| 59 |
+
"ar_embedder.mlp.0.bias"
|
| 60 |
+
)
|
| 61 |
+
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = state_dict.pop(
|
| 62 |
+
"ar_embedder.mlp.2.weight"
|
| 63 |
+
)
|
| 64 |
+
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = state_dict.pop(
|
| 65 |
+
"ar_embedder.mlp.2.bias"
|
| 66 |
+
)
|
| 67 |
+
# Shared norm.
|
| 68 |
+
converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight")
|
| 69 |
+
converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias")
|
| 70 |
+
|
| 71 |
+
for depth in range(28):
|
| 72 |
+
# Transformer blocks.
|
| 73 |
+
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
|
| 74 |
+
f"blocks.{depth}.scale_shift_table"
|
| 75 |
+
)
|
| 76 |
+
# Attention is all you need 🤘
|
| 77 |
+
|
| 78 |
+
# Self attention.
|
| 79 |
+
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
|
| 80 |
+
q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0)
|
| 81 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
|
| 82 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
|
| 83 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
|
| 84 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
|
| 85 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
|
| 86 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
|
| 87 |
+
# Projection.
|
| 88 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
|
| 89 |
+
f"blocks.{depth}.attn.proj.weight"
|
| 90 |
+
)
|
| 91 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
|
| 92 |
+
f"blocks.{depth}.attn.proj.bias"
|
| 93 |
+
)
|
| 94 |
+
if args.qk_norm:
|
| 95 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.weight"] = state_dict.pop(
|
| 96 |
+
f"blocks.{depth}.attn.q_norm.weight"
|
| 97 |
+
)
|
| 98 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.bias"] = state_dict.pop(
|
| 99 |
+
f"blocks.{depth}.attn.q_norm.bias"
|
| 100 |
+
)
|
| 101 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.weight"] = state_dict.pop(
|
| 102 |
+
f"blocks.{depth}.attn.k_norm.weight"
|
| 103 |
+
)
|
| 104 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.bias"] = state_dict.pop(
|
| 105 |
+
f"blocks.{depth}.attn.k_norm.bias"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Feed-forward.
|
| 109 |
+
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop(
|
| 110 |
+
f"blocks.{depth}.mlp.fc1.weight"
|
| 111 |
+
)
|
| 112 |
+
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop(
|
| 113 |
+
f"blocks.{depth}.mlp.fc1.bias"
|
| 114 |
+
)
|
| 115 |
+
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop(
|
| 116 |
+
f"blocks.{depth}.mlp.fc2.weight"
|
| 117 |
+
)
|
| 118 |
+
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop(
|
| 119 |
+
f"blocks.{depth}.mlp.fc2.bias"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Cross-attention.
|
| 123 |
+
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
|
| 124 |
+
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
|
| 125 |
+
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
|
| 126 |
+
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
|
| 127 |
+
|
| 128 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
|
| 129 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
|
| 130 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
|
| 131 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
|
| 132 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
|
| 133 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
|
| 134 |
+
|
| 135 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
|
| 136 |
+
f"blocks.{depth}.cross_attn.proj.weight"
|
| 137 |
+
)
|
| 138 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
|
| 139 |
+
f"blocks.{depth}.cross_attn.proj.bias"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Final block.
|
| 143 |
+
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
|
| 144 |
+
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
|
| 145 |
+
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
|
| 146 |
+
|
| 147 |
+
# PixArt XL/2
|
| 148 |
+
transformer = Transformer2DModel(
|
| 149 |
+
sample_size=args.image_size // 8,
|
| 150 |
+
num_layers=28,
|
| 151 |
+
attention_head_dim=72,
|
| 152 |
+
in_channels=4,
|
| 153 |
+
out_channels=8,
|
| 154 |
+
patch_size=2,
|
| 155 |
+
attention_bias=True,
|
| 156 |
+
num_attention_heads=16,
|
| 157 |
+
cross_attention_dim=1152,
|
| 158 |
+
activation_fn="gelu-approximate",
|
| 159 |
+
num_embeds_ada_norm=1000,
|
| 160 |
+
norm_type="ada_norm_single",
|
| 161 |
+
norm_elementwise_affine=False,
|
| 162 |
+
norm_eps=1e-6,
|
| 163 |
+
caption_channels=4096,
|
| 164 |
+
interpolation_scale=interpolation_scale[args.image_size],
|
| 165 |
+
use_additional_conditions=args.micro_condition,
|
| 166 |
+
)
|
| 167 |
+
transformer.load_state_dict(converted_state_dict, strict=True)
|
| 168 |
+
|
| 169 |
+
assert transformer.pos_embed.pos_embed is not None
|
| 170 |
+
try:
|
| 171 |
+
state_dict.pop("y_embedder.y_embedding")
|
| 172 |
+
state_dict.pop("pos_embed")
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(f"Skipping {str(e)}")
|
| 175 |
+
pass
|
| 176 |
+
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
|
| 177 |
+
|
| 178 |
+
num_model_params = sum(p.numel() for p in transformer.parameters())
|
| 179 |
+
print(f"Total number of transformer parameters: {num_model_params}")
|
| 180 |
+
|
| 181 |
+
if args.only_transformer:
|
| 182 |
+
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
|
| 183 |
+
else:
|
| 184 |
+
# pixart-Sigma vae link: https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main/vae
|
| 185 |
+
vae = AutoencoderKL.from_pretrained(f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="vae")
|
| 186 |
+
|
| 187 |
+
scheduler = DPMSolverMultistepScheduler()
|
| 188 |
+
|
| 189 |
+
tokenizer = T5Tokenizer.from_pretrained(f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="tokenizer")
|
| 190 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
| 191 |
+
f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="text_encoder"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
pipeline = PixArtSigmaPipeline(
|
| 195 |
+
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
pipeline.save_pretrained(args.dump_path)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
if __name__ == "__main__":
|
| 202 |
+
parser = argparse.ArgumentParser()
|
| 203 |
+
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--micro_condition", action="store_true", help="If use Micro-condition in PixArtMS structure during training."
|
| 206 |
+
)
|
| 207 |
+
parser.add_argument("--qk_norm", action="store_true", help="If use qk norm during training.")
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
|
| 210 |
+
)
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"--image_size",
|
| 213 |
+
default=1024,
|
| 214 |
+
type=int,
|
| 215 |
+
choices=[256, 512, 1024, 2048],
|
| 216 |
+
required=False,
|
| 217 |
+
help="Image size of pretrained model, 256, 512, 1024, or 2048.",
|
| 218 |
+
)
|
| 219 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
| 220 |
+
parser.add_argument("--only_transformer", default=True, type=bool, required=True)
|
| 221 |
+
|
| 222 |
+
args = parser.parse_args()
|
| 223 |
+
main(args)
|
diffusers/scripts/convert_sana_to_diffusers.py
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import os
|
| 6 |
+
from contextlib import nullcontext
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from accelerate import init_empty_weights
|
| 10 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 11 |
+
from termcolor import colored
|
| 12 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 13 |
+
|
| 14 |
+
from diffusers import (
|
| 15 |
+
AutoencoderDC,
|
| 16 |
+
DPMSolverMultistepScheduler,
|
| 17 |
+
FlowMatchEulerDiscreteScheduler,
|
| 18 |
+
SanaPipeline,
|
| 19 |
+
SanaSprintPipeline,
|
| 20 |
+
SanaTransformer2DModel,
|
| 21 |
+
SCMScheduler,
|
| 22 |
+
)
|
| 23 |
+
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
| 24 |
+
from diffusers.utils.import_utils import is_accelerate_available
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
| 28 |
+
|
| 29 |
+
ckpt_ids = [
|
| 30 |
+
"Efficient-Large-Model/Sana_Sprint_0.6B_1024px/checkpoints/Sana_Sprint_0.6B_1024px.pth"
|
| 31 |
+
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px/checkpoints/Sana_Sprint_1.6B_1024px.pth"
|
| 32 |
+
"Efficient-Large-Model/SANA1.5_4.8B_1024px/checkpoints/SANA1.5_4.8B_1024px.pth",
|
| 33 |
+
"Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth",
|
| 34 |
+
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
|
| 35 |
+
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
|
| 36 |
+
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
|
| 37 |
+
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
|
| 38 |
+
"Efficient-Large-Model/Sana_1600M_512px_MultiLing/checkpoints/Sana_1600M_512px_MultiLing.pth",
|
| 39 |
+
"Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
|
| 40 |
+
"Efficient-Large-Model/Sana_1600M_512px/checkpoints/Sana_1600M_512px.pth",
|
| 41 |
+
"Efficient-Large-Model/Sana_600M_1024px/checkpoints/Sana_600M_1024px_MultiLing.pth",
|
| 42 |
+
"Efficient-Large-Model/Sana_600M_512px/checkpoints/Sana_600M_512px_MultiLing.pth",
|
| 43 |
+
]
|
| 44 |
+
# https://github.com/NVlabs/Sana/blob/main/scripts/inference.py
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def main(args):
|
| 48 |
+
cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
|
| 49 |
+
|
| 50 |
+
if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
|
| 51 |
+
ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
|
| 52 |
+
snapshot_download(
|
| 53 |
+
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
|
| 54 |
+
cache_dir=cache_dir_path,
|
| 55 |
+
repo_type="model",
|
| 56 |
+
)
|
| 57 |
+
file_path = hf_hub_download(
|
| 58 |
+
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
|
| 59 |
+
filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
|
| 60 |
+
cache_dir=cache_dir_path,
|
| 61 |
+
repo_type="model",
|
| 62 |
+
)
|
| 63 |
+
else:
|
| 64 |
+
file_path = args.orig_ckpt_path
|
| 65 |
+
|
| 66 |
+
print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
|
| 67 |
+
all_state_dict = torch.load(file_path, weights_only=True)
|
| 68 |
+
state_dict = all_state_dict.pop("state_dict")
|
| 69 |
+
converted_state_dict = {}
|
| 70 |
+
|
| 71 |
+
# Patch embeddings.
|
| 72 |
+
converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
|
| 73 |
+
converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
|
| 74 |
+
|
| 75 |
+
# Caption projection.
|
| 76 |
+
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
|
| 77 |
+
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
|
| 78 |
+
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
|
| 79 |
+
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
|
| 80 |
+
|
| 81 |
+
# Handle different time embedding structure based on model type
|
| 82 |
+
|
| 83 |
+
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
|
| 84 |
+
# For Sana Sprint, the time embedding structure is different
|
| 85 |
+
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
| 86 |
+
"t_embedder.mlp.0.weight"
|
| 87 |
+
)
|
| 88 |
+
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
|
| 89 |
+
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
| 90 |
+
"t_embedder.mlp.2.weight"
|
| 91 |
+
)
|
| 92 |
+
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
|
| 93 |
+
|
| 94 |
+
# Guidance embedder for Sana Sprint
|
| 95 |
+
converted_state_dict["time_embed.guidance_embedder.linear_1.weight"] = state_dict.pop(
|
| 96 |
+
"cfg_embedder.mlp.0.weight"
|
| 97 |
+
)
|
| 98 |
+
converted_state_dict["time_embed.guidance_embedder.linear_1.bias"] = state_dict.pop("cfg_embedder.mlp.0.bias")
|
| 99 |
+
converted_state_dict["time_embed.guidance_embedder.linear_2.weight"] = state_dict.pop(
|
| 100 |
+
"cfg_embedder.mlp.2.weight"
|
| 101 |
+
)
|
| 102 |
+
converted_state_dict["time_embed.guidance_embedder.linear_2.bias"] = state_dict.pop("cfg_embedder.mlp.2.bias")
|
| 103 |
+
else:
|
| 104 |
+
# Original Sana time embedding structure
|
| 105 |
+
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
| 106 |
+
"t_embedder.mlp.0.weight"
|
| 107 |
+
)
|
| 108 |
+
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop(
|
| 109 |
+
"t_embedder.mlp.0.bias"
|
| 110 |
+
)
|
| 111 |
+
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
| 112 |
+
"t_embedder.mlp.2.weight"
|
| 113 |
+
)
|
| 114 |
+
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop(
|
| 115 |
+
"t_embedder.mlp.2.bias"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Shared norm.
|
| 119 |
+
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
|
| 120 |
+
converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
|
| 121 |
+
|
| 122 |
+
# y norm
|
| 123 |
+
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
|
| 124 |
+
|
| 125 |
+
# scheduler
|
| 126 |
+
if args.image_size == 4096:
|
| 127 |
+
flow_shift = 6.0
|
| 128 |
+
else:
|
| 129 |
+
flow_shift = 3.0
|
| 130 |
+
|
| 131 |
+
# model config
|
| 132 |
+
if args.model_type in ["SanaMS_1600M_P1_D20", "SanaSprint_1600M_P1_D20", "SanaMS1.5_1600M_P1_D20"]:
|
| 133 |
+
layer_num = 20
|
| 134 |
+
elif args.model_type in ["SanaMS_600M_P1_D28", "SanaSprint_600M_P1_D28"]:
|
| 135 |
+
layer_num = 28
|
| 136 |
+
elif args.model_type == "SanaMS_4800M_P1_D60":
|
| 137 |
+
layer_num = 60
|
| 138 |
+
else:
|
| 139 |
+
raise ValueError(f"{args.model_type} is not supported.")
|
| 140 |
+
# Positional embedding interpolation scale.
|
| 141 |
+
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
|
| 142 |
+
qk_norm = (
|
| 143 |
+
"rms_norm_across_heads"
|
| 144 |
+
if args.model_type
|
| 145 |
+
in ["SanaMS1.5_1600M_P1_D20", "SanaMS1.5_4800M_P1_D60", "SanaSprint_600M_P1_D28", "SanaSprint_1600M_P1_D20"]
|
| 146 |
+
else None
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
for depth in range(layer_num):
|
| 150 |
+
# Transformer blocks.
|
| 151 |
+
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
|
| 152 |
+
f"blocks.{depth}.scale_shift_table"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# Linear Attention is all you need 🤘
|
| 156 |
+
# Self attention.
|
| 157 |
+
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
|
| 158 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
|
| 159 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
|
| 160 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
|
| 161 |
+
if qk_norm is not None:
|
| 162 |
+
# Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
|
| 163 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
|
| 164 |
+
f"blocks.{depth}.attn.q_norm.weight"
|
| 165 |
+
)
|
| 166 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
|
| 167 |
+
f"blocks.{depth}.attn.k_norm.weight"
|
| 168 |
+
)
|
| 169 |
+
# Projection.
|
| 170 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
|
| 171 |
+
f"blocks.{depth}.attn.proj.weight"
|
| 172 |
+
)
|
| 173 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
|
| 174 |
+
f"blocks.{depth}.attn.proj.bias"
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Feed-forward.
|
| 178 |
+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
|
| 179 |
+
f"blocks.{depth}.mlp.inverted_conv.conv.weight"
|
| 180 |
+
)
|
| 181 |
+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
|
| 182 |
+
f"blocks.{depth}.mlp.inverted_conv.conv.bias"
|
| 183 |
+
)
|
| 184 |
+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
|
| 185 |
+
f"blocks.{depth}.mlp.depth_conv.conv.weight"
|
| 186 |
+
)
|
| 187 |
+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
|
| 188 |
+
f"blocks.{depth}.mlp.depth_conv.conv.bias"
|
| 189 |
+
)
|
| 190 |
+
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
|
| 191 |
+
f"blocks.{depth}.mlp.point_conv.conv.weight"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Cross-attention.
|
| 195 |
+
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
|
| 196 |
+
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
|
| 197 |
+
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
|
| 198 |
+
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
|
| 199 |
+
|
| 200 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
|
| 201 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
|
| 202 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
|
| 203 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
|
| 204 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
|
| 205 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
|
| 206 |
+
if qk_norm is not None:
|
| 207 |
+
# Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
|
| 208 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
|
| 209 |
+
f"blocks.{depth}.cross_attn.q_norm.weight"
|
| 210 |
+
)
|
| 211 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
|
| 212 |
+
f"blocks.{depth}.cross_attn.k_norm.weight"
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
|
| 216 |
+
f"blocks.{depth}.cross_attn.proj.weight"
|
| 217 |
+
)
|
| 218 |
+
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
|
| 219 |
+
f"blocks.{depth}.cross_attn.proj.bias"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Final block.
|
| 223 |
+
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
|
| 224 |
+
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
|
| 225 |
+
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
|
| 226 |
+
|
| 227 |
+
# Transformer
|
| 228 |
+
with CTX():
|
| 229 |
+
transformer_kwargs = {
|
| 230 |
+
"in_channels": 32,
|
| 231 |
+
"out_channels": 32,
|
| 232 |
+
"num_attention_heads": model_kwargs[args.model_type]["num_attention_heads"],
|
| 233 |
+
"attention_head_dim": model_kwargs[args.model_type]["attention_head_dim"],
|
| 234 |
+
"num_layers": model_kwargs[args.model_type]["num_layers"],
|
| 235 |
+
"num_cross_attention_heads": model_kwargs[args.model_type]["num_cross_attention_heads"],
|
| 236 |
+
"cross_attention_head_dim": model_kwargs[args.model_type]["cross_attention_head_dim"],
|
| 237 |
+
"cross_attention_dim": model_kwargs[args.model_type]["cross_attention_dim"],
|
| 238 |
+
"caption_channels": 2304,
|
| 239 |
+
"mlp_ratio": 2.5,
|
| 240 |
+
"attention_bias": False,
|
| 241 |
+
"sample_size": args.image_size // 32,
|
| 242 |
+
"patch_size": 1,
|
| 243 |
+
"norm_elementwise_affine": False,
|
| 244 |
+
"norm_eps": 1e-6,
|
| 245 |
+
"interpolation_scale": interpolation_scale[args.image_size],
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
# Add qk_norm parameter for Sana Sprint
|
| 249 |
+
if args.model_type in [
|
| 250 |
+
"SanaMS1.5_1600M_P1_D20",
|
| 251 |
+
"SanaMS1.5_4800M_P1_D60",
|
| 252 |
+
"SanaSprint_600M_P1_D28",
|
| 253 |
+
"SanaSprint_1600M_P1_D20",
|
| 254 |
+
]:
|
| 255 |
+
transformer_kwargs["qk_norm"] = "rms_norm_across_heads"
|
| 256 |
+
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
|
| 257 |
+
transformer_kwargs["guidance_embeds"] = True
|
| 258 |
+
|
| 259 |
+
transformer = SanaTransformer2DModel(**transformer_kwargs)
|
| 260 |
+
|
| 261 |
+
if is_accelerate_available():
|
| 262 |
+
load_model_dict_into_meta(transformer, converted_state_dict)
|
| 263 |
+
else:
|
| 264 |
+
transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
|
| 265 |
+
|
| 266 |
+
try:
|
| 267 |
+
state_dict.pop("y_embedder.y_embedding")
|
| 268 |
+
state_dict.pop("pos_embed")
|
| 269 |
+
state_dict.pop("logvar_linear.weight")
|
| 270 |
+
state_dict.pop("logvar_linear.bias")
|
| 271 |
+
except KeyError:
|
| 272 |
+
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
|
| 273 |
+
|
| 274 |
+
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
|
| 275 |
+
|
| 276 |
+
num_model_params = sum(p.numel() for p in transformer.parameters())
|
| 277 |
+
print(f"Total number of transformer parameters: {num_model_params}")
|
| 278 |
+
|
| 279 |
+
transformer = transformer.to(weight_dtype)
|
| 280 |
+
|
| 281 |
+
if not args.save_full_pipeline:
|
| 282 |
+
print(
|
| 283 |
+
colored(
|
| 284 |
+
f"Only saving transformer model of {args.model_type}. "
|
| 285 |
+
f"Set --save_full_pipeline to save the whole Pipeline",
|
| 286 |
+
"green",
|
| 287 |
+
attrs=["bold"],
|
| 288 |
+
)
|
| 289 |
+
)
|
| 290 |
+
transformer.save_pretrained(
|
| 291 |
+
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
|
| 292 |
+
)
|
| 293 |
+
else:
|
| 294 |
+
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
|
| 295 |
+
# VAE
|
| 296 |
+
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32)
|
| 297 |
+
|
| 298 |
+
# Text Encoder
|
| 299 |
+
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
|
| 300 |
+
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
|
| 301 |
+
tokenizer.padding_side = "right"
|
| 302 |
+
text_encoder = AutoModelForCausalLM.from_pretrained(
|
| 303 |
+
text_encoder_model_path, torch_dtype=torch.bfloat16
|
| 304 |
+
).get_decoder()
|
| 305 |
+
|
| 306 |
+
# Choose the appropriate pipeline and scheduler based on model type
|
| 307 |
+
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
|
| 308 |
+
# Force SCM Scheduler for Sana Sprint regardless of scheduler_type
|
| 309 |
+
if args.scheduler_type != "scm":
|
| 310 |
+
print(
|
| 311 |
+
colored(
|
| 312 |
+
f"Warning: Overriding scheduler_type '{args.scheduler_type}' to 'scm' for SanaSprint model",
|
| 313 |
+
"yellow",
|
| 314 |
+
attrs=["bold"],
|
| 315 |
+
)
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# SCM Scheduler for Sana Sprint
|
| 319 |
+
scheduler_config = {
|
| 320 |
+
"prediction_type": "trigflow",
|
| 321 |
+
"sigma_data": 0.5,
|
| 322 |
+
}
|
| 323 |
+
scheduler = SCMScheduler(**scheduler_config)
|
| 324 |
+
pipe = SanaSprintPipeline(
|
| 325 |
+
tokenizer=tokenizer,
|
| 326 |
+
text_encoder=text_encoder,
|
| 327 |
+
transformer=transformer,
|
| 328 |
+
vae=ae,
|
| 329 |
+
scheduler=scheduler,
|
| 330 |
+
)
|
| 331 |
+
else:
|
| 332 |
+
# Original Sana scheduler
|
| 333 |
+
if args.scheduler_type == "flow-dpm_solver":
|
| 334 |
+
scheduler = DPMSolverMultistepScheduler(
|
| 335 |
+
flow_shift=flow_shift,
|
| 336 |
+
use_flow_sigmas=True,
|
| 337 |
+
prediction_type="flow_prediction",
|
| 338 |
+
)
|
| 339 |
+
elif args.scheduler_type == "flow-euler":
|
| 340 |
+
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
|
| 341 |
+
else:
|
| 342 |
+
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
|
| 343 |
+
|
| 344 |
+
pipe = SanaPipeline(
|
| 345 |
+
tokenizer=tokenizer,
|
| 346 |
+
text_encoder=text_encoder,
|
| 347 |
+
transformer=transformer,
|
| 348 |
+
vae=ae,
|
| 349 |
+
scheduler=scheduler,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
DTYPE_MAPPING = {
|
| 356 |
+
"fp32": torch.float32,
|
| 357 |
+
"fp16": torch.float16,
|
| 358 |
+
"bf16": torch.bfloat16,
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
if __name__ == "__main__":
|
| 363 |
+
parser = argparse.ArgumentParser()
|
| 364 |
+
|
| 365 |
+
parser.add_argument(
|
| 366 |
+
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
|
| 367 |
+
)
|
| 368 |
+
parser.add_argument(
|
| 369 |
+
"--image_size",
|
| 370 |
+
default=1024,
|
| 371 |
+
type=int,
|
| 372 |
+
choices=[512, 1024, 2048, 4096],
|
| 373 |
+
required=False,
|
| 374 |
+
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
|
| 375 |
+
)
|
| 376 |
+
parser.add_argument(
|
| 377 |
+
"--model_type",
|
| 378 |
+
default="SanaMS_1600M_P1_D20",
|
| 379 |
+
type=str,
|
| 380 |
+
choices=[
|
| 381 |
+
"SanaMS_1600M_P1_D20",
|
| 382 |
+
"SanaMS_600M_P1_D28",
|
| 383 |
+
"SanaMS1.5_1600M_P1_D20",
|
| 384 |
+
"SanaMS1.5_4800M_P1_D60",
|
| 385 |
+
"SanaSprint_1600M_P1_D20",
|
| 386 |
+
"SanaSprint_600M_P1_D28",
|
| 387 |
+
],
|
| 388 |
+
)
|
| 389 |
+
parser.add_argument(
|
| 390 |
+
"--scheduler_type",
|
| 391 |
+
default="flow-dpm_solver",
|
| 392 |
+
type=str,
|
| 393 |
+
choices=["flow-dpm_solver", "flow-euler", "scm"],
|
| 394 |
+
help="Scheduler type to use. Use 'scm' for Sana Sprint models.",
|
| 395 |
+
)
|
| 396 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
| 397 |
+
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
|
| 398 |
+
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
|
| 399 |
+
|
| 400 |
+
args = parser.parse_args()
|
| 401 |
+
|
| 402 |
+
model_kwargs = {
|
| 403 |
+
"SanaMS_1600M_P1_D20": {
|
| 404 |
+
"num_attention_heads": 70,
|
| 405 |
+
"attention_head_dim": 32,
|
| 406 |
+
"num_cross_attention_heads": 20,
|
| 407 |
+
"cross_attention_head_dim": 112,
|
| 408 |
+
"cross_attention_dim": 2240,
|
| 409 |
+
"num_layers": 20,
|
| 410 |
+
},
|
| 411 |
+
"SanaMS_600M_P1_D28": {
|
| 412 |
+
"num_attention_heads": 36,
|
| 413 |
+
"attention_head_dim": 32,
|
| 414 |
+
"num_cross_attention_heads": 16,
|
| 415 |
+
"cross_attention_head_dim": 72,
|
| 416 |
+
"cross_attention_dim": 1152,
|
| 417 |
+
"num_layers": 28,
|
| 418 |
+
},
|
| 419 |
+
"SanaMS1.5_1600M_P1_D20": {
|
| 420 |
+
"num_attention_heads": 70,
|
| 421 |
+
"attention_head_dim": 32,
|
| 422 |
+
"num_cross_attention_heads": 20,
|
| 423 |
+
"cross_attention_head_dim": 112,
|
| 424 |
+
"cross_attention_dim": 2240,
|
| 425 |
+
"num_layers": 20,
|
| 426 |
+
},
|
| 427 |
+
"SanaMS1.5_4800M_P1_D60": {
|
| 428 |
+
"num_attention_heads": 70,
|
| 429 |
+
"attention_head_dim": 32,
|
| 430 |
+
"num_cross_attention_heads": 20,
|
| 431 |
+
"cross_attention_head_dim": 112,
|
| 432 |
+
"cross_attention_dim": 2240,
|
| 433 |
+
"num_layers": 60,
|
| 434 |
+
},
|
| 435 |
+
"SanaSprint_600M_P1_D28": {
|
| 436 |
+
"num_attention_heads": 36,
|
| 437 |
+
"attention_head_dim": 32,
|
| 438 |
+
"num_cross_attention_heads": 16,
|
| 439 |
+
"cross_attention_head_dim": 72,
|
| 440 |
+
"cross_attention_dim": 1152,
|
| 441 |
+
"num_layers": 28,
|
| 442 |
+
},
|
| 443 |
+
"SanaSprint_1600M_P1_D20": {
|
| 444 |
+
"num_attention_heads": 70,
|
| 445 |
+
"attention_head_dim": 32,
|
| 446 |
+
"num_cross_attention_heads": 20,
|
| 447 |
+
"cross_attention_head_dim": 112,
|
| 448 |
+
"cross_attention_dim": 2240,
|
| 449 |
+
"num_layers": 20,
|
| 450 |
+
},
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 454 |
+
weight_dtype = DTYPE_MAPPING[args.dtype]
|
| 455 |
+
|
| 456 |
+
main(args)
|
diffusers/scripts/convert_stable_cascade.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
|
| 2 |
+
import argparse
|
| 3 |
+
from contextlib import nullcontext
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from safetensors.torch import load_file
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoTokenizer,
|
| 9 |
+
CLIPConfig,
|
| 10 |
+
CLIPImageProcessor,
|
| 11 |
+
CLIPTextModelWithProjection,
|
| 12 |
+
CLIPVisionModelWithProjection,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from diffusers import (
|
| 16 |
+
DDPMWuerstchenScheduler,
|
| 17 |
+
StableCascadeCombinedPipeline,
|
| 18 |
+
StableCascadeDecoderPipeline,
|
| 19 |
+
StableCascadePriorPipeline,
|
| 20 |
+
)
|
| 21 |
+
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
|
| 22 |
+
from diffusers.models import StableCascadeUNet
|
| 23 |
+
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
| 24 |
+
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
| 25 |
+
from diffusers.utils import is_accelerate_available
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if is_accelerate_available():
|
| 29 |
+
from accelerate import init_empty_weights
|
| 30 |
+
|
| 31 |
+
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
|
| 32 |
+
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
|
| 33 |
+
parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file")
|
| 34 |
+
parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file")
|
| 35 |
+
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
|
| 36 |
+
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
|
| 37 |
+
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--prior_output_path", default="stable-cascade-prior", type=str, help="Hub organization to save the pipelines to"
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--decoder_output_path",
|
| 43 |
+
type=str,
|
| 44 |
+
default="stable-cascade-decoder",
|
| 45 |
+
help="Hub organization to save the pipelines to",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--combined_output_path",
|
| 49 |
+
type=str,
|
| 50 |
+
default="stable-cascade-combined",
|
| 51 |
+
help="Hub organization to save the pipelines to",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument("--save_combined", action="store_true")
|
| 54 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
|
| 55 |
+
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
|
| 56 |
+
|
| 57 |
+
args = parser.parse_args()
|
| 58 |
+
|
| 59 |
+
if args.skip_stage_b and args.skip_stage_c:
|
| 60 |
+
raise ValueError("At least one stage should be converted")
|
| 61 |
+
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
|
| 62 |
+
raise ValueError("Cannot skip stages when creating a combined pipeline")
|
| 63 |
+
|
| 64 |
+
model_path = args.model_path
|
| 65 |
+
|
| 66 |
+
device = "cpu"
|
| 67 |
+
if args.variant == "bf16":
|
| 68 |
+
dtype = torch.bfloat16
|
| 69 |
+
else:
|
| 70 |
+
dtype = torch.float32
|
| 71 |
+
|
| 72 |
+
# set paths to model weights
|
| 73 |
+
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
|
| 74 |
+
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"
|
| 75 |
+
|
| 76 |
+
# Clip Text encoder and tokenizer
|
| 77 |
+
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
| 78 |
+
config.text_config.projection_dim = config.projection_dim
|
| 79 |
+
text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
| 80 |
+
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
|
| 81 |
+
)
|
| 82 |
+
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
| 83 |
+
|
| 84 |
+
# image processor
|
| 85 |
+
feature_extractor = CLIPImageProcessor()
|
| 86 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
| 87 |
+
|
| 88 |
+
# scheduler for prior and decoder
|
| 89 |
+
scheduler = DDPMWuerstchenScheduler()
|
| 90 |
+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
| 91 |
+
|
| 92 |
+
if not args.skip_stage_c:
|
| 93 |
+
# Prior
|
| 94 |
+
if args.use_safetensors:
|
| 95 |
+
prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
|
| 96 |
+
else:
|
| 97 |
+
prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
|
| 98 |
+
|
| 99 |
+
prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)
|
| 100 |
+
|
| 101 |
+
with ctx():
|
| 102 |
+
prior_model = StableCascadeUNet(
|
| 103 |
+
in_channels=16,
|
| 104 |
+
out_channels=16,
|
| 105 |
+
timestep_ratio_embedding_dim=64,
|
| 106 |
+
patch_size=1,
|
| 107 |
+
conditioning_dim=2048,
|
| 108 |
+
block_out_channels=[2048, 2048],
|
| 109 |
+
num_attention_heads=[32, 32],
|
| 110 |
+
down_num_layers_per_block=[8, 24],
|
| 111 |
+
up_num_layers_per_block=[24, 8],
|
| 112 |
+
down_blocks_repeat_mappers=[1, 1],
|
| 113 |
+
up_blocks_repeat_mappers=[1, 1],
|
| 114 |
+
block_types_per_layer=[
|
| 115 |
+
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
| 116 |
+
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
| 117 |
+
],
|
| 118 |
+
clip_text_in_channels=1280,
|
| 119 |
+
clip_text_pooled_in_channels=1280,
|
| 120 |
+
clip_image_in_channels=768,
|
| 121 |
+
clip_seq=4,
|
| 122 |
+
kernel_size=3,
|
| 123 |
+
dropout=[0.1, 0.1],
|
| 124 |
+
self_attn=True,
|
| 125 |
+
timestep_conditioning_type=["sca", "crp"],
|
| 126 |
+
switch_level=[False],
|
| 127 |
+
)
|
| 128 |
+
if is_accelerate_available():
|
| 129 |
+
load_model_dict_into_meta(prior_model, prior_state_dict)
|
| 130 |
+
else:
|
| 131 |
+
prior_model.load_state_dict(prior_state_dict)
|
| 132 |
+
|
| 133 |
+
# Prior pipeline
|
| 134 |
+
prior_pipeline = StableCascadePriorPipeline(
|
| 135 |
+
prior=prior_model,
|
| 136 |
+
tokenizer=tokenizer,
|
| 137 |
+
text_encoder=text_encoder,
|
| 138 |
+
image_encoder=image_encoder,
|
| 139 |
+
scheduler=scheduler,
|
| 140 |
+
feature_extractor=feature_extractor,
|
| 141 |
+
)
|
| 142 |
+
prior_pipeline.to(dtype).save_pretrained(
|
| 143 |
+
args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
if not args.skip_stage_b:
|
| 147 |
+
# Decoder
|
| 148 |
+
if args.use_safetensors:
|
| 149 |
+
decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
|
| 150 |
+
else:
|
| 151 |
+
decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
|
| 152 |
+
|
| 153 |
+
decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)
|
| 154 |
+
with ctx():
|
| 155 |
+
decoder = StableCascadeUNet(
|
| 156 |
+
in_channels=4,
|
| 157 |
+
out_channels=4,
|
| 158 |
+
timestep_ratio_embedding_dim=64,
|
| 159 |
+
patch_size=2,
|
| 160 |
+
conditioning_dim=1280,
|
| 161 |
+
block_out_channels=[320, 640, 1280, 1280],
|
| 162 |
+
down_num_layers_per_block=[2, 6, 28, 6],
|
| 163 |
+
up_num_layers_per_block=[6, 28, 6, 2],
|
| 164 |
+
down_blocks_repeat_mappers=[1, 1, 1, 1],
|
| 165 |
+
up_blocks_repeat_mappers=[3, 3, 2, 2],
|
| 166 |
+
num_attention_heads=[0, 0, 20, 20],
|
| 167 |
+
block_types_per_layer=[
|
| 168 |
+
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
| 169 |
+
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
|
| 170 |
+
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
| 171 |
+
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
|
| 172 |
+
],
|
| 173 |
+
clip_text_pooled_in_channels=1280,
|
| 174 |
+
clip_seq=4,
|
| 175 |
+
effnet_in_channels=16,
|
| 176 |
+
pixel_mapper_in_channels=3,
|
| 177 |
+
kernel_size=3,
|
| 178 |
+
dropout=[0, 0, 0.1, 0.1],
|
| 179 |
+
self_attn=True,
|
| 180 |
+
timestep_conditioning_type=["sca"],
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
if is_accelerate_available():
|
| 184 |
+
load_model_dict_into_meta(decoder, decoder_state_dict)
|
| 185 |
+
else:
|
| 186 |
+
decoder.load_state_dict(decoder_state_dict)
|
| 187 |
+
|
| 188 |
+
# VQGAN from Wuerstchen-V2
|
| 189 |
+
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
|
| 190 |
+
|
| 191 |
+
# Decoder pipeline
|
| 192 |
+
decoder_pipeline = StableCascadeDecoderPipeline(
|
| 193 |
+
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
|
| 194 |
+
)
|
| 195 |
+
decoder_pipeline.to(dtype).save_pretrained(
|
| 196 |
+
args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
if args.save_combined:
|
| 200 |
+
# Stable Cascade combined pipeline
|
| 201 |
+
stable_cascade_pipeline = StableCascadeCombinedPipeline(
|
| 202 |
+
# Decoder
|
| 203 |
+
text_encoder=text_encoder,
|
| 204 |
+
tokenizer=tokenizer,
|
| 205 |
+
decoder=decoder,
|
| 206 |
+
scheduler=scheduler,
|
| 207 |
+
vqgan=vqmodel,
|
| 208 |
+
# Prior
|
| 209 |
+
prior_text_encoder=text_encoder,
|
| 210 |
+
prior_tokenizer=tokenizer,
|
| 211 |
+
prior_prior=prior_model,
|
| 212 |
+
prior_scheduler=scheduler,
|
| 213 |
+
prior_image_encoder=image_encoder,
|
| 214 |
+
prior_feature_extractor=feature_extractor,
|
| 215 |
+
)
|
| 216 |
+
stable_cascade_pipeline.to(dtype).save_pretrained(
|
| 217 |
+
args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
|
| 218 |
+
)
|
diffusers/scripts/convert_vae_pt_to_diffusers.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import io
|
| 3 |
+
|
| 4 |
+
import requests
|
| 5 |
+
import torch
|
| 6 |
+
import yaml
|
| 7 |
+
|
| 8 |
+
from diffusers import AutoencoderKL
|
| 9 |
+
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
| 10 |
+
assign_to_checkpoint,
|
| 11 |
+
conv_attn_to_linear,
|
| 12 |
+
create_vae_diffusers_config,
|
| 13 |
+
renew_vae_attention_paths,
|
| 14 |
+
renew_vae_resnet_paths,
|
| 15 |
+
)
|
| 16 |
+
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
| 20 |
+
vae_state_dict = checkpoint
|
| 21 |
+
|
| 22 |
+
new_checkpoint = {}
|
| 23 |
+
|
| 24 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
| 25 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
| 26 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
| 27 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
| 28 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
| 29 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
| 30 |
+
|
| 31 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
| 32 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
| 33 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
| 34 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
| 35 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
| 36 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
| 37 |
+
|
| 38 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
| 39 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
| 40 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
| 41 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
| 42 |
+
|
| 43 |
+
# Retrieves the keys for the encoder down blocks only
|
| 44 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
| 45 |
+
down_blocks = {
|
| 46 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
# Retrieves the keys for the decoder up blocks only
|
| 50 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
| 51 |
+
up_blocks = {
|
| 52 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
for i in range(num_down_blocks):
|
| 56 |
+
resnets = [
|
| 57 |
+
key
|
| 58 |
+
for key in down_blocks[i]
|
| 59 |
+
if f"down.{i}" in key and f"down.{i}.downsample" not in key and "attn" not in key
|
| 60 |
+
]
|
| 61 |
+
attentions = [key for key in down_blocks[i] if f"down.{i}.attn" in key]
|
| 62 |
+
|
| 63 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
| 64 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
| 65 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
| 66 |
+
)
|
| 67 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
| 68 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 72 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
| 73 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 74 |
+
|
| 75 |
+
paths = renew_vae_attention_paths(attentions)
|
| 76 |
+
meta_path = {"old": f"down.{i}.attn", "new": f"down_blocks.{i}.attentions"}
|
| 77 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 78 |
+
|
| 79 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
| 80 |
+
num_mid_res_blocks = 2
|
| 81 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 82 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
| 83 |
+
|
| 84 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 85 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 86 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 87 |
+
|
| 88 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
| 89 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 90 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 91 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 92 |
+
conv_attn_to_linear(new_checkpoint)
|
| 93 |
+
|
| 94 |
+
for i in range(num_up_blocks):
|
| 95 |
+
block_id = num_up_blocks - 1 - i
|
| 96 |
+
resnets = [
|
| 97 |
+
key
|
| 98 |
+
for key in up_blocks[block_id]
|
| 99 |
+
if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key and "attn" not in key
|
| 100 |
+
]
|
| 101 |
+
attentions = [key for key in up_blocks[block_id] if f"up.{block_id}.attn" in key]
|
| 102 |
+
|
| 103 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
| 104 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
| 105 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
| 106 |
+
]
|
| 107 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
| 108 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 112 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
| 113 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 114 |
+
|
| 115 |
+
paths = renew_vae_attention_paths(attentions)
|
| 116 |
+
meta_path = {"old": f"up.{block_id}.attn", "new": f"up_blocks.{i}.attentions"}
|
| 117 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 118 |
+
|
| 119 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
| 120 |
+
num_mid_res_blocks = 2
|
| 121 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 122 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
| 123 |
+
|
| 124 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 125 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 126 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 127 |
+
|
| 128 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
| 129 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 130 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 131 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 132 |
+
conv_attn_to_linear(new_checkpoint)
|
| 133 |
+
return new_checkpoint
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def vae_pt_to_vae_diffuser(
|
| 137 |
+
checkpoint_path: str,
|
| 138 |
+
output_path: str,
|
| 139 |
+
):
|
| 140 |
+
# Only support V1
|
| 141 |
+
r = requests.get(
|
| 142 |
+
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml",
|
| 143 |
+
timeout=DIFFUSERS_REQUEST_TIMEOUT,
|
| 144 |
+
)
|
| 145 |
+
io_obj = io.BytesIO(r.content)
|
| 146 |
+
|
| 147 |
+
original_config = yaml.safe_load(io_obj)
|
| 148 |
+
image_size = 512
|
| 149 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 150 |
+
if checkpoint_path.endswith("safetensors"):
|
| 151 |
+
from safetensors import safe_open
|
| 152 |
+
|
| 153 |
+
checkpoint = {}
|
| 154 |
+
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
|
| 155 |
+
for key in f.keys():
|
| 156 |
+
checkpoint[key] = f.get_tensor(key)
|
| 157 |
+
else:
|
| 158 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"]
|
| 159 |
+
|
| 160 |
+
# Convert the VAE model.
|
| 161 |
+
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
| 162 |
+
converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
| 163 |
+
|
| 164 |
+
vae = AutoencoderKL(**vae_config)
|
| 165 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
| 166 |
+
vae.save_pretrained(output_path)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == "__main__":
|
| 170 |
+
parser = argparse.ArgumentParser()
|
| 171 |
+
|
| 172 |
+
parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
|
| 173 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.")
|
| 174 |
+
|
| 175 |
+
args = parser.parse_args()
|
| 176 |
+
|
| 177 |
+
vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path)
|
diffusers/scripts/convert_wuerstchen.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Run inside root directory of official source code: https://github.com/dome272/wuerstchen/
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoTokenizer, CLIPTextModel
|
| 6 |
+
from vqgan import VQModel
|
| 7 |
+
|
| 8 |
+
from diffusers import (
|
| 9 |
+
DDPMWuerstchenScheduler,
|
| 10 |
+
WuerstchenCombinedPipeline,
|
| 11 |
+
WuerstchenDecoderPipeline,
|
| 12 |
+
WuerstchenPriorPipeline,
|
| 13 |
+
)
|
| 14 |
+
from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
model_path = "models/"
|
| 18 |
+
device = "cpu"
|
| 19 |
+
|
| 20 |
+
paella_vqmodel = VQModel()
|
| 21 |
+
state_dict = torch.load(os.path.join(model_path, "vqgan_f4_v1_500k.pt"), map_location=device)["state_dict"]
|
| 22 |
+
paella_vqmodel.load_state_dict(state_dict)
|
| 23 |
+
|
| 24 |
+
state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"]
|
| 25 |
+
state_dict.pop("vquantizer.codebook.weight")
|
| 26 |
+
vqmodel = PaellaVQModel(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent)
|
| 27 |
+
vqmodel.load_state_dict(state_dict)
|
| 28 |
+
|
| 29 |
+
# Clip Text encoder and tokenizer
|
| 30 |
+
text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
| 31 |
+
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
| 32 |
+
|
| 33 |
+
# Generator
|
| 34 |
+
gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu")
|
| 35 |
+
gen_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
| 36 |
+
|
| 37 |
+
orig_state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["state_dict"]
|
| 38 |
+
state_dict = {}
|
| 39 |
+
for key in orig_state_dict.keys():
|
| 40 |
+
if key.endswith("in_proj_weight"):
|
| 41 |
+
weights = orig_state_dict[key].chunk(3, 0)
|
| 42 |
+
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
| 43 |
+
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
| 44 |
+
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
| 45 |
+
elif key.endswith("in_proj_bias"):
|
| 46 |
+
weights = orig_state_dict[key].chunk(3, 0)
|
| 47 |
+
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
| 48 |
+
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
| 49 |
+
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
| 50 |
+
elif key.endswith("out_proj.weight"):
|
| 51 |
+
weights = orig_state_dict[key]
|
| 52 |
+
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
| 53 |
+
elif key.endswith("out_proj.bias"):
|
| 54 |
+
weights = orig_state_dict[key]
|
| 55 |
+
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
| 56 |
+
else:
|
| 57 |
+
state_dict[key] = orig_state_dict[key]
|
| 58 |
+
decoder = WuerstchenDiffNeXt()
|
| 59 |
+
decoder.load_state_dict(state_dict)
|
| 60 |
+
|
| 61 |
+
# Prior
|
| 62 |
+
orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"]
|
| 63 |
+
state_dict = {}
|
| 64 |
+
for key in orig_state_dict.keys():
|
| 65 |
+
if key.endswith("in_proj_weight"):
|
| 66 |
+
weights = orig_state_dict[key].chunk(3, 0)
|
| 67 |
+
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
| 68 |
+
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
| 69 |
+
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
| 70 |
+
elif key.endswith("in_proj_bias"):
|
| 71 |
+
weights = orig_state_dict[key].chunk(3, 0)
|
| 72 |
+
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
| 73 |
+
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
| 74 |
+
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
| 75 |
+
elif key.endswith("out_proj.weight"):
|
| 76 |
+
weights = orig_state_dict[key]
|
| 77 |
+
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
| 78 |
+
elif key.endswith("out_proj.bias"):
|
| 79 |
+
weights = orig_state_dict[key]
|
| 80 |
+
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
| 81 |
+
else:
|
| 82 |
+
state_dict[key] = orig_state_dict[key]
|
| 83 |
+
prior_model = WuerstchenPrior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device)
|
| 84 |
+
prior_model.load_state_dict(state_dict)
|
| 85 |
+
|
| 86 |
+
# scheduler
|
| 87 |
+
scheduler = DDPMWuerstchenScheduler()
|
| 88 |
+
|
| 89 |
+
# Prior pipeline
|
| 90 |
+
prior_pipeline = WuerstchenPriorPipeline(
|
| 91 |
+
prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
prior_pipeline.save_pretrained("warp-ai/wuerstchen-prior")
|
| 95 |
+
|
| 96 |
+
decoder_pipeline = WuerstchenDecoderPipeline(
|
| 97 |
+
text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=decoder, scheduler=scheduler
|
| 98 |
+
)
|
| 99 |
+
decoder_pipeline.save_pretrained("warp-ai/wuerstchen")
|
| 100 |
+
|
| 101 |
+
# Wuerstchen pipeline
|
| 102 |
+
wuerstchen_pipeline = WuerstchenCombinedPipeline(
|
| 103 |
+
# Decoder
|
| 104 |
+
text_encoder=gen_text_encoder,
|
| 105 |
+
tokenizer=gen_tokenizer,
|
| 106 |
+
decoder=decoder,
|
| 107 |
+
scheduler=scheduler,
|
| 108 |
+
vqgan=vqmodel,
|
| 109 |
+
# Prior
|
| 110 |
+
prior_tokenizer=tokenizer,
|
| 111 |
+
prior_text_encoder=text_encoder,
|
| 112 |
+
prior=prior_model,
|
| 113 |
+
prior_scheduler=scheduler,
|
| 114 |
+
)
|
| 115 |
+
wuerstchen_pipeline.save_pretrained("warp-ai/WuerstchenCombinedPipeline")
|
illustrious_generated/low_quality_images.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
illustrious_generated/natural_caption_generation_report.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
=== Natural Caption Generation Report ===
|
| 3 |
+
|
| 4 |
+
Processing Statistics:
|
| 5 |
+
- Total images processed: 9618
|
| 6 |
+
- Successfully captioned: 9618
|
| 7 |
+
- Errors encountered: 0
|
| 8 |
+
- Success rate: 100.0%
|
| 9 |
+
|
| 10 |
+
Time Statistics:
|
| 11 |
+
- Total processing time: 533.1 minutes
|
| 12 |
+
- Average time per image: 3.33 seconds
|
| 13 |
+
|
| 14 |
+
Completion time: 2025-07-29 20:42:34
|
illustrious_generated/optimization_final_results.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
illustrious_generated/optimization_summary_report.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
=== 图像质量优化总结报告 ===
|
| 3 |
+
|
| 4 |
+
处理统计:
|
| 5 |
+
- 总图像数: 9618
|
| 6 |
+
- 检测到低质量图像: 181
|
| 7 |
+
- 重新生成处理: 100
|
| 8 |
+
- 成功改善质量: 17
|
| 9 |
+
- 改善成功率: 17.0%
|
| 10 |
+
|
| 11 |
+
质量提升:
|
| 12 |
+
- 平均质量提升: 2.3 分
|
| 13 |
+
- 改善图像保存位置: /home/ubuntu/lyl/QwenIllustrious/illustrious_generated/improved
|
| 14 |
+
|
| 15 |
+
详细结果文件:
|
| 16 |
+
- 低质量图像记录: low_quality_images.json
|
| 17 |
+
- 重新生成结果: regeneration_results.json
|
| 18 |
+
- 最终优化结果: optimization_final_results.json
|
| 19 |
+
|
| 20 |
+
优化完成时间: 2025-07-29 08:24:41
|
illustrious_generated/regeneration_results.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
peft/.gitignore
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
# Usually these files are written by a python script from a template
|
| 32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
*.py,cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
target/
|
| 76 |
+
|
| 77 |
+
# Jupyter Notebook
|
| 78 |
+
.ipynb_checkpoints
|
| 79 |
+
|
| 80 |
+
# IPython
|
| 81 |
+
profile_default/
|
| 82 |
+
ipython_config.py
|
| 83 |
+
|
| 84 |
+
# pyenv
|
| 85 |
+
.python-version
|
| 86 |
+
|
| 87 |
+
# pipenv
|
| 88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 91 |
+
# install all needed dependencies.
|
| 92 |
+
#Pipfile.lock
|
| 93 |
+
|
| 94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 95 |
+
__pypackages__/
|
| 96 |
+
|
| 97 |
+
# Celery stuff
|
| 98 |
+
celerybeat-schedule
|
| 99 |
+
celerybeat.pid
|
| 100 |
+
|
| 101 |
+
# SageMath parsed files
|
| 102 |
+
*.sage.py
|
| 103 |
+
|
| 104 |
+
# Environments
|
| 105 |
+
.env
|
| 106 |
+
.venv
|
| 107 |
+
env/
|
| 108 |
+
venv/
|
| 109 |
+
ENV/
|
| 110 |
+
env.bak/
|
| 111 |
+
venv.bak/
|
| 112 |
+
|
| 113 |
+
# Spyder project settings
|
| 114 |
+
.spyderproject
|
| 115 |
+
.spyproject
|
| 116 |
+
|
| 117 |
+
# Rope project settings
|
| 118 |
+
.ropeproject
|
| 119 |
+
|
| 120 |
+
# mkdocs documentation
|
| 121 |
+
/site
|
| 122 |
+
|
| 123 |
+
# mypy
|
| 124 |
+
.mypy_cache/
|
| 125 |
+
.dmypy.json
|
| 126 |
+
dmypy.json
|
| 127 |
+
|
| 128 |
+
# Pyre type checker
|
| 129 |
+
.pyre/
|
| 130 |
+
|
| 131 |
+
# VSCode
|
| 132 |
+
.vscode
|
| 133 |
+
|
| 134 |
+
# IntelliJ
|
| 135 |
+
.idea
|
| 136 |
+
|
| 137 |
+
# Mac .DS_Store
|
| 138 |
+
.DS_Store
|
| 139 |
+
|
| 140 |
+
# More test things
|
| 141 |
+
wandb
|
| 142 |
+
|
| 143 |
+
# method_comparison logs
|
| 144 |
+
method_comparison/MetaMathQA/cancelled_results/
|
| 145 |
+
method_comparison/MetaMathQA/temporary_results/
|
peft/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 3 |
+
rev: v0.9.2
|
| 4 |
+
hooks:
|
| 5 |
+
- id: ruff
|
| 6 |
+
args:
|
| 7 |
+
- --fix
|
| 8 |
+
- id: ruff-format
|
| 9 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 10 |
+
rev: v4.6.0
|
| 11 |
+
hooks:
|
| 12 |
+
- id: check-merge-conflict
|
| 13 |
+
- id: check-yaml
|
peft/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
peft/Makefile
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: quality style test docs
|
| 2 |
+
|
| 3 |
+
check_dirs := src tests examples docs scripts docker
|
| 4 |
+
|
| 5 |
+
# Check that source code meets quality standards
|
| 6 |
+
|
| 7 |
+
# this target runs checks on all files
|
| 8 |
+
quality:
|
| 9 |
+
ruff check $(check_dirs)
|
| 10 |
+
ruff format --check $(check_dirs)
|
| 11 |
+
doc-builder style src/peft tests docs/source --max_len 119 --check_only
|
| 12 |
+
|
| 13 |
+
# Format source code automatically and check is there are any problems left that need manual fixing
|
| 14 |
+
style:
|
| 15 |
+
ruff check --fix $(check_dirs)
|
| 16 |
+
ruff format $(check_dirs)
|
| 17 |
+
doc-builder style src/peft tests docs/source --max_len 119
|
| 18 |
+
|
| 19 |
+
test:
|
| 20 |
+
python -m pytest -n 3 tests/ $(if $(IS_GITHUB_CI),--report-log "ci_tests.log",)
|
| 21 |
+
|
| 22 |
+
tests_examples_multi_gpu:
|
| 23 |
+
python -m pytest -m multi_gpu_tests tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "multi_gpu_examples.log",)
|
| 24 |
+
|
| 25 |
+
tests_examples_single_gpu:
|
| 26 |
+
python -m pytest -m single_gpu_tests tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "single_gpu_examples.log",)
|
| 27 |
+
|
| 28 |
+
tests_core_multi_gpu:
|
| 29 |
+
python -m pytest -m multi_gpu_tests tests/test_common_gpu.py $(if $(IS_GITHUB_CI),--report-log "core_multi_gpu.log",)
|
| 30 |
+
|
| 31 |
+
tests_core_single_gpu:
|
| 32 |
+
python -m pytest -m single_gpu_tests tests/test_common_gpu.py $(if $(IS_GITHUB_CI),--report-log "core_single_gpu.log",)
|
| 33 |
+
|
| 34 |
+
# exclude gemma tests, as generation fails with torch.compile, these failures
|
| 35 |
+
# trigger side effects that make other tests fail with 'RuntimeError: Offset
|
| 36 |
+
# increment outside graph capture encountered unexpectedly.'
|
| 37 |
+
# TODO re-enable gemma once/if it is fixed
|
| 38 |
+
tests_common_gpu:
|
| 39 |
+
python -m pytest tests/test_decoder_models.py -k "not gemma" $(if $(IS_GITHUB_CI),--report-log "common_decoder.log",)
|
| 40 |
+
python -m pytest tests/test_encoder_decoder_models.py $(if $(IS_GITHUB_CI),--report-log "common_encoder_decoder.log",)
|
| 41 |
+
python -m pytest tests/test_gptqmodel.py $(if $(IS_GITHUB_CI),--report-log "gptqmodel_gpu.log",)
|
| 42 |
+
|
| 43 |
+
tests_examples_multi_gpu_bnb:
|
| 44 |
+
python -m pytest -m "multi_gpu_tests and bitsandbytes" tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "multi_gpu_examples.log",)
|
| 45 |
+
|
| 46 |
+
tests_examples_single_gpu_bnb:
|
| 47 |
+
python -m pytest -m "single_gpu_tests and bitsandbytes" tests/test_gpu_examples.py $(if $(IS_GITHUB_CI),--report-log "single_gpu_examples.log",)
|
| 48 |
+
|
| 49 |
+
tests_core_multi_gpu_bnb:
|
| 50 |
+
python -m pytest -m "multi_gpu_tests and bitsandbytes" tests/test_common_gpu.py $(if $(IS_GITHUB_CI),--report-log "core_multi_gpu.log",)
|
| 51 |
+
|
| 52 |
+
tests_core_single_gpu_bnb:
|
| 53 |
+
python -m pytest -m "single_gpu_tests and bitsandbytes" tests/test_common_gpu.py $(if $(IS_GITHUB_CI),--report-log "core_single_gpu.log",)
|
| 54 |
+
|
| 55 |
+
tests_gpu_bnb_regression:
|
| 56 |
+
python -m pytest tests/bnb/test_bnb_regression.py $(if $(IS_GITHUB_CI),--report-log "bnb_regression_gpu.log",)
|
| 57 |
+
|
| 58 |
+
# For testing transformers tests for bnb runners
|
| 59 |
+
transformers_tests:
|
| 60 |
+
RUN_SLOW=1 python -m pytest transformers-clone/tests/quantization/bnb $(if $(IS_GITHUB_CI),--report-log "transformers_tests.log",)
|
| 61 |
+
|
| 62 |
+
tests_regression:
|
| 63 |
+
python -m pytest -s --regression tests/regression/ $(if $(IS_GITHUB_CI),--report-log "regression_tests.log",)
|
| 64 |
+
|
| 65 |
+
tests_torch_compile:
|
| 66 |
+
python -m pytest tests/test_torch_compile.py $(if $(IS_GITHUB_CI),--report-log "compile_tests.log",)
|
peft/README.md
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!---
|
| 2 |
+
Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
you may not use this file except in compliance with the License.
|
| 6 |
+
You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
See the License for the specific language governing permissions and
|
| 14 |
+
limitations under the License.
|
| 15 |
+
-->
|
| 16 |
+
|
| 17 |
+
<h1 align="center"> <p>🤗 PEFT</p></h1>
|
| 18 |
+
<h3 align="center">
|
| 19 |
+
<p>State-of-the-art Parameter-Efficient Fine-Tuning (PEFT) methods</p>
|
| 20 |
+
</h3>
|
| 21 |
+
|
| 22 |
+
Fine-tuning large pretrained models is often prohibitively costly due to their scale. Parameter-Efficient Fine-Tuning (PEFT) methods enable efficient adaptation of large pretrained models to various downstream applications by only fine-tuning a small number of (extra) model parameters instead of all the model's parameters. This significantly decreases the computational and storage costs. Recent state-of-the-art PEFT techniques achieve performance comparable to fully fine-tuned models.
|
| 23 |
+
|
| 24 |
+
PEFT is integrated with Transformers for easy model training and inference, Diffusers for conveniently managing different adapters, and Accelerate for distributed training and inference for really big models.
|
| 25 |
+
|
| 26 |
+
> [!TIP]
|
| 27 |
+
> Visit the [PEFT](https://huggingface.co/PEFT) organization to read about the PEFT methods implemented in the library and to see notebooks demonstrating how to apply these methods to a variety of downstream tasks. Click the "Watch repos" button on the organization page to be notified of newly implemented methods and notebooks!
|
| 28 |
+
|
| 29 |
+
Check the PEFT Adapters API Reference section for a list of supported PEFT methods, and read the [Adapters](https://huggingface.co/docs/peft/en/conceptual_guides/adapter), [Soft prompts](https://huggingface.co/docs/peft/en/conceptual_guides/prompting), and [IA3](https://huggingface.co/docs/peft/en/conceptual_guides/ia3) conceptual guides to learn more about how these methods work.
|
| 30 |
+
|
| 31 |
+
## Quickstart
|
| 32 |
+
|
| 33 |
+
Install PEFT from pip:
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
pip install peft
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
Prepare a model for training with a PEFT method such as LoRA by wrapping the base model and PEFT configuration with `get_peft_model`. For the bigscience/mt0-large model, you're only training 0.19% of the parameters!
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
from transformers import AutoModelForCausalLM
|
| 43 |
+
from peft import LoraConfig, TaskType, get_peft_model
|
| 44 |
+
|
| 45 |
+
device = "cuda"
|
| 46 |
+
model_id = "Qwen/Qwen2.5-3B-Instruct"
|
| 47 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)
|
| 48 |
+
peft_config = LoraConfig(
|
| 49 |
+
r=16,
|
| 50 |
+
lora_alpha=32,
|
| 51 |
+
task_type=TaskType.CAUSAL_LM,
|
| 52 |
+
# target_modules=["q_proj", "v_proj", ...] # optionally indicate target modules
|
| 53 |
+
)
|
| 54 |
+
model = get_peft_model(model, peft_config)
|
| 55 |
+
model.print_trainable_parameters()
|
| 56 |
+
# prints: trainable params: 3,686,400 || all params: 3,089,625,088 || trainable%: 0.1193
|
| 57 |
+
|
| 58 |
+
# now perform training on your dataset, e.g. using transformers Trainer, then save the model
|
| 59 |
+
model.save_pretrained("qwen2.5-3b-lora")
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
To load a PEFT model for inference:
|
| 63 |
+
|
| 64 |
+
```python
|
| 65 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 66 |
+
from peft import PeftModel
|
| 67 |
+
|
| 68 |
+
device = "cuda"
|
| 69 |
+
model_id = "Qwen/Qwen2.5-3B-Instruct"
|
| 70 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 71 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)
|
| 72 |
+
model = PeftModel.from_pretrained(model, "qwen2.5-3b-lora")
|
| 73 |
+
|
| 74 |
+
inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors="pt")
|
| 75 |
+
outputs = model.generate(**inputs.to(device), max_new_tokens=50)
|
| 76 |
+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 77 |
+
|
| 78 |
+
# prints something like: Preheat the oven to 350 degrees and place the cookie dough in a baking dish [...]
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
## Why you should use PEFT
|
| 82 |
+
|
| 83 |
+
There are many benefits of using PEFT but the main one is the huge savings in compute and storage, making PEFT applicable to many different use cases.
|
| 84 |
+
|
| 85 |
+
### High performance on consumer hardware
|
| 86 |
+
|
| 87 |
+
Consider the memory requirements for training the following models on the [ought/raft/twitter_complaints](https://huggingface.co/datasets/ought/raft/viewer/twitter_complaints) dataset with an A100 80GB GPU with more than 64GB of CPU RAM.
|
| 88 |
+
|
| 89 |
+
| Model | Full Finetuning | PEFT-LoRA PyTorch | PEFT-LoRA DeepSpeed with CPU Offloading |
|
| 90 |
+
| --------- | ---- | ---- | ---- |
|
| 91 |
+
| bigscience/T0_3B (3B params) | 47.14GB GPU / 2.96GB CPU | 14.4GB GPU / 2.96GB CPU | 9.8GB GPU / 17.8GB CPU |
|
| 92 |
+
| bigscience/mt0-xxl (12B params) | OOM GPU | 56GB GPU / 3GB CPU | 22GB GPU / 52GB CPU |
|
| 93 |
+
| bigscience/bloomz-7b1 (7B params) | OOM GPU | 32GB GPU / 3.8GB CPU | 18.1GB GPU / 35GB CPU |
|
| 94 |
+
|
| 95 |
+
With LoRA you can fully finetune a 12B parameter model that would've otherwise run out of memory on the 80GB GPU, and comfortably fit and train a 3B parameter model. When you look at the 3B parameter model's performance, it is comparable to a fully finetuned model at a fraction of the GPU memory.
|
| 96 |
+
|
| 97 |
+
| Submission Name | Accuracy |
|
| 98 |
+
| --------- | ---- |
|
| 99 |
+
| Human baseline (crowdsourced) | 0.897 |
|
| 100 |
+
| Flan-T5 | 0.892 |
|
| 101 |
+
| lora-t0-3b | 0.863 |
|
| 102 |
+
|
| 103 |
+
> [!TIP]
|
| 104 |
+
> The bigscience/T0_3B model performance isn't optimized in the table above. You can squeeze even more performance out of it by playing around with the input instruction templates, LoRA hyperparameters, and other training related hyperparameters. The final checkpoint size of this model is just 19MB compared to 11GB of the full bigscience/T0_3B model. Learn more about the advantages of finetuning with PEFT in this [blog post](https://www.philschmid.de/fine-tune-flan-t5-peft).
|
| 105 |
+
|
| 106 |
+
### Quantization
|
| 107 |
+
|
| 108 |
+
Quantization is another method for reducing the memory requirements of a model by representing the data in a lower precision. It can be combined with PEFT methods to make it even easier to train and load LLMs for inference.
|
| 109 |
+
|
| 110 |
+
* Learn how to finetune [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) with QLoRA and the [TRL](https://huggingface.co/docs/trl/index) library on a 16GB GPU in the [Finetune LLMs on your own consumer hardware using tools from PyTorch and Hugging Face ecosystem](https://pytorch.org/blog/finetune-llms/) blog post.
|
| 111 |
+
* Learn how to finetune a [openai/whisper-large-v2](https://huggingface.co/openai/whisper-large-v2) model for multilingual automatic speech recognition with LoRA and 8-bit quantization in this [notebook](https://colab.research.google.com/drive/1DOkD_5OUjFa0r5Ik3SgywJLJtEo2qLxO?usp=sharing) (see this [notebook](https://colab.research.google.com/drive/1vhF8yueFqha3Y3CpTHN6q9EVcII9EYzs?usp=sharing) instead for an example of streaming a dataset).
|
| 112 |
+
|
| 113 |
+
### Save compute and storage
|
| 114 |
+
|
| 115 |
+
PEFT can help you save storage by avoiding full finetuning of models on each of downstream task or dataset. In many cases, you're only finetuning a very small fraction of a model's parameters and each checkpoint is only a few MBs in size (instead of GBs). These smaller PEFT adapters demonstrate performance comparable to a fully finetuned model. If you have many datasets, you can save a lot of storage with a PEFT model and not have to worry about catastrophic forgetting or overfitting the backbone or base model.
|
| 116 |
+
|
| 117 |
+
## PEFT integrations
|
| 118 |
+
|
| 119 |
+
PEFT is widely supported across the Hugging Face ecosystem because of the massive efficiency it brings to training and inference.
|
| 120 |
+
|
| 121 |
+
### Diffusers
|
| 122 |
+
|
| 123 |
+
The iterative diffusion process consumes a lot of memory which can make it difficult to train. PEFT can help reduce the memory requirements and reduce the storage size of the final model checkpoint. For example, consider the memory required for training a Stable Diffusion model with LoRA on an A100 80GB GPU with more than 64GB of CPU RAM. The final model checkpoint size is only 8.8MB!
|
| 124 |
+
|
| 125 |
+
| Model | Full Finetuning | PEFT-LoRA | PEFT-LoRA with Gradient Checkpointing |
|
| 126 |
+
| --------- | ---- | ---- | ---- |
|
| 127 |
+
| CompVis/stable-diffusion-v1-4 | 27.5GB GPU / 3.97GB CPU | 15.5GB GPU / 3.84GB CPU | 8.12GB GPU / 3.77GB CPU |
|
| 128 |
+
|
| 129 |
+
> [!TIP]
|
| 130 |
+
> Take a look at the [examples/lora_dreambooth/train_dreambooth.py](examples/lora_dreambooth/train_dreambooth.py) training script to try training your own Stable Diffusion model with LoRA, and play around with the [smangrul/peft-lora-sd-dreambooth](https://huggingface.co/spaces/smangrul/peft-lora-sd-dreambooth) Space which is running on a T4 instance. Learn more about the PEFT integration in Diffusers in this [tutorial](https://huggingface.co/docs/peft/main/en/tutorial/peft_integrations#diffusers).
|
| 131 |
+
|
| 132 |
+
### Transformers
|
| 133 |
+
|
| 134 |
+
PEFT is directly integrated with [Transformers](https://huggingface.co/docs/transformers/main/en/peft). After loading a model, call `add_adapter` to add a new PEFT adapter to the model:
|
| 135 |
+
|
| 136 |
+
```python
|
| 137 |
+
from peft import LoraConfig
|
| 138 |
+
model = ... # transformers model
|
| 139 |
+
peft_config = LoraConfig(...)
|
| 140 |
+
model.add_adapter(lora_config, adapter_name="lora_1")
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
To load a trained PEFT adapter, call `load_adapter`:
|
| 144 |
+
|
| 145 |
+
```python
|
| 146 |
+
model = ... # transformers model
|
| 147 |
+
model.load_adapter(<path-to-adapter>, adapter_name="lora_1")
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
And to switch between different adapters, call `set_adapter`:
|
| 151 |
+
|
| 152 |
+
```python
|
| 153 |
+
model.set_adapter("lora_2")
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
The Transformers integration doesn't include all the functionalities offered in PEFT, such as methods for merging the adapter into the base model.
|
| 157 |
+
|
| 158 |
+
### Accelerate
|
| 159 |
+
|
| 160 |
+
[Accelerate](https://huggingface.co/docs/accelerate/index) is a library for distributed training and inference on various training setups and hardware (GPUs, TPUs, Apple Silicon, etc.). PEFT models work with Accelerate out of the box, making it really convenient to train really large models or use them for inference on consumer hardware with limited resources.
|
| 161 |
+
|
| 162 |
+
### TRL
|
| 163 |
+
|
| 164 |
+
PEFT can also be applied to training LLMs with RLHF components such as the ranker and policy. Get started by reading:
|
| 165 |
+
|
| 166 |
+
* [Fine-tune a Mistral-7b model with Direct Preference Optimization](https://towardsdatascience.com/fine-tune-a-mistral-7b-model-with-direct-preference-optimization-708042745aac) with PEFT and the [TRL](https://huggingface.co/docs/trl/index) library to learn more about the Direct Preference Optimization (DPO) method and how to apply it to a LLM.
|
| 167 |
+
* [Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU](https://huggingface.co/blog/trl-peft) with PEFT and the [TRL](https://huggingface.co/docs/trl/index) library, and then try out the [gpt2-sentiment_peft.ipynb](https://github.com/huggingface/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb) notebook to optimize GPT2 to generate positive movie reviews.
|
| 168 |
+
* [StackLLaMA: A hands-on guide to train LLaMA with RLHF](https://huggingface.co/blog/stackllama) with PEFT, and then try out the [stack_llama/scripts](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama/scripts) for supervised finetuning, reward modeling, and RL finetuning.
|
| 169 |
+
|
| 170 |
+
## Model support
|
| 171 |
+
|
| 172 |
+
Use this [Space](https://stevhliu-peft-methods.hf.space) or check out the [docs](https://huggingface.co/docs/peft/main/en/index) to find which models officially support a PEFT method out of the box. Even if you don't see a model listed below, you can manually configure the model config to enable PEFT for a model. Read the [New transformers architecture](https://huggingface.co/docs/peft/main/en/developer_guides/custom_models#new-transformers-architectures) guide to learn how.
|
| 173 |
+
|
| 174 |
+
## Contribute
|
| 175 |
+
|
| 176 |
+
If you would like to contribute to PEFT, please check out our [contribution guide](https://huggingface.co/docs/peft/developer_guides/contributing).
|
| 177 |
+
|
| 178 |
+
## Citing 🤗 PEFT
|
| 179 |
+
|
| 180 |
+
To use 🤗 PEFT in your publication, please cite it by using the following BibTeX entry.
|
| 181 |
+
|
| 182 |
+
```bibtex
|
| 183 |
+
@Misc{peft,
|
| 184 |
+
title = {PEFT: State-of-the-art Parameter-Efficient Fine-Tuning methods},
|
| 185 |
+
author = {Sourab Mangrulkar and Sylvain Gugger and Lysandre Debut and Younes Belkada and Sayak Paul and Benjamin Bossan},
|
| 186 |
+
howpublished = {\url{https://github.com/huggingface/peft}},
|
| 187 |
+
year = {2022}
|
| 188 |
+
}
|
| 189 |
+
```
|
peft/pyproject.toml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.black]
|
| 2 |
+
# Only used by `hf-doc-builder´.
|
| 3 |
+
line-length = 119
|
| 4 |
+
target-version = ['py38']
|
| 5 |
+
|
| 6 |
+
[tool.ruff]
|
| 7 |
+
target-version = "py39"
|
| 8 |
+
line-length = 119
|
| 9 |
+
extend-exclude = ["*.ipynb"]
|
| 10 |
+
|
| 11 |
+
[tool.ruff.lint]
|
| 12 |
+
preview = true
|
| 13 |
+
explicit-preview-rules = true
|
| 14 |
+
extend-select = [
|
| 15 |
+
"C", # Complexity
|
| 16 |
+
"E", # PEP8 errors
|
| 17 |
+
"F", # PEP8 formatting
|
| 18 |
+
"I", # Import sorting
|
| 19 |
+
"UP", # Pyupgrade upgrades
|
| 20 |
+
"W", # PEP8 warnings
|
| 21 |
+
"PT009", # Pytest assertions
|
| 22 |
+
"RUF022", # Sorting of __all__
|
| 23 |
+
]
|
| 24 |
+
ignore = [
|
| 25 |
+
"C901", # Function too complex
|
| 26 |
+
"E501", # Line length (handled by ruff-format)
|
| 27 |
+
"F841", # unused variable
|
| 28 |
+
"UP007", # X | Y style Unions
|
| 29 |
+
"C420", # dict.fromkeys
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
[tool.ruff.lint.isort]
|
| 33 |
+
lines-after-imports = 2
|
| 34 |
+
known-first-party = ["peft"]
|
| 35 |
+
|
| 36 |
+
[tool.pytest]
|
| 37 |
+
doctest_optionflags = [
|
| 38 |
+
"NORMALIZE_WHITESPACE",
|
| 39 |
+
"ELLIPSIS",
|
| 40 |
+
"NUMBER",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
[tool.pytest.ini_options]
|
| 44 |
+
addopts = "--cov=src/peft --cov-report=term-missing --durations=10"
|
| 45 |
+
markers = [
|
| 46 |
+
"single_gpu_tests: tests that run on a single GPU",
|
| 47 |
+
"multi_gpu_tests: tests that run on multiple GPUs",
|
| 48 |
+
"regression: whether to run regression suite test",
|
| 49 |
+
"bitsandbytes: select bitsandbytes integration tests"
|
| 50 |
+
]
|
peft/requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
+
torch
|
| 3 |
+
safetensors
|
| 4 |
+
bitsandbytes
|
| 5 |
+
scipy
|
| 6 |
+
peft
|
| 7 |
+
transformers
|
| 8 |
+
tqdm
|
| 9 |
+
packaging
|
| 10 |
+
pytest
|
| 11 |
+
numpy
|
| 12 |
+
pyyaml
|
| 13 |
+
datasets
|
| 14 |
+
psutil
|
| 15 |
+
setuptools
|
peft/setup.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from setuptools import find_packages, setup
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
VERSION = "0.16.1.dev0"
|
| 19 |
+
|
| 20 |
+
extras = {}
|
| 21 |
+
extras["quality"] = [
|
| 22 |
+
"black", # doc-builder has an implicit dependency on Black, see huggingface/doc-builder#434
|
| 23 |
+
"hf-doc-builder",
|
| 24 |
+
"ruff~=0.9.2",
|
| 25 |
+
]
|
| 26 |
+
extras["docs_specific"] = [
|
| 27 |
+
"black", # doc-builder has an implicit dependency on Black, see huggingface/doc-builder#434
|
| 28 |
+
"hf-doc-builder",
|
| 29 |
+
]
|
| 30 |
+
extras["dev"] = extras["quality"] + extras["docs_specific"]
|
| 31 |
+
extras["test"] = extras["dev"] + [
|
| 32 |
+
"pytest",
|
| 33 |
+
"pytest-cov",
|
| 34 |
+
"pytest-xdist",
|
| 35 |
+
"parameterized",
|
| 36 |
+
"datasets",
|
| 37 |
+
"diffusers",
|
| 38 |
+
"scipy",
|
| 39 |
+
"protobuf",
|
| 40 |
+
"sentencepiece",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
setup(
|
| 44 |
+
name="peft",
|
| 45 |
+
version=VERSION,
|
| 46 |
+
description="Parameter-Efficient Fine-Tuning (PEFT)",
|
| 47 |
+
license_files=["LICENSE"],
|
| 48 |
+
long_description=open("README.md", encoding="utf-8").read(),
|
| 49 |
+
long_description_content_type="text/markdown",
|
| 50 |
+
keywords="deep learning",
|
| 51 |
+
license="Apache",
|
| 52 |
+
author="The HuggingFace team",
|
| 53 |
+
author_email="benjamin@huggingface.co",
|
| 54 |
+
url="https://github.com/huggingface/peft",
|
| 55 |
+
package_dir={"": "src"},
|
| 56 |
+
packages=find_packages("src"),
|
| 57 |
+
package_data={"peft": ["py.typed", "tuners/boft/fbd/fbd_cuda.cpp", "tuners/boft/fbd/fbd_cuda_kernel.cu"]},
|
| 58 |
+
entry_points={},
|
| 59 |
+
python_requires=">=3.9.0",
|
| 60 |
+
install_requires=[
|
| 61 |
+
"numpy>=1.17",
|
| 62 |
+
"packaging>=20.0",
|
| 63 |
+
"psutil",
|
| 64 |
+
"pyyaml",
|
| 65 |
+
"torch>=1.13.0",
|
| 66 |
+
"transformers",
|
| 67 |
+
"tqdm",
|
| 68 |
+
"accelerate>=0.21.0",
|
| 69 |
+
"safetensors",
|
| 70 |
+
"huggingface_hub>=0.25.0",
|
| 71 |
+
],
|
| 72 |
+
extras_require=extras,
|
| 73 |
+
classifiers=[
|
| 74 |
+
"Development Status :: 5 - Production/Stable",
|
| 75 |
+
"Intended Audience :: Developers",
|
| 76 |
+
"Intended Audience :: Education",
|
| 77 |
+
"Intended Audience :: Science/Research",
|
| 78 |
+
"License :: OSI Approved :: Apache Software License",
|
| 79 |
+
"Operating System :: OS Independent",
|
| 80 |
+
"Programming Language :: Python :: 3",
|
| 81 |
+
"Programming Language :: Python :: 3.9",
|
| 82 |
+
"Programming Language :: Python :: 3.10",
|
| 83 |
+
"Programming Language :: Python :: 3.11",
|
| 84 |
+
"Programming Language :: Python :: 3.12",
|
| 85 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 86 |
+
],
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Release checklist
|
| 90 |
+
# 1. Change the version in __init__.py and setup.py to the release version, e.g. from "0.6.1.dev0" to "0.7.0"
|
| 91 |
+
# 2. Check if there are any deprecations that need to be addressed for this release by searching for "# TODO" in the code
|
| 92 |
+
# 3. Commit these changes with the message: "Release: VERSION", create a PR and merge it.
|
| 93 |
+
# 4. Add a tag in git to mark the release: "git tag -a v<VERSION> -m 'Adds tag <VERSION> for pypi' "
|
| 94 |
+
# Push the tag to git:
|
| 95 |
+
# git push --tags origin main
|
| 96 |
+
# It is necessary to work on the original repository, not on a fork.
|
| 97 |
+
# 5. Run the following commands in the top-level directory:
|
| 98 |
+
# python setup.py bdist_wheel
|
| 99 |
+
# python setup.py sdist
|
| 100 |
+
# Ensure that you are on the clean and up-to-date main branch (git status --untracked-files=no should not list any
|
| 101 |
+
# files and show the main branch)
|
| 102 |
+
# 6. Upload the package to the pypi test server first:
|
| 103 |
+
# twine upload dist/* -r pypitest
|
| 104 |
+
# 7. Check that you can install it in a virtualenv by running:
|
| 105 |
+
# pip install -i https://testpypi.python.org/pypi --extra-index-url https://pypi.org/simple peft
|
| 106 |
+
# 8. Upload the final version to actual pypi:
|
| 107 |
+
# twine upload dist/* -r pypi
|
| 108 |
+
# 9. Add release notes to the tag on https://github.com/huggingface/peft/releases once everything is looking hunky-dory.
|
| 109 |
+
# Check the notes here: https://docs.google.com/document/d/1k-sOIfykuKjWcOIALqjhFKz4amFEp-myeJUJEzNgjoU/edit?usp=sharing
|
| 110 |
+
# 10. Update the version in __init__.py, setup.py to the bumped patch version + ".dev0" (e.g. from "0.7.0" to "0.7.1.dev0")
|
sentence-transformers/.gitignore
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Distribution / packaging
|
| 2 |
+
.Python
|
| 3 |
+
build/
|
| 4 |
+
develop-eggs/
|
| 5 |
+
dist/
|
| 6 |
+
downloads/
|
| 7 |
+
eggs/
|
| 8 |
+
.eggs/
|
| 9 |
+
lib/
|
| 10 |
+
lib64/
|
| 11 |
+
parts/
|
| 12 |
+
sdist/
|
| 13 |
+
var/
|
| 14 |
+
wheels/
|
| 15 |
+
share/python-wheels/
|
| 16 |
+
*.egg-info/
|
| 17 |
+
.installed.cfg
|
| 18 |
+
*.egg
|
| 19 |
+
MANIFEST
|
| 20 |
+
|
| 21 |
+
# Docs
|
| 22 |
+
/docs/_build/
|
| 23 |
+
/docs/make.bat
|
| 24 |
+
|
| 25 |
+
# Editors
|
| 26 |
+
.idea
|
| 27 |
+
.vscode
|
| 28 |
+
|
| 29 |
+
# Coverage
|
| 30 |
+
htmlcov
|
| 31 |
+
.coverage*
|
| 32 |
+
coverage.xml
|
| 33 |
+
|
| 34 |
+
# Examples
|
| 35 |
+
/examples/**/output/*
|
| 36 |
+
/examples/datasets/
|
| 37 |
+
/examples/embeddings/
|
| 38 |
+
/examples/sentence_transformer/training/quora_duplicate_questions/quora-IR-dataset/
|
| 39 |
+
examples/datasets/*/
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Specific files and folders
|
| 43 |
+
/pretrained-models/
|
| 44 |
+
/cheatsheet.txt
|
| 45 |
+
/testsuite.txt
|
| 46 |
+
/TODO.txt
|
| 47 |
+
|
| 48 |
+
# Virtual environments
|
| 49 |
+
.env
|
| 50 |
+
.venv
|
| 51 |
+
env/
|
| 52 |
+
venv/
|
| 53 |
+
|
| 54 |
+
# Database
|
| 55 |
+
/qdrant_storage
|
| 56 |
+
/elastic-start-local
|
| 57 |
+
|
| 58 |
+
# Others
|
| 59 |
+
*.pyc
|
| 60 |
+
*.gz
|
| 61 |
+
*.tsv
|
| 62 |
+
tmp_*.py
|
| 63 |
+
nr_*/
|
| 64 |
+
wandb
|
| 65 |
+
checkpoints
|
| 66 |
+
tmp
|
| 67 |
+
.DS_Store
|
| 68 |
+
/runs
|
| 69 |
+
/tmp_trainer/
|