| | import argparse |
| |
|
| | import safetensors.torch |
| |
|
| | from diffusers import AutoencoderTiny |
| |
|
| |
|
| | """ |
| | Example - From the diffusers root directory: |
| | |
| | Download the weights: |
| | ```sh |
| | $ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_encoder.safetensors |
| | $ wget -q https://huggingface.co/madebyollin/taesd/resolve/main/taesd_decoder.safetensors |
| | ``` |
| | |
| | Convert the model: |
| | ```sh |
| | $ python scripts/convert_tiny_autoencoder_to_diffusers.py \ |
| | --encoder_ckpt_path taesd_encoder.safetensors \ |
| | --decoder_ckpt_path taesd_decoder.safetensors \ |
| | --dump_path taesd-diffusers |
| | ``` |
| | """ |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") |
| | parser.add_argument( |
| | "--encoder_ckpt_path", |
| | default=None, |
| | type=str, |
| | required=True, |
| | help="Path to the encoder ckpt.", |
| | ) |
| | parser.add_argument( |
| | "--decoder_ckpt_path", |
| | default=None, |
| | type=str, |
| | required=True, |
| | help="Path to the decoder ckpt.", |
| | ) |
| | parser.add_argument( |
| | "--use_safetensors", action="store_true", help="Whether to serialize in the safetensors format." |
| | ) |
| | args = parser.parse_args() |
| |
|
| | print("Loading the original state_dicts of the encoder and the decoder...") |
| | encoder_state_dict = safetensors.torch.load_file(args.encoder_ckpt_path) |
| | decoder_state_dict = safetensors.torch.load_file(args.decoder_ckpt_path) |
| |
|
| | print("Populating the state_dicts in the diffusers format...") |
| | tiny_autoencoder = AutoencoderTiny() |
| | new_state_dict = {} |
| |
|
| | |
| | for k in encoder_state_dict: |
| | new_state_dict.update({f"encoder.layers.{k}": encoder_state_dict[k]}) |
| |
|
| | |
| | for k in decoder_state_dict: |
| | layer_id = int(k.split(".")[0]) - 1 |
| | new_k = str(layer_id) + "." + ".".join(k.split(".")[1:]) |
| | new_state_dict.update({f"decoder.layers.{new_k}": decoder_state_dict[k]}) |
| |
|
| | |
| | |
| | tiny_autoencoder.load_state_dict(new_state_dict) |
| | print("Population successful, serializing...") |
| | tiny_autoencoder.save_pretrained(args.dump_path, safe_serialization=args.use_safetensors) |
| |
|