| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import re |
|
|
| import torch |
|
|
| from ..utils import is_peft_version, logging |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5): |
| |
| all_keys = list(state_dict.keys()) |
| sgm_patterns = ["input_blocks", "middle_block", "output_blocks"] |
|
|
| |
| is_in_sgm_format = False |
| for key in all_keys: |
| if any(p in key for p in sgm_patterns): |
| is_in_sgm_format = True |
| break |
|
|
| if not is_in_sgm_format: |
| return state_dict |
|
|
| |
| new_state_dict = {} |
| inner_block_map = ["resnets", "attentions", "upsamplers"] |
|
|
| |
| input_block_ids, middle_block_ids, output_block_ids = set(), set(), set() |
|
|
| for layer in all_keys: |
| if "text" in layer: |
| new_state_dict[layer] = state_dict.pop(layer) |
| else: |
| layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) |
| if sgm_patterns[0] in layer: |
| input_block_ids.add(layer_id) |
| elif sgm_patterns[1] in layer: |
| middle_block_ids.add(layer_id) |
| elif sgm_patterns[2] in layer: |
| output_block_ids.add(layer_id) |
| else: |
| raise ValueError(f"Checkpoint not supported because layer {layer} not supported.") |
|
|
| input_blocks = { |
| layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] |
| for layer_id in input_block_ids |
| } |
| middle_blocks = { |
| layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key] |
| for layer_id in middle_block_ids |
| } |
| output_blocks = { |
| layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key] |
| for layer_id in output_block_ids |
| } |
|
|
| |
| for i in input_block_ids: |
| block_id = (i - 1) // (unet_config.layers_per_block + 1) |
| layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1) |
|
|
| for key in input_blocks[i]: |
| inner_block_id = int(key.split(delimiter)[block_slice_pos]) |
| inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers" |
| inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0" |
| new_key = delimiter.join( |
| key.split(delimiter)[: block_slice_pos - 1] |
| + [str(block_id), inner_block_key, inner_layers_in_block] |
| + key.split(delimiter)[block_slice_pos + 1 :] |
| ) |
| new_state_dict[new_key] = state_dict.pop(key) |
|
|
| for i in middle_block_ids: |
| key_part = None |
| if i == 0: |
| key_part = [inner_block_map[0], "0"] |
| elif i == 1: |
| key_part = [inner_block_map[1], "0"] |
| elif i == 2: |
| key_part = [inner_block_map[0], "1"] |
| else: |
| raise ValueError(f"Invalid middle block id {i}.") |
|
|
| for key in middle_blocks[i]: |
| new_key = delimiter.join( |
| key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:] |
| ) |
| new_state_dict[new_key] = state_dict.pop(key) |
|
|
| for i in output_block_ids: |
| block_id = i // (unet_config.layers_per_block + 1) |
| layer_in_block_id = i % (unet_config.layers_per_block + 1) |
|
|
| for key in output_blocks[i]: |
| inner_block_id = int(key.split(delimiter)[block_slice_pos]) |
| inner_block_key = inner_block_map[inner_block_id] |
| inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0" |
| new_key = delimiter.join( |
| key.split(delimiter)[: block_slice_pos - 1] |
| + [str(block_id), inner_block_key, inner_layers_in_block] |
| + key.split(delimiter)[block_slice_pos + 1 :] |
| ) |
| new_state_dict[new_key] = state_dict.pop(key) |
|
|
| if len(state_dict) > 0: |
| raise ValueError("At this point all state dict entries have to be converted.") |
|
|
| return new_state_dict |
|
|
|
|
| def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"): |
| """ |
| Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict. |
| |
| Args: |
| state_dict (`dict`): The state dict to convert. |
| unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet". |
| text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to |
| "text_encoder". |
| |
| Returns: |
| `tuple`: A tuple containing the converted state dict and a dictionary of alphas. |
| """ |
| unet_state_dict = {} |
| te_state_dict = {} |
| te2_state_dict = {} |
| network_alphas = {} |
|
|
| |
| dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict) |
| dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict) |
| dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict) |
| if dora_present_in_unet or dora_present_in_te or dora_present_in_te2: |
| if is_peft_version("<", "0.9.0"): |
| raise ValueError( |
| "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." |
| ) |
|
|
| |
| all_lora_keys = list(state_dict.keys()) |
| for key in all_lora_keys: |
| if not key.endswith("lora_down.weight"): |
| continue |
|
|
| |
| lora_name = key.split(".")[0] |
|
|
| |
| lora_name_up = lora_name + ".lora_up.weight" |
| lora_name_alpha = lora_name + ".alpha" |
|
|
| |
| if lora_name.startswith("lora_unet_"): |
| diffusers_name = _convert_unet_lora_key(key) |
|
|
| |
| unet_state_dict[diffusers_name] = state_dict.pop(key) |
| unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) |
|
|
| |
| if dora_present_in_unet: |
| dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." |
| unet_state_dict[ |
| diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") |
| ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) |
|
|
| |
| elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): |
| diffusers_name = _convert_text_encoder_lora_key(key, lora_name) |
|
|
| |
| if lora_name.startswith(("lora_te_", "lora_te1_")): |
| te_state_dict[diffusers_name] = state_dict.pop(key) |
| te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) |
| else: |
| te2_state_dict[diffusers_name] = state_dict.pop(key) |
| te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) |
|
|
| |
| if dora_present_in_te or dora_present_in_te2: |
| dora_scale_key_to_replace_te = ( |
| "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." |
| ) |
| if lora_name.startswith(("lora_te_", "lora_te1_")): |
| te_state_dict[ |
| diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") |
| ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) |
| elif lora_name.startswith("lora_te2_"): |
| te2_state_dict[ |
| diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") |
| ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) |
|
|
| |
| if lora_name_alpha in state_dict: |
| alpha = state_dict.pop(lora_name_alpha).item() |
| network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha)) |
|
|
| |
| if len(state_dict) > 0: |
| raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}") |
|
|
| logger.info("Non-diffusers checkpoint detected.") |
|
|
| |
| unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} |
| te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()} |
| te2_state_dict = ( |
| {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()} |
| if len(te2_state_dict) > 0 |
| else None |
| ) |
| if te2_state_dict is not None: |
| te_state_dict.update(te2_state_dict) |
|
|
| new_state_dict = {**unet_state_dict, **te_state_dict} |
| return new_state_dict, network_alphas |
|
|
|
|
| def _convert_unet_lora_key(key): |
| """ |
| Converts a U-Net LoRA key to a Diffusers compatible key. |
| """ |
| diffusers_name = key.replace("lora_unet_", "").replace("_", ".") |
|
|
| |
| diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") |
| diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") |
| diffusers_name = diffusers_name.replace("middle.block", "mid_block") |
| diffusers_name = diffusers_name.replace("mid.block", "mid_block") |
| diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") |
| diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") |
| diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") |
| diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") |
| diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") |
| diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") |
| diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") |
| diffusers_name = diffusers_name.replace("proj.in", "proj_in") |
| diffusers_name = diffusers_name.replace("proj.out", "proj_out") |
| diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") |
|
|
| |
| if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name: |
| pattern = r"\.\d+(?=\D*$)" |
| diffusers_name = re.sub(pattern, "", diffusers_name, count=1) |
| if ".in." in diffusers_name: |
| diffusers_name = diffusers_name.replace("in.layers.2", "conv1") |
| if ".out." in diffusers_name: |
| diffusers_name = diffusers_name.replace("out.layers.3", "conv2") |
| if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: |
| diffusers_name = diffusers_name.replace("op", "conv") |
| if "skip" in diffusers_name: |
| diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") |
|
|
| |
| if "time.emb.proj" in diffusers_name: |
| diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj") |
| if "conv.shortcut" in diffusers_name: |
| diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut") |
|
|
| |
| if "transformer_blocks" in diffusers_name: |
| if "attn1" in diffusers_name or "attn2" in diffusers_name: |
| diffusers_name = diffusers_name.replace("attn1", "attn1.processor") |
| diffusers_name = diffusers_name.replace("attn2", "attn2.processor") |
| elif "ff" in diffusers_name: |
| pass |
| elif any(key in diffusers_name for key in ("proj_in", "proj_out")): |
| pass |
| else: |
| pass |
|
|
| return diffusers_name |
|
|
|
|
| def _convert_text_encoder_lora_key(key, lora_name): |
| """ |
| Converts a text encoder LoRA key to a Diffusers compatible key. |
| """ |
| if lora_name.startswith(("lora_te_", "lora_te1_")): |
| key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_" |
| else: |
| key_to_replace = "lora_te2_" |
|
|
| diffusers_name = key.replace(key_to_replace, "").replace("_", ".") |
| diffusers_name = diffusers_name.replace("text.model", "text_model") |
| diffusers_name = diffusers_name.replace("self.attn", "self_attn") |
| diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") |
| diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") |
| diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") |
| diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") |
| diffusers_name = diffusers_name.replace("text.projection", "text_projection") |
|
|
| if "self_attn" in diffusers_name or "text_projection" in diffusers_name: |
| pass |
| elif "mlp" in diffusers_name: |
| |
| |
| diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") |
| return diffusers_name |
|
|
|
|
| def _get_alpha_name(lora_name_alpha, diffusers_name, alpha): |
| """ |
| Gets the correct alpha name for the Diffusers model. |
| """ |
| if lora_name_alpha.startswith("lora_unet_"): |
| prefix = "unet." |
| elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): |
| prefix = "text_encoder." |
| else: |
| prefix = "text_encoder_2." |
| new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" |
| return {new_name: alpha} |
|
|
|
|
| |
| |
| |
| def _convert_kohya_flux_lora_to_diffusers(state_dict): |
| def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): |
| if sds_key + ".lora_down.weight" not in sds_sd: |
| return |
| down_weight = sds_sd.pop(sds_key + ".lora_down.weight") |
|
|
| |
| rank = down_weight.shape[0] |
| alpha = sds_sd.pop(sds_key + ".alpha").item() |
| scale = alpha / rank |
|
|
| |
| scale_down = scale |
| scale_up = 1.0 |
| while scale_down * 2 < scale_up: |
| scale_down *= 2 |
| scale_up /= 2 |
|
|
| ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down |
| ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up |
|
|
| def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): |
| if sds_key + ".lora_down.weight" not in sds_sd: |
| return |
| down_weight = sds_sd.pop(sds_key + ".lora_down.weight") |
| up_weight = sds_sd.pop(sds_key + ".lora_up.weight") |
| sd_lora_rank = down_weight.shape[0] |
|
|
| |
| alpha = sds_sd.pop(sds_key + ".alpha") |
| scale = alpha / sd_lora_rank |
|
|
| |
| scale_down = scale |
| scale_up = 1.0 |
| while scale_down * 2 < scale_up: |
| scale_down *= 2 |
| scale_up /= 2 |
|
|
| down_weight = down_weight * scale_down |
| up_weight = up_weight * scale_up |
|
|
| |
| num_splits = len(ait_keys) |
| if dims is None: |
| dims = [up_weight.shape[0] // num_splits] * num_splits |
| else: |
| assert sum(dims) == up_weight.shape[0] |
|
|
| |
| is_sparse = False |
| if sd_lora_rank % num_splits == 0: |
| ait_rank = sd_lora_rank // num_splits |
| is_sparse = True |
| i = 0 |
| for j in range(len(dims)): |
| for k in range(len(dims)): |
| if j == k: |
| continue |
| is_sparse = is_sparse and torch.all( |
| up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0 |
| ) |
| i += dims[j] |
| if is_sparse: |
| logger.info(f"weight is sparse: {sds_key}") |
|
|
| |
| ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] |
| ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] |
| if not is_sparse: |
| |
| ait_sd.update({k: down_weight for k in ait_down_keys}) |
|
|
| |
| ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) |
| else: |
| |
| ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) |
|
|
| |
| i = 0 |
| for j in range(len(dims)): |
| ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous() |
| i += dims[j] |
|
|
| def _convert_sd_scripts_to_ai_toolkit(sds_sd): |
| ait_sd = {} |
| for i in range(19): |
| _convert_to_ai_toolkit( |
| sds_sd, |
| ait_sd, |
| f"lora_unet_double_blocks_{i}_img_attn_proj", |
| f"transformer.transformer_blocks.{i}.attn.to_out.0", |
| ) |
| _convert_to_ai_toolkit_cat( |
| sds_sd, |
| ait_sd, |
| f"lora_unet_double_blocks_{i}_img_attn_qkv", |
| [ |
| f"transformer.transformer_blocks.{i}.attn.to_q", |
| f"transformer.transformer_blocks.{i}.attn.to_k", |
| f"transformer.transformer_blocks.{i}.attn.to_v", |
| ], |
| ) |
| _convert_to_ai_toolkit( |
| sds_sd, |
| ait_sd, |
| f"lora_unet_double_blocks_{i}_img_mlp_0", |
| f"transformer.transformer_blocks.{i}.ff.net.0.proj", |
| ) |
| _convert_to_ai_toolkit( |
| sds_sd, |
| ait_sd, |
| f"lora_unet_double_blocks_{i}_img_mlp_2", |
| f"transformer.transformer_blocks.{i}.ff.net.2", |
| ) |
| _convert_to_ai_toolkit( |
| sds_sd, |
| ait_sd, |
| f"lora_unet_double_blocks_{i}_img_mod_lin", |
| f"transformer.transformer_blocks.{i}.norm1.linear", |
| ) |
| _convert_to_ai_toolkit( |
| sds_sd, |
| ait_sd, |
| f"lora_unet_double_blocks_{i}_txt_attn_proj", |
| f"transformer.transformer_blocks.{i}.attn.to_add_out", |
| ) |
| _convert_to_ai_toolkit_cat( |
| sds_sd, |
| ait_sd, |
| f"lora_unet_double_blocks_{i}_txt_attn_qkv", |
| [ |
| f"transformer.transformer_blocks.{i}.attn.add_q_proj", |
| f"transformer.transformer_blocks.{i}.attn.add_k_proj", |
| f"transformer.transformer_blocks.{i}.attn.add_v_proj", |
| ], |
| ) |
| _convert_to_ai_toolkit( |
| sds_sd, |
| ait_sd, |
| f"lora_unet_double_blocks_{i}_txt_mlp_0", |
| f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", |
| ) |
| _convert_to_ai_toolkit( |
| sds_sd, |
| ait_sd, |
| f"lora_unet_double_blocks_{i}_txt_mlp_2", |
| f"transformer.transformer_blocks.{i}.ff_context.net.2", |
| ) |
| _convert_to_ai_toolkit( |
| sds_sd, |
| ait_sd, |
| f"lora_unet_double_blocks_{i}_txt_mod_lin", |
| f"transformer.transformer_blocks.{i}.norm1_context.linear", |
| ) |
|
|
| for i in range(38): |
| _convert_to_ai_toolkit_cat( |
| sds_sd, |
| ait_sd, |
| f"lora_unet_single_blocks_{i}_linear1", |
| [ |
| f"transformer.single_transformer_blocks.{i}.attn.to_q", |
| f"transformer.single_transformer_blocks.{i}.attn.to_k", |
| f"transformer.single_transformer_blocks.{i}.attn.to_v", |
| f"transformer.single_transformer_blocks.{i}.proj_mlp", |
| ], |
| dims=[3072, 3072, 3072, 12288], |
| ) |
| _convert_to_ai_toolkit( |
| sds_sd, |
| ait_sd, |
| f"lora_unet_single_blocks_{i}_linear2", |
| f"transformer.single_transformer_blocks.{i}.proj_out", |
| ) |
| _convert_to_ai_toolkit( |
| sds_sd, |
| ait_sd, |
| f"lora_unet_single_blocks_{i}_modulation_lin", |
| f"transformer.single_transformer_blocks.{i}.norm.linear", |
| ) |
|
|
| if len(sds_sd) > 0: |
| logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}") |
|
|
| return ait_sd |
|
|
| return _convert_sd_scripts_to_ai_toolkit(state_dict) |
|
|
|
|
| |
| |
| |
| def _convert_xlabs_flux_lora_to_diffusers(old_state_dict): |
| new_state_dict = {} |
| orig_keys = list(old_state_dict.keys()) |
|
|
| def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): |
| down_weight = sds_sd.pop(sds_key) |
| up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight")) |
|
|
| |
| num_splits = len(ait_keys) |
| if dims is None: |
| dims = [up_weight.shape[0] // num_splits] * num_splits |
| else: |
| assert sum(dims) == up_weight.shape[0] |
|
|
| |
| ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] |
| ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] |
|
|
| |
| ait_sd.update({k: down_weight for k in ait_down_keys}) |
|
|
| |
| ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) |
|
|
| for old_key in orig_keys: |
| |
| if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")): |
| block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1) |
| new_key = f"transformer.transformer_blocks.{block_num}" |
|
|
| if "processor.proj_lora1" in old_key: |
| new_key += ".attn.to_out.0" |
| elif "processor.proj_lora2" in old_key: |
| new_key += ".attn.to_add_out" |
| |
| elif "processor.qkv_lora2" in old_key and "up" not in old_key: |
| handle_qkv( |
| old_state_dict, |
| new_state_dict, |
| old_key, |
| [ |
| f"transformer.transformer_blocks.{block_num}.attn.add_q_proj", |
| f"transformer.transformer_blocks.{block_num}.attn.add_k_proj", |
| f"transformer.transformer_blocks.{block_num}.attn.add_v_proj", |
| ], |
| ) |
| |
| |
| elif "processor.qkv_lora1" in old_key and "up" not in old_key: |
| handle_qkv( |
| old_state_dict, |
| new_state_dict, |
| old_key, |
| [ |
| f"transformer.transformer_blocks.{block_num}.attn.to_q", |
| f"transformer.transformer_blocks.{block_num}.attn.to_k", |
| f"transformer.transformer_blocks.{block_num}.attn.to_v", |
| ], |
| ) |
| |
|
|
| if "down" in old_key: |
| new_key += ".lora_A.weight" |
| elif "up" in old_key: |
| new_key += ".lora_B.weight" |
|
|
| |
| elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"): |
| block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1) |
| new_key = f"transformer.single_transformer_blocks.{block_num}" |
|
|
| if "proj_lora1" in old_key or "proj_lora2" in old_key: |
| new_key += ".proj_out" |
| elif "qkv_lora1" in old_key or "qkv_lora2" in old_key: |
| new_key += ".norm.linear" |
|
|
| if "down" in old_key: |
| new_key += ".lora_A.weight" |
| elif "up" in old_key: |
| new_key += ".lora_B.weight" |
|
|
| else: |
| |
| new_key = old_key |
|
|
| |
| if "qkv" not in old_key: |
| new_state_dict[new_key] = old_state_dict.pop(old_key) |
|
|
| if len(old_state_dict) > 0: |
| raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") |
|
|
| return new_state_dict |
|
|