sdxs / src /conv2048.py
recoilme's picture
2512
9614bab
import torch
from diffusers import UNet2DConditionModel
import os
from safetensors.torch import load_file as safe_load
# --- КОНСТАНТЫ ---
OLD_UNET_PATH = "/workspace/sdxs/unet_old"
NEW_UNET_PATH = "/workspace/sdxs/unet"
def transfer_unet_weights_fix(old_path: str, new_path: str):
print(f"Загрузка новой UNet (Конфиг 2048D) из: {new_path}")
new_unet = UNet2DConditionModel.from_pretrained(new_path, low_cpu_mem_usage=False)
# 2. Прямая загрузка state dict старой модели (с обработкой формата)
print(f"Прямая загрузка весов старой UNet из: {old_path}")
weights_file = next((f for f in os.listdir(old_path) if f.endswith(('.safetensors', '.bin'))), None)
if not weights_file:
print("Ошибка: не найден файл весов (.safetensors или .bin) в старой папке.")
return None
if weights_file.endswith('.safetensors'):
print(f"Обнаружен файл Safetensors ({weights_file}). Использую safe_load.")
old_state_dict = safe_load(f"{old_path}/{weights_file}")
elif weights_file.endswith('.bin'):
print(f"Обнаружен файл PyTorch (.bin) ({weights_file}). Использую torch.load с weights_only=False.")
old_state_dict = torch.load(
f"{old_path}/{weights_file}",
map_location='cpu',
weights_only=False
)
else:
print(f"Ошибка: Не удалось загрузить файл {weights_file}. Проверьте формат.")
return None
if "state_dict" in old_state_dict:
old_state_dict = old_state_dict["state_dict"]
# 3. Перенос весов с обработкой RuntimeError
print("Начало переноса весов с пропуском несовместимых слоев...")
# --- НОВОЕ: Предварительный подсчет совпадающих ключей ---
total_keys = len(old_state_dict)
matching_keys = 0
size_mismatch_keys = 0
# Используем ключи новой модели для сравнения, чтобы убедиться,
# что ключи, которые отсутствуют в старой, не учитываются в "перенесенных".
new_keys = new_unet.state_dict().keys()
for name, old_param in old_state_dict.items():
if name in new_keys:
new_param = new_unet.state_dict()[name]
# Считаем только те, где имя и размер совпали
if old_param.shape == new_param.shape:
matching_keys += 1
# Считаем те, где имя совпало, но размер изменился (mismatch)
else:
size_mismatch_keys += 1
# Ключи, которые есть в старой, но нет в новой (unexpected),
# также будут пропущены. Но нас в первую очередь интересуют size mismatch.
# Пытаемся загрузить веса, зная, что это вызовет RuntimeError
try:
# Запуск actual переноса. Ключи будут перенесены.
new_unet.load_state_dict(old_state_dict, strict=False)
except RuntimeError as e:
# Это ожидаемый блок! Он ловит RuntimeError, вызванный mismatch size.
# Веса, которые совпали, УЖЕ перенесены в new_unet.
print("\n--- Отчет о переносе весов ---")
print("⚠️ Обнаружен ожидаемый **RuntimeError** из-за несовпадения размеров. Это нормально!")
print(f"💡 **УСПЕШНО перенесенных ключей (совпадающий размер): {matching_keys} шт.**")
print(f"❌ **Пропущенных ключей (несовпадение размера): {size_mismatch_keys} шт.**")
# Мы всё равно сохраняем UNet, так как большая часть весов перенесена
new_unet.save_pretrained(new_path)
print(f"\n✅ Новая UNet (с перенесенными весами) сохранена по пути: {new_path}")
return new_unet
# Блок на случай, если чудом ошибки не возникло (для полноты)
print("\n--- Отчет о переносе весов ---")
print(f"✅ Успешно перенесены совпадающие веса (основная часть UNet).")
print(f"💡 **УСПЕШНО перенесенных ключей (совпадающий размер): {matching_keys} шт.**")
print(f"❌ **Пропущенных ключей (несовпадение размера): {size_mismatch_keys} шт.**")
#print(f"❌ Слои, требующие переобучения (измененный размер): {len(incompatible_keys.unexpected_keys)}")
new_unet.save_pretrained(new_path)
print(f"\n✅ Новая UNet сохранена по пути: {new_path}")
return new_unet
# --- ВЫПОЛНЕНИЕ ---
transferred_unet = transfer_unet_weights_fix(OLD_UNET_PATH, NEW_UNET_PATH)