In [7]:
from diffusers.models import AsymmetricAutoencoderKL
import torch

config = {
    "_class_name": "AsymmetricAutoencoderKL",
    "act_fn": "silu",
    "down_block_out_channels": [128, 256, 512, 512],
    "down_block_types": [
        "DownEncoderBlock2D",
        "DownEncoderBlock2D",
        "DownEncoderBlock2D",
        "DownEncoderBlock2D",
    ],
    "in_channels": 3,
    "latent_channels": 16,
    #"latents_mean": [0.2539, 0.1431, 0.1484, -0.3048, -0.0985, -0.162, 0.1403, 0.2034, -0.1419, 0.2646, 0.0655, 0.0061, 0.1555, 0.0506, 0.0129, -0.1948],
    #"latents_std": [0.8123, 0.7376, 0.7354, 1.1827, 0.8387, 0.8735, 0.8705, 0.8142, 0.8076, 0.7409, 0.7655, 0.8731, 0.8087, 0.7058, 0.8087, 0.7615],
    #"layers_per_block": 2,
    "norm_num_groups": 32,
    "out_channels": 3,
    "sample_size": 1024,
    "scaling_factor": 1,
    "shift_factor": 0,
    "up_block_out_channels": [128, 256, 512, 768, 768],
    "up_block_types": [
        "UpDecoderBlock2D",
        "UpDecoderBlock2D",
        "UpDecoderBlock2D",
        "UpDecoderBlock2D",
        "UpDecoderBlock2D",
    ],
}

# Преобразуем списки mean и std в тензоры
#latents_mean = torch.tensor(config["latents_mean"])
#latents_std = torch.tensor(config["latents_std"])

# Создаем модель
vae = AsymmetricAutoencoderKL(
    act_fn=config["act_fn"],
    down_block_out_channels=config["down_block_out_channels"],
    down_block_types=config["down_block_types"],
    in_channels=config["in_channels"],
    latent_channels=config["latent_channels"],
    norm_num_groups=config["norm_num_groups"],
    out_channels=config["out_channels"],
    sample_size=config["sample_size"],
    scaling_factor=config["scaling_factor"],
    up_block_out_channels=config["up_block_out_channels"],
    up_block_types=config["up_block_types"],
    layers_per_down_block = 2,
    layers_per_up_block = 2
)

# Устанавливаем mean и std для латентов
#vae.latents_mean = latents_mean
#vae.latents_std = latents_std

vae.save_pretrained("simple_vae")
print(vae)

AsymmetricAutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
    

In [8]:
import torch
from diffusers import AsymmetricAutoencoderKL,AutoencoderKL
from tqdm import tqdm
import torch.nn.init as init

def log(message):
    print(message)

def initialize_mid_block_weights(state_dict, device, dtype):
    # Инициализация весов для mid block 0 с  размерностью 512
    state_dict['encoder.mid_block.attentions.0.group_norm.weight'] = torch.ones(512, device=device, dtype=dtype)
    state_dict['encoder.mid_block.attentions.0.group_norm.bias'] = torch.zeros(512, device=device, dtype=dtype)
    
    # Удаляем ключи для второго блока внимания, так как он не существует в архитектуре
    #if 'encoder.mid_block.attentions.1.group_norm.weight' in state_dict:
    #    del state_dict['encoder.mid_block.attentions.1.group_norm.weight']
    #if 'encoder.mid_block.attentions.1.group_norm.bias' in state_dict:
    #    del state_dict['encoder.mid_block.attentions.1.group_norm.bias']
    
    return state_dict

