File size: 5,055 Bytes
0e02107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import os
from diffusers import UNet2DConditionModel
from diffusers.configuration_utils import FrozenDict
from typing import Dict

# --- КОНСТАНТЫ И ПУТИ ---
UNET_PATH = './unet_old'
NEW_UNET_PATH = './unet'

# --- ШАГ 1: Определение функции миграции весов ---

def migrate_unet_weights(old_unet: UNet2DConditionModel, new_unet: UNet2DConditionModel) -> UNet2DConditionModel:
    """Копирует веса из старой UNet в новую, игнорируя новые слои."""
    old_state_dict = old_unet.state_dict()
    new_state_dict = new_unet.state_dict()
    
    missing_keys = []
    
    # 1. Копируем существующие веса
    for name, param in new_state_dict.items():
        if name in old_state_dict:
            if param.shape == old_state_dict[name].shape:
                param.data.copy_(old_state_dict[name].data)
            else:
                print(f"⚠️ Пропуск ключа {name}: не совпадают формы ({old_state_dict[name].shape} vs {param.shape})")
        else:
            missing_keys.append(name)
            
    print(f"✅ Успешно перенесено {len(new_state_dict) - len(missing_keys)} весов.")
    print("\n--- Новые слои (случайные веса, требуют дообучения) ---")
    for key in missing_keys:
         print(f"🆕 {key}")
    print("----------------------------------------------------------")
    
    return new_unet

# --- ШАГ 2: Загрузка и модификация конфигурации ---

print(f"1. Загрузка UNet из {UNET_PATH}...")
try:
    # Загружаем вашу исходную UNet
    old_unet = UNet2DConditionModel.from_pretrained(UNET_PATH, torch_dtype=torch.float32)
    old_config: Dict = old_unet.config
    print("   -> Исходная UNet успешно загружена.")
except Exception as e:
    print(f"🛑 Ошибка при загрузке UNet: {e}")
    exit()

# 2. Модификация конфигурации для добавления пулинг-слоя (SDXL-подобный стиль)
print("2. Модификация конфигурации...")

# Ключевые изменения для включения added_cond_kwargs
new_config = dict(old_config)
new_config.update({
    # Активируем дополнительное встраивание для текста и времени
    "addition_embed_type": "text", 
    
    # Размерность текстового пулинга (1024D для Qwen-0.6B)
    "addition_time_embed_dim": 1024, 
    
    # Тип проекции для hid-эмбеддингов
    "encoder_hid_dim_type": "text_proj",
    
    # ПЕРВОЕ ИСПРАВЛЕНИЕ (из предыдущего шага)
    "encoder_hid_dim": 1024, 
    
    # ВТОРОЕ ИСПРАВЛЕНИЕ: Определение входной размерности для add_embedding
    # Она должна соответствовать размерности пулинг-слоя, который мы передаем.
    "projection_class_embeddings_input_dim": 1024, 
    
    # Размерность, в которую будут проецироваться пулинг-эмбеддинги (совместно с временем)
    "time_embedding_dim": 1024, 
    
    # Убеждаемся, что исходный unet не имел этих фич, иначе могут быть конфликты
    "addition_embed_type_num_heads": 64,
})

# 3. Создание новой UNet с модифицированной архитектурой
print("3. Инициализация новой UNet с измененной архитектурой...")
new_config_frozen = FrozenDict(new_config) 
# Теперь инициализация пройдет успешно, так как `encoder_hid_dim` определен
new_unet = UNet2DConditionModel.from_config(new_config_frozen)
print("   -> Новая UNet инициализирована (новые слои имеют случайные веса).")

# --- ШАГ 3: Перенос весов ---

print("4. Выполнение миграции весов...")
migrated_unet = migrate_unet_weights(old_unet, new_unet)

# --- ШАГ 4: Сохранение новой UNet ---

print(f"5. Сохранение новой UNet в папку {NEW_UNET_PATH}...")

# Создаем папку, если она не существует
os.makedirs(NEW_UNET_PATH, exist_ok=True)

# Сохраняем модель
migrated_unet.save_pretrained(NEW_UNET_PATH)

print("🎉 Готово! Новая UNet готова к использованию и дообучению.")

print(f"\nСледующий шаг: Замените путь к UNet в вашем SdxsPipeline на '{NEW_UNET_PATH}' и запустите инференс, чтобы убедиться, что она принимает `added_cond_kwargs` без ошибок.")