|
|
import torch |
|
|
import os |
|
|
from diffusers import UNet2DConditionModel |
|
|
from diffusers.configuration_utils import FrozenDict |
|
|
from typing import Dict |
|
|
|
|
|
|
|
|
UNET_PATH = './unet_old' |
|
|
NEW_UNET_PATH = './unet' |
|
|
|
|
|
|
|
|
|
|
|
def migrate_unet_weights(old_unet: UNet2DConditionModel, new_unet: UNet2DConditionModel) -> UNet2DConditionModel: |
|
|
"""Копирует веса из старой UNet в новую, игнорируя новые слои.""" |
|
|
old_state_dict = old_unet.state_dict() |
|
|
new_state_dict = new_unet.state_dict() |
|
|
|
|
|
missing_keys = [] |
|
|
|
|
|
|
|
|
for name, param in new_state_dict.items(): |
|
|
if name in old_state_dict: |
|
|
if param.shape == old_state_dict[name].shape: |
|
|
param.data.copy_(old_state_dict[name].data) |
|
|
else: |
|
|
print(f"⚠️ Пропуск ключа {name}: не совпадают формы ({old_state_dict[name].shape} vs {param.shape})") |
|
|
else: |
|
|
missing_keys.append(name) |
|
|
|
|
|
print(f"✅ Успешно перенесено {len(new_state_dict) - len(missing_keys)} весов.") |
|
|
print("\n--- Новые слои (случайные веса, требуют дообучения) ---") |
|
|
for key in missing_keys: |
|
|
print(f"🆕 {key}") |
|
|
print("----------------------------------------------------------") |
|
|
|
|
|
return new_unet |
|
|
|
|
|
|
|
|
|
|
|
print(f"1. Загрузка UNet из {UNET_PATH}...") |
|
|
try: |
|
|
|
|
|
old_unet = UNet2DConditionModel.from_pretrained(UNET_PATH, torch_dtype=torch.float32) |
|
|
old_config: Dict = old_unet.config |
|
|
print(" -> Исходная UNet успешно загружена.") |
|
|
except Exception as e: |
|
|
print(f"🛑 Ошибка при загрузке UNet: {e}") |
|
|
exit() |
|
|
|
|
|
|
|
|
print("2. Модификация конфигурации...") |
|
|
|
|
|
|
|
|
new_config = dict(old_config) |
|
|
new_config.update({ |
|
|
|
|
|
"addition_embed_type": "text", |
|
|
|
|
|
|
|
|
"addition_time_embed_dim": 1024, |
|
|
|
|
|
|
|
|
"encoder_hid_dim_type": "text_proj", |
|
|
|
|
|
|
|
|
"encoder_hid_dim": 1024, |
|
|
|
|
|
|
|
|
|
|
|
"projection_class_embeddings_input_dim": 1024, |
|
|
|
|
|
|
|
|
"time_embedding_dim": 1024, |
|
|
|
|
|
|
|
|
"addition_embed_type_num_heads": 64, |
|
|
}) |
|
|
|
|
|
|
|
|
print("3. Инициализация новой UNet с измененной архитектурой...") |
|
|
new_config_frozen = FrozenDict(new_config) |
|
|
|
|
|
new_unet = UNet2DConditionModel.from_config(new_config_frozen) |
|
|
print(" -> Новая UNet инициализирована (новые слои имеют случайные веса).") |
|
|
|
|
|
|
|
|
|
|
|
print("4. Выполнение миграции весов...") |
|
|
migrated_unet = migrate_unet_weights(old_unet, new_unet) |
|
|
|
|
|
|
|
|
|
|
|
print(f"5. Сохранение новой UNet в папку {NEW_UNET_PATH}...") |
|
|
|
|
|
|
|
|
os.makedirs(NEW_UNET_PATH, exist_ok=True) |
|
|
|
|
|
|
|
|
migrated_unet.save_pretrained(NEW_UNET_PATH) |
|
|
|
|
|
print("🎉 Готово! Новая UNet готова к использованию и дообучению.") |
|
|
|
|
|
print(f"\nСледующий шаг: Замените путь к UNet в вашем SdxsPipeline на '{NEW_UNET_PATH}' и запустите инференс, чтобы убедиться, что она принимает `added_cond_kwargs` без ошибок.") |