def main():
    checkpoint_path_old = "AiArtLab/sdxs"
    checkpoint_path_new = "simple_vae"
    device = "cuda"
    dtype = torch.float16

    # Загрузка моделей
    old_unet = AutoencoderKL.from_pretrained(checkpoint_path_old,subfolder="vae",variant="fp16").to(device, dtype=dtype)
    new_unet = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)

    old_state_dict = old_unet.state_dict()
    new_state_dict = new_unet.state_dict()

    transferred_state_dict = {}
    transfer_stats = {
        "перенесено": 0,
        "несовпадение_размеров": 0,
        "пропущено": 0
    }

    transferred_keys = set()

    # Обрабатываем каждый ключ старой модели
    for old_key in tqdm(old_state_dict.keys(), desc="Перенос весов"):
        new_key = old_key

        if new_key in new_state_dict:
            if old_state_dict[old_key].shape == new_state_dict[new_key].shape:
                transferred_state_dict[new_key] = old_state_dict[old_key].clone()
                transferred_keys.add(new_key)
                transfer_stats["перенесено"] += 1
            else:
                log(f"✗ Несовпадение размеров: {old_key} ({old_state_dict[old_key].shape}) -> {new_key} ({new_state_dict[new_key].shape})")
                transfer_stats["несовпадение_размеров"] += 1
        else:
            log(f"? Ключ не найден в новой модели: {old_key} -> {old_state_dict[old_key].shape}")
            transfer_stats["пропущено"] += 1

    # Обновляем состояние новой модели перенесенными весами
    new_state_dict.update(transferred_state_dict)
    
    # Инициализируем веса для нового mid блока
    new_state_dict = initialize_mid_block_weights(new_state_dict, device, dtype)
    
    new_unet.load_state_dict(new_state_dict)
    new_unet.save_pretrained("vae")

    # Получаем список неперенесенных ключей
    non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)

    print("Статистика переноса:", transfer_stats)
    print("Неперенесенные ключи в новой модели:")
    for key in non_transferred_keys:
        print(key)

if __name__ == "__main__":
    main()

