import torch from diffusers import UNet2DConditionModel import os from safetensors.torch import load_file as safe_load # --- КОНСТАНТЫ --- OLD_UNET_PATH = "/workspace/sdxs/unet_old" NEW_UNET_PATH = "/workspace/sdxs/unet" def transfer_unet_weights_fix(old_path: str, new_path: str): print(f"Загрузка новой UNet (Конфиг 2048D) из: {new_path}") new_unet = UNet2DConditionModel.from_pretrained(new_path, low_cpu_mem_usage=False) # 2. Прямая загрузка state dict старой модели (с обработкой формата) print(f"Прямая загрузка весов старой UNet из: {old_path}") weights_file = next((f for f in os.listdir(old_path) if f.endswith(('.safetensors', '.bin'))), None) if not weights_file: print("Ошибка: не найден файл весов (.safetensors или .bin) в старой папке.") return None if weights_file.endswith('.safetensors'): print(f"Обнаружен файл Safetensors ({weights_file}). Использую safe_load.") old_state_dict = safe_load(f"{old_path}/{weights_file}") elif weights_file.endswith('.bin'): print(f"Обнаружен файл PyTorch (.bin) ({weights_file}). Использую torch.load с weights_only=False.") old_state_dict = torch.load( f"{old_path}/{weights_file}", map_location='cpu', weights_only=False ) else: print(f"Ошибка: Не удалось загрузить файл {weights_file}. Проверьте формат.") return None if "state_dict" in old_state_dict: old_state_dict = old_state_dict["state_dict"] # 3. Перенос весов с обработкой RuntimeError print("Начало переноса весов с пропуском несовместимых слоев...") # --- НОВОЕ: Предварительный подсчет совпадающих ключей --- total_keys = len(old_state_dict) matching_keys = 0 size_mismatch_keys = 0 # Используем ключи новой модели для сравнения, чтобы убедиться, # что ключи, которые отсутствуют в старой, не учитываются в "перенесенных". new_keys = new_unet.state_dict().keys() for name, old_param in old_state_dict.items(): if name in new_keys: new_param = new_unet.state_dict()[name] # Считаем только те, где имя и размер совпали if old_param.shape == new_param.shape: matching_keys += 1 # Считаем те, где имя совпало, но размер изменился (mismatch) else: size_mismatch_keys += 1 # Ключи, которые есть в старой, но нет в новой (unexpected), # также будут пропущены. Но нас в первую очередь интересуют size mismatch. # Пытаемся загрузить веса, зная, что это вызовет RuntimeError try: # Запуск actual переноса. Ключи будут перенесены. new_unet.load_state_dict(old_state_dict, strict=False) except RuntimeError as e: # Это ожидаемый блок! Он ловит RuntimeError, вызванный mismatch size. # Веса, которые совпали, УЖЕ перенесены в new_unet. print("\n--- Отчет о переносе весов ---") print("⚠️ Обнаружен ожидаемый **RuntimeError** из-за несовпадения размеров. Это нормально!") print(f"💡 **УСПЕШНО перенесенных ключей (совпадающий размер): {matching_keys} шт.**") print(f"❌ **Пропущенных ключей (несовпадение размера): {size_mismatch_keys} шт.**") # Мы всё равно сохраняем UNet, так как большая часть весов перенесена new_unet.save_pretrained(new_path) print(f"\n✅ Новая UNet (с перенесенными весами) сохранена по пути: {new_path}") return new_unet # Блок на случай, если чудом ошибки не возникло (для полноты) print("\n--- Отчет о переносе весов ---") print(f"✅ Успешно перенесены совпадающие веса (основная часть UNet).") print(f"💡 **УСПЕШНО перенесенных ключей (совпадающий размер): {matching_keys} шт.**") print(f"❌ **Пропущенных ключей (несовпадение размера): {size_mismatch_keys} шт.**") #print(f"❌ Слои, требующие переобучения (измененный размер): {len(incompatible_keys.unexpected_keys)}") new_unet.save_pretrained(new_path) print(f"\n✅ Новая UNet сохранена по пути: {new_path}") return new_unet # --- ВЫПОЛНЕНИЕ --- transferred_unet = transfer_unet_weights_fix(OLD_UNET_PATH, NEW_UNET_PATH)