File size: 5,534 Bytes
9614bab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from diffusers import UNet2DConditionModel
import os
from safetensors.torch import load_file as safe_load

# --- КОНСТАНТЫ ---
OLD_UNET_PATH = "/workspace/sdxs/unet_old" 
NEW_UNET_PATH = "/workspace/sdxs/unet"     

def transfer_unet_weights_fix(old_path: str, new_path: str):
    
    print(f"Загрузка новой UNet (Конфиг 2048D) из: {new_path}")
    new_unet = UNet2DConditionModel.from_pretrained(new_path, low_cpu_mem_usage=False)

    # 2. Прямая загрузка state dict старой модели (с обработкой формата)
    print(f"Прямая загрузка весов старой UNet из: {old_path}")
    
    weights_file = next((f for f in os.listdir(old_path) if f.endswith(('.safetensors', '.bin'))), None)
    if not weights_file:
        print("Ошибка: не найден файл весов (.safetensors или .bin) в старой папке.")
        return None
    
    if weights_file.endswith('.safetensors'):
        print(f"Обнаружен файл Safetensors ({weights_file}). Использую safe_load.")
        old_state_dict = safe_load(f"{old_path}/{weights_file}")
    elif weights_file.endswith('.bin'):
        print(f"Обнаружен файл PyTorch (.bin) ({weights_file}). Использую torch.load с weights_only=False.")
        old_state_dict = torch.load(
            f"{old_path}/{weights_file}", 
            map_location='cpu',
            weights_only=False 
        )
    else:
        print(f"Ошибка: Не удалось загрузить файл {weights_file}. Проверьте формат.")
        return None
    
    if "state_dict" in old_state_dict:
        old_state_dict = old_state_dict["state_dict"]
    
    # 3. Перенос весов с обработкой RuntimeError
    print("Начало переноса весов с пропуском несовместимых слоев...")
    
    # --- НОВОЕ: Предварительный подсчет совпадающих ключей ---
    total_keys = len(old_state_dict)
    matching_keys = 0
    size_mismatch_keys = 0
    
    # Используем ключи новой модели для сравнения, чтобы убедиться, 
    # что ключи, которые отсутствуют в старой, не учитываются в "перенесенных".
    new_keys = new_unet.state_dict().keys()

    for name, old_param in old_state_dict.items():
        if name in new_keys:
            new_param = new_unet.state_dict()[name]
            # Считаем только те, где имя и размер совпали
            if old_param.shape == new_param.shape:
                matching_keys += 1
            # Считаем те, где имя совпало, но размер изменился (mismatch)
            else:
                size_mismatch_keys += 1
        # Ключи, которые есть в старой, но нет в новой (unexpected), 
        # также будут пропущены. Но нас в первую очередь интересуют size mismatch.

    # Пытаемся загрузить веса, зная, что это вызовет RuntimeError
    try:
        # Запуск actual переноса. Ключи будут перенесены.
        new_unet.load_state_dict(old_state_dict, strict=False)
        
    except RuntimeError as e:
        # Это ожидаемый блок! Он ловит RuntimeError, вызванный mismatch size.
        # Веса, которые совпали, УЖЕ перенесены в new_unet.
        
        print("\n--- Отчет о переносе весов ---")
        print("⚠️ Обнаружен ожидаемый **RuntimeError** из-за несовпадения размеров. Это нормально!")
        
        print(f"💡 **УСПЕШНО перенесенных ключей (совпадающий размер): {matching_keys} шт.**")
        print(f"❌ **Пропущенных ключей (несовпадение размера): {size_mismatch_keys} шт.**")
        
        # Мы всё равно сохраняем UNet, так как большая часть весов перенесена
        new_unet.save_pretrained(new_path)
        print(f"\n✅ Новая UNet (с перенесенными весами) сохранена по пути: {new_path}")
        return new_unet
        
    # Блок на случай, если чудом ошибки не возникло (для полноты)
    print("\n--- Отчет о переносе весов ---")
    print(f"✅ Успешно перенесены совпадающие веса (основная часть UNet).")
    print(f"💡 **УСПЕШНО перенесенных ключей (совпадающий размер): {matching_keys} шт.**")
    print(f"❌ **Пропущенных ключей (несовпадение размера): {size_mismatch_keys} шт.**")
    #print(f"❌ Слои, требующие переобучения (измененный размер): {len(incompatible_keys.unexpected_keys)}")
    
    new_unet.save_pretrained(new_path)
    print(f"\n✅ Новая UNet сохранена по пути: {new_path}")
    
    return new_unet

# --- ВЫПОЛНЕНИЕ ---
transferred_unet = transfer_unet_weights_fix(OLD_UNET_PATH, NEW_UNET_PATH)