The config attributes {'block_out_channels': [128, 256, 512, 768, 768], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.
Перенос весов: 100%|██████████| 228/228 [00:00<00:00, 192647.32it/s]


✗ Несовпадение размеров: decoder.conv_in.weight (torch.Size([512, 16, 3, 3])) -> decoder.conv_in.weight (torch.Size([768, 16, 3, 3]))
✗ Несовпадение размеров: decoder.conv_in.bias (torch.Size([512])) -> decoder.conv_in.bias (torch.Size([768]))
✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.norm1.weight (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.norm1.weight (torch.Size([768]))
✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.norm1.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.norm1.bias (torch.Size([768]))
✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.conv1.weight (torch.Size([512, 512, 3, 3])) -> decoder.up_blocks.0.resnets.0.conv1.weight (torch.Size([768, 768, 3, 3]))
✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.conv1.bias (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.conv1.bias (torch.Size([768]))
✗ Несовпадение размеров: decoder.up_blocks.0.resnets.0.norm2.weight (torch.Size([512])) -> decoder.up_blocks.0.resnets.0.norm2

In [3]:

import torch
from diffusers import AsymmetricAutoencoderKL
from tqdm import tqdm

def normalize_weights3(state_dict, latents_mean, latents_std):
    device = next(iter(state_dict.values())).device
    dtype = next(iter(state_dict.values())).dtype
    
    # Преобразуем в тензоры
    latents_mean = torch.tensor(latents_mean, device=device, dtype=dtype)
    latents_std = torch.tensor(latents_std, device=device, dtype=dtype)

    # Нормализация для quant_conv (32 -> 32 каналов)
    if 'quant_conv.weight' in state_dict:
        weight = state_dict['quant_conv.weight']  # [32, 32, 1, 1]
        # Применяем нормализацию к выходным каналам
        for i in range(weight.size(0)):
            weight[i] = weight[i] / latents_std[i % len(latents_std)]
    
    if 'quant_conv.bias' in state_dict:
        bias = state_dict['quant_conv.bias']  # [32]
        for i in range(bias.size(0)):
            bias[i] = (bias[i] - latents_mean[i % len(latents_mean)]) / latents_std[i % len(latents_std)]

    # Нормализация для post_quant_conv (16 -> 16 каналов)
    if 'post_quant_conv.weight' in state_dict:
        weight = state_dict['post_quant_conv.weight']  # [16, 16, 1, 1]
        # Применяем нормализацию к входным каналам
        for i in range(weight.size(1)):
            weight[:, i] = weight[:, i] * latents_std[i]
    
    if 'post_quant_conv.bias' in state_dict:
        bias = state_dict['post_quant_conv.bias']  # [16]
        for i in range(bias.size(0)):
            bias[i] = bias[i] * latents_std[i] + latents_mean[i]

    return state_dict

def normalize_weights(state_dict, latents_mean, latents_std):
    device = next(iter(state_dict.values())).device
    dtype = next(iter(state_dict.values())).dtype
    
    # Преобразуем в тензоры
    latents_mean = torch.tensor(latents_mean, device=device, dtype=dtype)
    latents_std = torch.tensor(latents_std, device=device, dtype=dtype)

    # Нормализация для quant_conv (32 -> 32 каналов)
    # На выходе энкодера: (x - mean) / std
    if 'quant_conv.weight' in state_dict:
        weight = state_dict['quant_conv.weight']  # [32, 32, 1, 1]
        # Нормализуем выходные каналы
        for i in range(weight.size(0)):
            if i < len(latents_std):
                weight[i] = weight[i] / latents_std[i]
    
    if 'quant_conv.bias' in state_dict:
        bias = state_dict['quant_conv.bias']  # [32]
        for i in range(bias.size(0)):
            if i < len(latents_mean):
                # Сначала применяем сдвиг, потом масштабирование
                bias[i] = -latents_mean[i] / latents_std[i]

    # Нормализация для post_quant_conv (16 -> 16 каналов)
    # На входе декодера: x * std + mean
    if 'post_quant_conv.weight' in state_dict:
        weight = state_dict['post_quant_conv.weight']  # [16, 16, 1, 1]
        # Нормализуем входные каналы
        for i in range(weight.size(1)):
            if i < len(latents_std):
                weight[:, i] = weight[:, i] * latents_std[i]
    
    if 'post_quant_conv.bias' in state_dict:
        bias = state_dict['post_quant_conv.bias']  # [16]
        for i in range(bias.size(0)):
            if i < len(latents_mean):
                bias[i] = bias[i] + latents_mean[i]

    return state_dict

def main():
    # Путь к модели
    model_path = "vae"
    device = "cuda"
    dtype = torch.float16

    # Ваши mean и std
    latents_mean = [0.2539, 0.1431, 0.1484, -0.3048, -0.0985, -0.162, 0.1403, 0.2034, -0.1419, 0.2646, 0.0655, 0.0061, 0.1555, 0.0506, 0.0129, -0.1948]

    latents_std = [0.8123, 0.7376, 0.7354, 1.1827, 0.8387, 0.8735, 0.8705, 0.8142, 0.8076, 0.7409, 0.7655, 0.8731, 0.8087, 0.7058, 0.8087, 0.7615]

    # Загружаем модель
    model = AsymmetricAutoencoderKL.from_pretrained(model_path).to(device, dtype=dtype)
    
    # Получаем state dict
    state_dict = model.state_dict()

    # Выводим информацию о весах до нормализации
    print("\nWeights before normalization:")
    for key in ['quant_conv.weight', 'quant_conv.bias', 'post_quant_conv.weight', 'post_quant_conv.bias']:
        if key in state_dict:
            print(f"{key}: {state_dict[key].shape}")

    # Нормализуем веса
    normalized_state_dict = normalize_weights(state_dict, latents_mean, latents_std)
    normalized_state_dict = initialize_mid_block_weights(normalized_state_dict, device, dtype)

    # Загружаем нормализованные веса обратно в модель
    model.load_state_dict(normalized_state_dict)

    # Сохраняем модель
    model.save_pretrained("vaenorm")

if __name__ == "__main__":
    main()

The config attributes {'block_out_channels': [128, 256, 512, 512], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.



Weights before normalization:
quant_conv.weight: torch.Size([32, 32, 1, 1])
quant_conv.bias: torch.Size([32])
post_quant_conv.weight: torch.Size([16, 16, 1, 1])
post_quant_conv.bias: torch.Size([16])


In [6]:
import torch
from diffusers import AsymmetricAutoencoderKL,AutoencoderKL
from tqdm import tqdm
import torch.nn.init as init

def log(message):
    print(message)

def initialize_mid_block_weights(state_dict, device, dtype):
    # Инициализация весов для mid block 0 с  размерностью 512
    state_dict['encoder.mid_block.attentions.0.group_norm.weight'] = torch.ones(512, device=device, dtype=dtype)
    state_dict['encoder.mid_block.attentions.0.group_norm.bias'] = torch.zeros(512, device=device, dtype=dtype)
    
    # Удаляем ключи для второго блока внимания, так как он не существует в архитектуре
    #if 'encoder.mid_block.attentions.1.group_norm.weight' in state_dict:
    #    del state_dict['encoder.mid_block.attentions.1.group_norm.weight']
    #if 'encoder.mid_block.attentions.1.group_norm.bias' in state_dict:
    #    del state_dict['encoder.mid_block.attentions.1.group_norm.bias']
    
    return state_dict
    
def interpolate_tensor(tensor, target_shape):
    """Интерполяция тензора до целевой формы"""
    print(f"Interpolating tensor of shape {tensor.shape} to target shape {target_shape}")
    
    if len(tensor.shape) == 4:  # Для свёрточных слоев
        out_channels, in_channels, k1, k2 = target_shape
        
        # Создаем новый тензор нужного размера
        result = torch.zeros(target_shape, device=tensor.device, dtype=tensor.dtype)
        
        # Копируем существующие веса с масштабированием
        min_out = min(tensor.shape[0], out_channels)
        min_in = min(tensor.shape[1], in_channels)
        
        print(f"Copying existing weights: min_out={min_out}, min_in={min_in}")
        
        # Копируем существующие веса
        result[:min_out, :min_in, :, :] = tensor[:min_out, :min_in, :, :]
        
        # Заполняем новые выходные каналы
        if out_channels > min_out:
            print(f"Extending output channels from {min_out} to {out_channels}")
            result[min_out:, :min_in, :, :] = result[min_out-1:min_out, :min_in, :, :].repeat(out_channels-min_out, 1, 1, 1)
        
        # Заполняем новые входные каналы
        if in_channels > min_in:
            print(f"Extending input channels from {min_in} to {in_channels}")
            result[:, min_in:, :, :] = result[:, min_in-1:min_in, :, :].repeat(1, in_channels-min_in, 1, 1)
        
        return result
    
    else:  # Для bias и других 1D тензоров
        # Создаем новый тензор нужного размера
        result = torch.zeros(target_shape, device=tensor.device, dtype=tensor.dtype)
        
        # Копируем существующие значения
        min_size = min(tensor.shape[0], target_shape[0])
        result[:min_size] = tensor[:min_size]
        
        # Заполняем оставшиеся значения
        if target_shape[0] > min_size:
            print(f"Extending 1D tensor from {min_size} to {target_shape[0]}")
            result[min_size:] = result[min_size-1]
        
        return result

def should_interpolate(key):
    """Определяет, нужно ли интерполировать веса для данного ключа"""
    return any(x in key for x in [
        'conv1', 'conv2', 'conv_shortcut',  # свёрточные слои
        'norm1', 'norm2', 'group_norm',     # нормализационные слои
        'bias', 'weight'                    # веса и смещения
    ])
    
def main():
    checkpoint_path_old = "AiArtLab/sdxs"
    checkpoint_path_new = "simple_vae"
    device = "cuda"
    dtype = torch.float16

    # Загрузка моделей
    old_unet = AutoencoderKL.from_pretrained(checkpoint_path_old,subfolder="vae",variant="fp16").to(device, dtype=dtype)
    new_unet = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)

    old_state_dict = old_unet.state_dict()
    new_state_dict = new_unet.state_dict()

    transferred_state_dict = {}
    transfer_stats = {
        "перенесено": 0,
        "несовпадение_размеров": 0,
        "пропущено": 0
    }

    transferred_keys = set()

    # Сначала найдем все ключи блока 3 и их интерполированные значения
    block3_interpolated = {}
    for key in old_state_dict:
        if 'decoder.up_blocks.3.' in key and should_interpolate(key):
            old_tensor = old_state_dict[key]
            new_key = key
            if new_key in new_state_dict and old_tensor.shape != new_state_dict[new_key].shape:
                print(f"\nProcessing {key}")
                print(f"Old shape: {old_tensor.shape}")
                print(f"Target shape: {new_state_dict[new_key].shape}")
                interpolated = interpolate_tensor(old_tensor, new_state_dict[new_key].shape)
                block3_interpolated[key] = interpolated

    # Обрабатываем каждый ключ новой модели
    for new_key in tqdm(new_state_dict.keys(), desc="Перенос весов"):
        # Случай 1: Прямое соответствие ключей и размеров
        if new_key in old_state_dict and old_state_dict[new_key].shape == new_state_dict[new_key].shape:
            transferred_state_dict[new_key] = old_state_dict[new_key].clone()
            transferred_keys.add(new_key)
            transfer_stats["перенесено"] += 1
            continue

        # Случай 2: Блоки 4 и 5 (копируем интерполированные веса блока 3)
        if ('decoder.up_blocks.4.' in new_key or 'decoder.up_blocks.5.' in new_key) and should_interpolate(new_key):
            source_key = new_key.replace('decoder.up_blocks.4.', 'decoder.up_blocks.3.')
            source_key = source_key.replace('decoder.up_blocks.5.', 'decoder.up_blocks.3.')
            
            if source_key in block3_interpolated:
                transferred_state_dict[new_key] = block3_interpolated[source_key].clone()
                transferred_keys.add(new_key)
                transfer_stats["перенесено"] += 1
                continue

        # Случай 3: Несовпадение размеров в блоке 3
        if 'decoder.up_blocks.3.' in new_key and new_key in block3_interpolated:
            transferred_state_dict[new_key] = block3_interpolated[new_key].clone()
            transferred_keys.add(new_key)
            transfer_stats["перенесено"] += 1
            continue

        # Если ключ не обработан - помечаем как пропущенный
        transfer_stats["пропущено"] += 1
        log(f"? Ключ пропущен: {new_key} -> {new_state_dict[new_key].shape}")

        # Если ключ не обработан - помечаем как пропущенный
        transfer_stats["пропущено"] += 1
        log(f"? Ключ пропущен: {new_key} -> {new_state_dict[new_key].shape}")

    # Обновляем состояние новой модели
    new_state_dict.update(transferred_state_dict)

    # Инициализируем веса для нового mid блока
    new_state_dict = initialize_mid_block_weights(new_state_dict, device, dtype)

    
    new_unet.load_state_dict(new_state_dict)
    new_unet.save_pretrained("vae")

    # Выводим статистику
    print("\nСтатистика переноса:", transfer_stats)
    print("\nНеперенесенные ключи:")
    non_transferred_keys = sorted(set(new_state_dict.keys()) - transferred_keys)
    for key in non_transferred_keys:
        print(key)

if __name__ == "__main__":
    main()

The config attributes {'block_out_channels': [128, 256, 512, 768, 768], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.



Processing decoder.up_blocks.3.resnets.0.norm1.weight
Old shape: torch.Size([256])
Target shape: torch.Size([512])
Interpolating tensor of shape torch.Size([256]) to target shape torch.Size([512])
Extending 1D tensor from 256 to 512

Processing decoder.up_blocks.3.resnets.0.norm1.bias
Old shape: torch.Size([256])
Target shape: torch.Size([512])
Interpolating tensor of shape torch.Size([256]) to target shape torch.Size([512])
Extending 1D tensor from 256 to 512

Processing decoder.up_blocks.3.resnets.0.conv1.weight
Old shape: torch.Size([128, 256, 3, 3])
Target shape: torch.Size([256, 512, 3, 3])
Interpolating tensor of shape torch.Size([128, 256, 3, 3]) to target shape torch.Size([256, 512, 3, 3])
Copying existing weights: min_out=128, min_in=256
Extending output channels from 128 to 256
Extending input channels from 256 to 512

Processing decoder.up_blocks.3.resnets.0.conv1.bias
Old shape: torch.Size([128])
Target shape: torch.Size([256])
Interpolating tensor of shape torch.Size([128

Перенос весов: 100%|██████████| 286/286 [00:00<00:00, 163407.02it/s]

? Ключ пропущен: encoder.mid_block.attentions.0.group_norm.weight -> torch.Size([512])
? Ключ пропущен: encoder.mid_block.attentions.0.group_norm.weight -> torch.Size([512])
? Ключ пропущен: encoder.mid_block.attentions.0.group_norm.bias -> torch.Size([512])
? Ключ пропущен: encoder.mid_block.attentions.0.group_norm.bias -> torch.Size([512])
? Ключ пропущен: encoder.mid_block.attentions.0.to_q.weight -> torch.Size([512, 512])
? Ключ пропущен: encoder.mid_block.attentions.0.to_q.weight -> torch.Size([512, 512])
? Ключ пропущен: encoder.mid_block.attentions.0.to_q.bias -> torch.Size([512])
? Ключ пропущен: encoder.mid_block.attentions.0.to_q.bias -> torch.Size([512])
? Ключ пропущен: encoder.mid_block.attentions.0.to_k.weight -> torch.Size([512, 512])
? Ключ пропущен: encoder.mid_block.attentions.0.to_k.weight -> torch.Size([512, 512])
? Ключ пропущен: encoder.mid_block.attentions.0.to_k.bias -> torch.Size([512])
? Ключ пропущен: encoder.mid_block.attentions.0.to_k.bias -> torch.Size([51




RuntimeError: Error(s) in loading state_dict for AsymmetricAutoencoderKL:
	size mismatch for decoder.up_blocks.4.resnets.0.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up_blocks.4.resnets.0.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up_blocks.4.resnets.0.conv1.weight: copying a param with shape torch.Size([256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 3, 3]).
	size mismatch for decoder.up_blocks.4.resnets.0.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.0.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for decoder.up_blocks.4.resnets.0.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.0.conv_shortcut.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 256, 1, 1]).
	size mismatch for decoder.up_blocks.4.resnets.0.conv_shortcut.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.1.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.1.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.1.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for decoder.up_blocks.4.resnets.1.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.1.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.1.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.1.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for decoder.up_blocks.4.resnets.1.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.2.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.2.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.2.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for decoder.up_blocks.4.resnets.2.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.2.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.2.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up_blocks.4.resnets.2.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for decoder.up_blocks.4.resnets.2.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).

In [7]:
import torch
from diffusers import AsymmetricAutoencoderKL,AutoencoderKL
from tqdm import tqdm

def log(message):
    print(message)

def interpolate_tensor(tensor, target_shape):
    """Интерполяция тензора до целевой формы"""
    print(f"Interpolating tensor of shape {tensor.shape} to target shape {target_shape}")
    
    # Создаем новый тензор нужного размера
    result = torch.zeros(target_shape, device=tensor.device, dtype=tensor.dtype)
    
    if len(tensor.shape) == 1:  # Для 1D тензоров (bias)
        min_size = min(tensor.shape[0], target_shape[0])
        result[:min_size] = tensor[:min_size]
        if target_shape[0] > min_size:
            result[min_size:] = tensor[min_size-1]
    else:  # Для всех остальных тензоров
        # Просто копируем то, что можем, остальное оставляем нулями
        if len(tensor.shape) == len(target_shape):
            for i in range(len(tensor.shape)):
                if tensor.shape[i] > target_shape[i]:
                    tensor = tensor.narrow(i, 0, target_shape[i])
            result[tuple(slice(0, s) for s in tensor.shape)] = tensor
    
    return result

def main():
    checkpoint_path_old = "AiArtLab/sdxs"
    checkpoint_path_new = "simple_vae"
    device = "cuda"
    dtype = torch.float16

    print("Loading models...")
    old_unet = AutoencoderKL.from_pretrained(checkpoint_path_old,subfolder="vae",variant="fp16").to(device, dtype=dtype)
    new_unet = AsymmetricAutoencoderKL.from_pretrained(checkpoint_path_new).to(device, dtype=dtype)

    old_state_dict = old_unet.state_dict()
    new_state_dict = new_unet.state_dict()

    transferred_state_dict = {}
    transfer_stats = {"перенесено": 0, "несовпадение_размеров": 0, "пропущено": 0}

    print("\nProcessing weights...")
    for new_key in tqdm(new_state_dict.keys()):
        print(f"\nProcessing key: {new_key}")
        print(f"Target shape: {new_state_dict[new_key].shape}")

        # Для блоков 4 и 5 используем веса из блока 3
        if 'decoder.up_blocks.4.' in new_key or 'decoder.up_blocks.5.' in new_key:
            source_key = new_key.replace('decoder.up_blocks.4.', 'decoder.up_blocks.3.')
            source_key = source_key.replace('decoder.up_blocks.5.', 'decoder.up_blocks.3.')
            
            if source_key in old_state_dict:
                print(f"Found source key: {source_key}")
                source_tensor = old_state_dict[source_key]
                print(f"Source shape: {source_tensor.shape}")
                
                if source_tensor.shape != new_state_dict[new_key].shape:
                    print("Shapes don't match, interpolating...")
                    transferred_state_dict[new_key] = interpolate_tensor(source_tensor, new_state_dict[new_key].shape)
                else:
                    print("Shapes match, copying directly...")
                    transferred_state_dict[new_key] = source_tensor.clone()
                transfer_stats["перенесено"] += 1
                continue

        # Для остальных ключей пробуем прямой перенос
        if new_key in old_state_dict:
            if old_state_dict[new_key].shape == new_state_dict[new_key].shape:
                print("Direct copy...")
                transferred_state_dict[new_key] = old_state_dict[new_key].clone()
                transfer_stats["перенесено"] += 1
            else:
                print(f"Size mismatch: {old_state_dict[new_key].shape} vs {new_state_dict[new_key].shape}")
                transfer_stats["несовпадение_размеров"] += 1
        else:
            print("Key not found in source model")
            transfer_stats["пропущено"] += 1

    print("\nUpdating state dict...")
    new_state_dict.update(transferred_state_dict)

    print("\nLoading state dict...")
    new_unet.load_state_dict(new_state_dict)

    print("\nSaving model...")
    new_unet.save_pretrained("vae")

    print("\nTransfer statistics:", transfer_stats)

if __name__ == "__main__":
    main()

Loading models...


The config attributes {'block_out_channels': [128, 256, 512, 768, 768], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.



Processing weights...


100%|██████████| 286/286 [00:00<00:00, 106458.20it/s]



Processing key: encoder.conv_in.weight
Target shape: torch.Size([128, 3, 3, 3])
Direct copy...

Processing key: encoder.conv_in.bias
Target shape: torch.Size([128])
Direct copy...

Processing key: encoder.down_blocks.0.resnets.0.norm1.weight
Target shape: torch.Size([128])
Direct copy...

Processing key: encoder.down_blocks.0.resnets.0.norm1.bias
Target shape: torch.Size([128])
Direct copy...

Processing key: encoder.down_blocks.0.resnets.0.conv1.weight
Target shape: torch.Size([128, 128, 3, 3])
Direct copy...

Processing key: encoder.down_blocks.0.resnets.0.conv1.bias
Target shape: torch.Size([128])
Direct copy...

Processing key: encoder.down_blocks.0.resnets.0.norm2.weight
Target shape: torch.Size([128])
Direct copy...

Processing key: encoder.down_blocks.0.resnets.0.norm2.bias
Target shape: torch.Size([128])
Direct copy...

Processing key: encoder.down_blocks.0.resnets.0.conv2.weight
Target shape: torch.Size([128, 128, 3, 3])
Direct copy...

Processing key: encoder.down_blocks.0.r