import torch import os from diffusers import UNet2DConditionModel from diffusers.configuration_utils import FrozenDict from typing import Dict # --- КОНСТАНТЫ И ПУТИ --- UNET_PATH = './unet_old' NEW_UNET_PATH = './unet' # --- ШАГ 1: Определение функции миграции весов --- def migrate_unet_weights(old_unet: UNet2DConditionModel, new_unet: UNet2DConditionModel) -> UNet2DConditionModel: """Копирует веса из старой UNet в новую, игнорируя новые слои.""" old_state_dict = old_unet.state_dict() new_state_dict = new_unet.state_dict() missing_keys = [] # 1. Копируем существующие веса for name, param in new_state_dict.items(): if name in old_state_dict: if param.shape == old_state_dict[name].shape: param.data.copy_(old_state_dict[name].data) else: print(f"⚠️ Пропуск ключа {name}: не совпадают формы ({old_state_dict[name].shape} vs {param.shape})") else: missing_keys.append(name) print(f"✅ Успешно перенесено {len(new_state_dict) - len(missing_keys)} весов.") print("\n--- Новые слои (случайные веса, требуют дообучения) ---") for key in missing_keys: print(f"🆕 {key}") print("----------------------------------------------------------") return new_unet # --- ШАГ 2: Загрузка и модификация конфигурации --- print(f"1. Загрузка UNet из {UNET_PATH}...") try: # Загружаем вашу исходную UNet old_unet = UNet2DConditionModel.from_pretrained(UNET_PATH, torch_dtype=torch.float32) old_config: Dict = old_unet.config print(" -> Исходная UNet успешно загружена.") except Exception as e: print(f"🛑 Ошибка при загрузке UNet: {e}") exit() # 2. Модификация конфигурации для добавления пулинг-слоя (SDXL-подобный стиль) print("2. Модификация конфигурации...") # Ключевые изменения для включения added_cond_kwargs new_config = dict(old_config) new_config.update({ # Активируем дополнительное встраивание для текста и времени "addition_embed_type": "text", # Размерность текстового пулинга (1024D для Qwen-0.6B) "addition_time_embed_dim": 1024, # Тип проекции для hid-эмбеддингов "encoder_hid_dim_type": "text_proj", # ПЕРВОЕ ИСПРАВЛЕНИЕ (из предыдущего шага) "encoder_hid_dim": 1024, # ВТОРОЕ ИСПРАВЛЕНИЕ: Определение входной размерности для add_embedding # Она должна соответствовать размерности пулинг-слоя, который мы передаем. "projection_class_embeddings_input_dim": 1024, # Размерность, в которую будут проецироваться пулинг-эмбеддинги (совместно с временем) "time_embedding_dim": 1024, # Убеждаемся, что исходный unet не имел этих фич, иначе могут быть конфликты "addition_embed_type_num_heads": 64, }) # 3. Создание новой UNet с модифицированной архитектурой print("3. Инициализация новой UNet с измененной архитектурой...") new_config_frozen = FrozenDict(new_config) # Теперь инициализация пройдет успешно, так как `encoder_hid_dim` определен new_unet = UNet2DConditionModel.from_config(new_config_frozen) print(" -> Новая UNet инициализирована (новые слои имеют случайные веса).") # --- ШАГ 3: Перенос весов --- print("4. Выполнение миграции весов...") migrated_unet = migrate_unet_weights(old_unet, new_unet) # --- ШАГ 4: Сохранение новой UNet --- print(f"5. Сохранение новой UNet в папку {NEW_UNET_PATH}...") # Создаем папку, если она не существует os.makedirs(NEW_UNET_PATH, exist_ok=True) # Сохраняем модель migrated_unet.save_pretrained(NEW_UNET_PATH) print("🎉 Готово! Новая UNet готова к использованию и дообучению.") print(f"\nСледующий шаг: Замените путь к UNet в вашем SdxsPipeline на '{NEW_UNET_PATH}' и запустите инференс, чтобы убедиться, что она принимает `added_cond_kwargs` без ошибок.")