|
|
import torch |
|
|
from diffusers import UNet2DConditionModel |
|
|
import os |
|
|
from safetensors.torch import load_file as safe_load |
|
|
|
|
|
|
|
|
OLD_UNET_PATH = "/workspace/sdxs/unet_old" |
|
|
NEW_UNET_PATH = "/workspace/sdxs/unet" |
|
|
|
|
|
def transfer_unet_weights_fix(old_path: str, new_path: str): |
|
|
|
|
|
print(f"Загрузка новой UNet (Конфиг 2048D) из: {new_path}") |
|
|
new_unet = UNet2DConditionModel.from_pretrained(new_path, low_cpu_mem_usage=False) |
|
|
|
|
|
|
|
|
print(f"Прямая загрузка весов старой UNet из: {old_path}") |
|
|
|
|
|
weights_file = next((f for f in os.listdir(old_path) if f.endswith(('.safetensors', '.bin'))), None) |
|
|
if not weights_file: |
|
|
print("Ошибка: не найден файл весов (.safetensors или .bin) в старой папке.") |
|
|
return None |
|
|
|
|
|
if weights_file.endswith('.safetensors'): |
|
|
print(f"Обнаружен файл Safetensors ({weights_file}). Использую safe_load.") |
|
|
old_state_dict = safe_load(f"{old_path}/{weights_file}") |
|
|
elif weights_file.endswith('.bin'): |
|
|
print(f"Обнаружен файл PyTorch (.bin) ({weights_file}). Использую torch.load с weights_only=False.") |
|
|
old_state_dict = torch.load( |
|
|
f"{old_path}/{weights_file}", |
|
|
map_location='cpu', |
|
|
weights_only=False |
|
|
) |
|
|
else: |
|
|
print(f"Ошибка: Не удалось загрузить файл {weights_file}. Проверьте формат.") |
|
|
return None |
|
|
|
|
|
if "state_dict" in old_state_dict: |
|
|
old_state_dict = old_state_dict["state_dict"] |
|
|
|
|
|
|
|
|
print("Начало переноса весов с пропуском несовместимых слоев...") |
|
|
|
|
|
|
|
|
total_keys = len(old_state_dict) |
|
|
matching_keys = 0 |
|
|
size_mismatch_keys = 0 |
|
|
|
|
|
|
|
|
|
|
|
new_keys = new_unet.state_dict().keys() |
|
|
|
|
|
for name, old_param in old_state_dict.items(): |
|
|
if name in new_keys: |
|
|
new_param = new_unet.state_dict()[name] |
|
|
|
|
|
if old_param.shape == new_param.shape: |
|
|
matching_keys += 1 |
|
|
|
|
|
else: |
|
|
size_mismatch_keys += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
new_unet.load_state_dict(old_state_dict, strict=False) |
|
|
|
|
|
except RuntimeError as e: |
|
|
|
|
|
|
|
|
|
|
|
print("\n--- Отчет о переносе весов ---") |
|
|
print("⚠️ Обнаружен ожидаемый **RuntimeError** из-за несовпадения размеров. Это нормально!") |
|
|
|
|
|
print(f"💡 **УСПЕШНО перенесенных ключей (совпадающий размер): {matching_keys} шт.**") |
|
|
print(f"❌ **Пропущенных ключей (несовпадение размера): {size_mismatch_keys} шт.**") |
|
|
|
|
|
|
|
|
new_unet.save_pretrained(new_path) |
|
|
print(f"\n✅ Новая UNet (с перенесенными весами) сохранена по пути: {new_path}") |
|
|
return new_unet |
|
|
|
|
|
|
|
|
print("\n--- Отчет о переносе весов ---") |
|
|
print(f"✅ Успешно перенесены совпадающие веса (основная часть UNet).") |
|
|
print(f"💡 **УСПЕШНО перенесенных ключей (совпадающий размер): {matching_keys} шт.**") |
|
|
print(f"❌ **Пропущенных ключей (несовпадение размера): {size_mismatch_keys} шт.**") |
|
|
|
|
|
|
|
|
new_unet.save_pretrained(new_path) |
|
|
print(f"\n✅ Новая UNet сохранена по пути: {new_path}") |
|
|
|
|
|
return new_unet |
|
|
|
|
|
|
|
|
transferred_unet = transfer_unet_weights_fix(OLD_UNET_PATH, NEW_UNET_PATH) |