| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Conversion script for the LDM checkpoints.""" |
| |
|
| | import argparse |
| | import json |
| | import os |
| |
|
| | import torch |
| | from transformers.file_utils import has_file |
| |
|
| | from diffusers import UNet2DConditionModel, UNet2DModel |
| |
|
| |
|
| | do_only_config = False |
| | do_only_weights = True |
| | do_only_renaming = False |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument( |
| | "--repo_path", |
| | default=None, |
| | type=str, |
| | required=True, |
| | help="The config json file corresponding to the architecture.", |
| | ) |
| |
|
| | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") |
| |
|
| | args = parser.parse_args() |
| |
|
| | config_parameters_to_change = { |
| | "image_size": "sample_size", |
| | "num_res_blocks": "layers_per_block", |
| | "block_channels": "block_out_channels", |
| | "down_blocks": "down_block_types", |
| | "up_blocks": "up_block_types", |
| | "downscale_freq_shift": "freq_shift", |
| | "resnet_num_groups": "norm_num_groups", |
| | "resnet_act_fn": "act_fn", |
| | "resnet_eps": "norm_eps", |
| | "num_head_channels": "attention_head_dim", |
| | } |
| |
|
| | key_parameters_to_change = { |
| | "time_steps": "time_proj", |
| | "mid": "mid_block", |
| | "downsample_blocks": "down_blocks", |
| | "upsample_blocks": "up_blocks", |
| | } |
| |
|
| | subfolder = "" if has_file(args.repo_path, "config.json") else "unet" |
| |
|
| | with open(os.path.join(args.repo_path, subfolder, "config.json"), "r", encoding="utf-8") as reader: |
| | text = reader.read() |
| | config = json.loads(text) |
| |
|
| | if do_only_config: |
| | for key in config_parameters_to_change.keys(): |
| | config.pop(key, None) |
| |
|
| | if has_file(args.repo_path, "config.json"): |
| | model = UNet2DModel(**config) |
| | else: |
| | class_name = UNet2DConditionModel if "ldm-text2im-large-256" in args.repo_path else UNet2DModel |
| | model = class_name(**config) |
| |
|
| | if do_only_config: |
| | model.save_config(os.path.join(args.repo_path, subfolder)) |
| |
|
| | config = dict(model.config) |
| |
|
| | if do_only_renaming: |
| | for key, value in config_parameters_to_change.items(): |
| | if key in config: |
| | config[value] = config[key] |
| | del config[key] |
| |
|
| | config["down_block_types"] = [k.replace("UNetRes", "") for k in config["down_block_types"]] |
| | config["up_block_types"] = [k.replace("UNetRes", "") for k in config["up_block_types"]] |
| |
|
| | if do_only_weights: |
| | state_dict = torch.load(os.path.join(args.repo_path, subfolder, "diffusion_pytorch_model.bin")) |
| |
|
| | new_state_dict = {} |
| | for param_key, param_value in state_dict.items(): |
| | if param_key.endswith(".op.bias") or param_key.endswith(".op.weight"): |
| | continue |
| | has_changed = False |
| | for key, new_key in key_parameters_to_change.items(): |
| | if not has_changed and param_key.split(".")[0] == key: |
| | new_state_dict[".".join([new_key] + param_key.split(".")[1:])] = param_value |
| | has_changed = True |
| | if not has_changed: |
| | new_state_dict[param_key] = param_value |
| |
|
| | model.load_state_dict(new_state_dict) |
| | model.save_pretrained(os.path.join(args.repo_path, subfolder)) |
| |
|