sdxs / src /pooling.py
recoilme's picture
2511
f08d8ce
import torch
import os
from diffusers import UNet2DConditionModel
from diffusers.configuration_utils import FrozenDict
from typing import Dict
# --- КОНСТАНТЫ И ПУТИ ---
UNET_PATH = './unet_old'
NEW_UNET_PATH = './unet'
# --- ШАГ 1: Определение функции миграции весов ---
def migrate_unet_weights(old_unet: UNet2DConditionModel, new_unet: UNet2DConditionModel) -> UNet2DConditionModel:
"""Копирует веса из старой UNet в новую, игнорируя новые слои."""
old_state_dict = old_unet.state_dict()
new_state_dict = new_unet.state_dict()
missing_keys = []
# 1. Копируем существующие веса
for name, param in new_state_dict.items():
if name in old_state_dict:
if param.shape == old_state_dict[name].shape:
param.data.copy_(old_state_dict[name].data)
else:
print(f"⚠️ Пропуск ключа {name}: не совпадают формы ({old_state_dict[name].shape} vs {param.shape})")
else:
missing_keys.append(name)
print(f"✅ Успешно перенесено {len(new_state_dict) - len(missing_keys)} весов.")
print("\n--- Новые слои (случайные веса, требуют дообучения) ---")
for key in missing_keys:
print(f"🆕 {key}")
print("----------------------------------------------------------")
return new_unet
# --- ШАГ 2: Загрузка и модификация конфигурации ---
print(f"1. Загрузка UNet из {UNET_PATH}...")
try:
# Загружаем вашу исходную UNet
old_unet = UNet2DConditionModel.from_pretrained(UNET_PATH, torch_dtype=torch.float32)
old_config: Dict = old_unet.config
print(" -> Исходная UNet успешно загружена.")
except Exception as e:
print(f"🛑 Ошибка при загрузке UNet: {e}")
exit()
# 2. Модификация конфигурации для добавления пулинг-слоя (SDXL-подобный стиль)
print("2. Модификация конфигурации...")
# Ключевые изменения для включения added_cond_kwargs
new_config = dict(old_config)
new_config.update({
# Активируем дополнительное встраивание для текста и времени
"addition_embed_type": "text",
# Размерность текстового пулинга (1024D для Qwen-0.6B)
"addition_time_embed_dim": 1024,
# Тип проекции для hid-эмбеддингов
"encoder_hid_dim_type": "text_proj",
# ПЕРВОЕ ИСПРАВЛЕНИЕ (из предыдущего шага)
"encoder_hid_dim": 1024,
# ВТОРОЕ ИСПРАВЛЕНИЕ: Определение входной размерности для add_embedding
# Она должна соответствовать размерности пулинг-слоя, который мы передаем.
"projection_class_embeddings_input_dim": 1024,
# Размерность, в которую будут проецироваться пулинг-эмбеддинги (совместно с временем)
"time_embedding_dim": 1024,
# Убеждаемся, что исходный unet не имел этих фич, иначе могут быть конфликты
"addition_embed_type_num_heads": 64,
})
# 3. Создание новой UNet с модифицированной архитектурой
print("3. Инициализация новой UNet с измененной архитектурой...")
new_config_frozen = FrozenDict(new_config)
# Теперь инициализация пройдет успешно, так как `encoder_hid_dim` определен
new_unet = UNet2DConditionModel.from_config(new_config_frozen)
print(" -> Новая UNet инициализирована (новые слои имеют случайные веса).")
# --- ШАГ 3: Перенос весов ---
print("4. Выполнение миграции весов...")
migrated_unet = migrate_unet_weights(old_unet, new_unet)
# --- ШАГ 4: Сохранение новой UNet ---
print(f"5. Сохранение новой UNet в папку {NEW_UNET_PATH}...")
# Создаем папку, если она не существует
os.makedirs(NEW_UNET_PATH, exist_ok=True)
# Сохраняем модель
migrated_unet.save_pretrained(NEW_UNET_PATH)
print("🎉 Готово! Новая UNet готова к использованию и дообучению.")
print(f"\nСледующий шаг: Замените путь к UNet в вашем SdxsPipeline на '{NEW_UNET_PATH}' и запустите инференс, чтобы убедиться, что она принимает `added_cond_kwargs` без ошибок.")