| import torch |
| from diffusers import AsymmetricAutoencoderKL |
| from safetensors.torch import save_file |
|
|
| |
| KEY_MAP = { |
| |
| "encoder.conv_in": "encoder.conv_in", |
| "encoder.conv_norm_out": "encoder.norm_out", |
| "encoder.conv_out": "encoder.conv_out", |
| |
| |
| "encoder.down_blocks.0.resnets.0": "encoder.down.0.block.0", |
| "encoder.down_blocks.0.resnets.1": "encoder.down.0.block.1", |
| "encoder.down_blocks.0.downsamplers.0": "encoder.down.0.downsample", |
| |
| "encoder.down_blocks.1.resnets.0": "encoder.down.1.block.0", |
| "encoder.down_blocks.1.resnets.1": "encoder.down.1.block.1", |
| "encoder.down_blocks.1.downsamplers.0": "encoder.down.1.downsample", |
| |
| "encoder.down_blocks.2.resnets.0": "encoder.down.2.block.0", |
| "encoder.down_blocks.2.resnets.1": "encoder.down.2.block.1", |
| "encoder.down_blocks.2.downsamplers.0": "encoder.down.2.downsample", |
| |
| "encoder.down_blocks.3.resnets.0": "encoder.down.3.block.0", |
| "encoder.down_blocks.3.resnets.1": "encoder.down.3.block.1", |
| |
| |
| "encoder.mid_block.resnets.0": "encoder.mid.block_1", |
| "encoder.mid_block.attentions.0": "encoder.mid.attn_1", |
| "encoder.mid_block.resnets.1": "encoder.mid.block_2", |
| |
| |
| "decoder.conv_in": "decoder.conv_in", |
| "decoder.conv_norm_out": "decoder.norm_out", |
| "decoder.conv_out": "decoder.conv_out", |
| |
| |
| "decoder.mid_block.resnets.0": "decoder.mid.block_1", |
| "decoder.mid_block.attentions.0": "decoder.mid.attn_1", |
| "decoder.mid_block.resnets.1": "decoder.mid.block_2", |
| |
| |
| |
| "decoder.up_blocks.0.resnets.0": "decoder.up.3.block.0", |
| "decoder.up_blocks.0.resnets.1": "decoder.up.3.block.1", |
| "decoder.up_blocks.0.resnets.2": "decoder.up.3.block.2", |
| "decoder.up_blocks.0.resnets.3": "decoder.up.3.block.3", |
| "decoder.up_blocks.0.upsamplers.0": "decoder.up.3.upsample", |
| |
| |
| "decoder.up_blocks.1.resnets.0": "decoder.up.2.block.0", |
| "decoder.up_blocks.1.resnets.1": "decoder.up.2.block.1", |
| "decoder.up_blocks.1.resnets.2": "decoder.up.2.block.2", |
| "decoder.up_blocks.1.resnets.3": "decoder.up.2.block.3", |
| "decoder.up_blocks.1.upsamplers.0": "decoder.up.2.upsample", |
| |
| |
| "decoder.up_blocks.2.resnets.0": "decoder.up.1.block.0", |
| "decoder.up_blocks.2.resnets.1": "decoder.up.1.block.1", |
| "decoder.up_blocks.2.resnets.2": "decoder.up.1.block.2", |
| "decoder.up_blocks.2.resnets.3": "decoder.up.1.block.3", |
| "decoder.up_blocks.2.upsamplers.0": "decoder.up.1.upsample", |
| |
| |
| "decoder.up_blocks.3.resnets.0": "decoder.up.0.block.0", |
| "decoder.up_blocks.3.resnets.1": "decoder.up.0.block.1", |
| "decoder.up_blocks.3.resnets.2": "decoder.up.0.block.2", |
| "decoder.up_blocks.3.resnets.3": "decoder.up.0.block.3", |
| } |
|
|
| |
| LAYER_RENAMES = { |
| "conv_shortcut": "nin_shortcut", |
| "group_norm": "norm", |
| "to_q": "q", |
| "to_k": "k", |
| "to_v": "v", |
| "to_out.0": "proj_out", |
| } |
|
|
| def convert_key(key): |
| """Конвертирует ключ из формата Diffusers в формат A1111""" |
| |
| if "condition_encoder" in key: |
| return None |
| |
| |
| for diffusers_prefix, a1111_prefix in KEY_MAP.items(): |
| if key.startswith(diffusers_prefix): |
| new_key = key.replace(diffusers_prefix, a1111_prefix, 1) |
| |
| for old, new in LAYER_RENAMES.items(): |
| new_key = new_key.replace(old, new) |
| return new_key |
| |
| |
| return key |
|
|
| |
| vae = AsymmetricAutoencoderKL.from_pretrained("./asymmetric_vae") |
| state_dict = vae.state_dict() |
|
|
| |
| converted_state_dict = {} |
| skipped_keys = [] |
|
|
| for key, value in state_dict.items(): |
| new_key = convert_key(key) |
| |
| if new_key is None: |
| skipped_keys.append(key) |
| continue |
| |
| |
| if "attn_1" in new_key and any(x in new_key for x in ["q.weight", "k.weight", "v.weight", "proj_out.weight"]): |
| |
| if value.dim() == 2: |
| value = value.unsqueeze(-1).unsqueeze(-1) |
| |
| converted_state_dict[new_key] = value |
|
|
| |
| save_file(converted_state_dict, "sdxl_vae_asymm_a1111.safetensors") |
|
|
| print(f"Конвертировано {len(converted_state_dict)} ключей") |
| print(f"Пропущено {len(skipped_keys)} ключей (condition_encoder и др.)") |
|
|
| if skipped_keys: |
| print("\nПропущенные ключи:") |
| for key in skipped_keys[:10]: |
| print(f" - {key}") |
|
|
| print("\nПримеры конвертированных ключей:") |
| for i, (old, new) in enumerate(zip(list(state_dict.keys())[:5], list(converted_state_dict.keys())[:5])): |
| if old not in skipped_keys: |
| print(f"{old} -> {new}") |
|
|
| |
| print("\nAttention веса после конвертации:") |
| for key, value in converted_state_dict.items(): |
| if "attn_1" in key and "weight" in key: |
| print(f"{key}: {value.shape}") |