Spaces:
Configuration error
Configuration error
| from typing import Any, Dict | |
| import torch | |
| import argparse | |
| from diffusers.loaders.lora_base import LoraBaseMixin | |
| from diffusers.models.modeling_utils import load_state_dict | |
| def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: | |
| state_dict = saved_dict | |
| if "model" in saved_dict.keys(): | |
| state_dict = state_dict["model"] | |
| if "module" in saved_dict.keys(): | |
| state_dict = state_dict["module"] | |
| if "state_dict" in saved_dict.keys(): | |
| state_dict = state_dict["state_dict"] | |
| return state_dict | |
| LORA_KEYS_RENAME = { | |
| 'attention.query_key_value.matrix_A.0': 'attn1.to_q.lora_A.weight', | |
| 'attention.query_key_value.matrix_A.1': 'attn1.to_k.lora_A.weight', | |
| 'attention.query_key_value.matrix_A.2': 'attn1.to_v.lora_A.weight', | |
| 'attention.query_key_value.matrix_B.0': 'attn1.to_q.lora_B.weight', | |
| 'attention.query_key_value.matrix_B.1': 'attn1.to_k.lora_B.weight', | |
| 'attention.query_key_value.matrix_B.2': 'attn1.to_v.lora_B.weight', | |
| 'attention.dense.matrix_A.0': 'attn1.to_out.0.lora_A.weight', | |
| 'attention.dense.matrix_B.0': 'attn1.to_out.0.lora_B.weight' | |
| } | |
| PREFIX_KEY = "model.diffusion_model." | |
| SAT_UNIT_KEY = "layers" | |
| LORA_PREFIX_KEY = "transformer_blocks" | |
| def export_lora_weight(ckpt_path,lora_save_directory): | |
| merge_original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) | |
| lora_state_dict = {} | |
| for key in list(merge_original_state_dict.keys()): | |
| new_key = key[len(PREFIX_KEY) :] | |
| for special_key, lora_keys in LORA_KEYS_RENAME.items(): | |
| if new_key.endswith(special_key): | |
| new_key = new_key.replace(special_key, lora_keys) | |
| new_key = new_key.replace(SAT_UNIT_KEY, LORA_PREFIX_KEY) | |
| lora_state_dict[new_key] = merge_original_state_dict[key] | |
| # final length should be 240 | |
| if len(lora_state_dict) != 240: | |
| raise ValueError("lora_state_dict length is not 240") | |
| lora_state_dict.keys() | |
| LoraBaseMixin.write_lora_layers( | |
| state_dict=lora_state_dict, | |
| save_directory=lora_save_directory, | |
| is_main_process=True, | |
| weight_name=None, | |
| save_function=None, | |
| safe_serialization=True | |
| ) | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--sat_pt_path", type=str, required=True, help="Path to original sat transformer checkpoint" | |
| ) | |
| parser.add_argument("--lora_save_directory", type=str, required=True, help="Path where converted lora should be saved") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = get_args() | |
| export_lora_weight(args.sat_pt_path, args.lora_save_directory) | |