Spaces:
Running on Zero
Running on Zero
| # Copyright 2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| State dict utilities: utility methods for converting state dicts easily | |
| """ | |
| import enum | |
| import json | |
| from .import_utils import is_torch_available | |
| from .logging import get_logger | |
| if is_torch_available(): | |
| import torch | |
| logger = get_logger(__name__) | |
| class StateDictType(enum.Enum): | |
| """ | |
| The mode to use when converting state dicts. | |
| """ | |
| DIFFUSERS_OLD = "diffusers_old" | |
| KOHYA_SS = "kohya_ss" | |
| PEFT = "peft" | |
| DIFFUSERS = "diffusers" | |
| # We need to define a proper mapping for Unet since it uses different output keys than text encoder | |
| # e.g. to_q_lora -> q_proj / to_q | |
| UNET_TO_DIFFUSERS = { | |
| ".to_out_lora.up": ".to_out.0.lora_B", | |
| ".to_out_lora.down": ".to_out.0.lora_A", | |
| ".to_q_lora.down": ".to_q.lora_A", | |
| ".to_q_lora.up": ".to_q.lora_B", | |
| ".to_k_lora.down": ".to_k.lora_A", | |
| ".to_k_lora.up": ".to_k.lora_B", | |
| ".to_v_lora.down": ".to_v.lora_A", | |
| ".to_v_lora.up": ".to_v.lora_B", | |
| ".lora.up": ".lora_B", | |
| ".lora.down": ".lora_A", | |
| ".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector", | |
| } | |
| CONTROL_LORA_TO_DIFFUSERS = { | |
| ".to_q.down": ".to_q.lora_A.weight", | |
| ".to_q.up": ".to_q.lora_B.weight", | |
| ".to_k.down": ".to_k.lora_A.weight", | |
| ".to_k.up": ".to_k.lora_B.weight", | |
| ".to_v.down": ".to_v.lora_A.weight", | |
| ".to_v.up": ".to_v.lora_B.weight", | |
| ".to_out.0.down": ".to_out.0.lora_A.weight", | |
| ".to_out.0.up": ".to_out.0.lora_B.weight", | |
| ".ff.net.0.proj.down": ".ff.net.0.proj.lora_A.weight", | |
| ".ff.net.0.proj.up": ".ff.net.0.proj.lora_B.weight", | |
| ".ff.net.2.down": ".ff.net.2.lora_A.weight", | |
| ".ff.net.2.up": ".ff.net.2.lora_B.weight", | |
| ".proj_in.down": ".proj_in.lora_A.weight", | |
| ".proj_in.up": ".proj_in.lora_B.weight", | |
| ".proj_out.down": ".proj_out.lora_A.weight", | |
| ".proj_out.up": ".proj_out.lora_B.weight", | |
| ".conv.down": ".conv.lora_A.weight", | |
| ".conv.up": ".conv.lora_B.weight", | |
| **{f".conv{i}.down": f".conv{i}.lora_A.weight" for i in range(1, 3)}, | |
| **{f".conv{i}.up": f".conv{i}.lora_B.weight" for i in range(1, 3)}, | |
| "conv_in.down": "conv_in.lora_A.weight", | |
| "conv_in.up": "conv_in.lora_B.weight", | |
| ".conv_shortcut.down": ".conv_shortcut.lora_A.weight", | |
| ".conv_shortcut.up": ".conv_shortcut.lora_B.weight", | |
| **{f".linear_{i}.down": f".linear_{i}.lora_A.weight" for i in range(1, 3)}, | |
| **{f".linear_{i}.up": f".linear_{i}.lora_B.weight" for i in range(1, 3)}, | |
| "time_emb_proj.down": "time_emb_proj.lora_A.weight", | |
| "time_emb_proj.up": "time_emb_proj.lora_B.weight", | |
| } | |
| DIFFUSERS_TO_PEFT = { | |
| ".q_proj.lora_linear_layer.up": ".q_proj.lora_B", | |
| ".q_proj.lora_linear_layer.down": ".q_proj.lora_A", | |
| ".k_proj.lora_linear_layer.up": ".k_proj.lora_B", | |
| ".k_proj.lora_linear_layer.down": ".k_proj.lora_A", | |
| ".v_proj.lora_linear_layer.up": ".v_proj.lora_B", | |
| ".v_proj.lora_linear_layer.down": ".v_proj.lora_A", | |
| ".out_proj.lora_linear_layer.up": ".out_proj.lora_B", | |
| ".out_proj.lora_linear_layer.down": ".out_proj.lora_A", | |
| ".lora_linear_layer.up": ".lora_B", | |
| ".lora_linear_layer.down": ".lora_A", | |
| "text_projection.lora.down.weight": "text_projection.lora_A.weight", | |
| "text_projection.lora.up.weight": "text_projection.lora_B.weight", | |
| } | |
| DIFFUSERS_OLD_TO_PEFT = { | |
| ".to_q_lora.up": ".q_proj.lora_B", | |
| ".to_q_lora.down": ".q_proj.lora_A", | |
| ".to_k_lora.up": ".k_proj.lora_B", | |
| ".to_k_lora.down": ".k_proj.lora_A", | |
| ".to_v_lora.up": ".v_proj.lora_B", | |
| ".to_v_lora.down": ".v_proj.lora_A", | |
| ".to_out_lora.up": ".out_proj.lora_B", | |
| ".to_out_lora.down": ".out_proj.lora_A", | |
| ".lora_linear_layer.up": ".lora_B", | |
| ".lora_linear_layer.down": ".lora_A", | |
| } | |
| PEFT_TO_DIFFUSERS = { | |
| ".q_proj.lora_B": ".q_proj.lora_linear_layer.up", | |
| ".q_proj.lora_A": ".q_proj.lora_linear_layer.down", | |
| ".k_proj.lora_B": ".k_proj.lora_linear_layer.up", | |
| ".k_proj.lora_A": ".k_proj.lora_linear_layer.down", | |
| ".v_proj.lora_B": ".v_proj.lora_linear_layer.up", | |
| ".v_proj.lora_A": ".v_proj.lora_linear_layer.down", | |
| ".out_proj.lora_B": ".out_proj.lora_linear_layer.up", | |
| ".out_proj.lora_A": ".out_proj.lora_linear_layer.down", | |
| "to_k.lora_A": "to_k.lora.down", | |
| "to_k.lora_B": "to_k.lora.up", | |
| "to_q.lora_A": "to_q.lora.down", | |
| "to_q.lora_B": "to_q.lora.up", | |
| "to_v.lora_A": "to_v.lora.down", | |
| "to_v.lora_B": "to_v.lora.up", | |
| "to_out.0.lora_A": "to_out.0.lora.down", | |
| "to_out.0.lora_B": "to_out.0.lora.up", | |
| } | |
| DIFFUSERS_OLD_TO_DIFFUSERS = { | |
| ".to_q_lora.up": ".q_proj.lora_linear_layer.up", | |
| ".to_q_lora.down": ".q_proj.lora_linear_layer.down", | |
| ".to_k_lora.up": ".k_proj.lora_linear_layer.up", | |
| ".to_k_lora.down": ".k_proj.lora_linear_layer.down", | |
| ".to_v_lora.up": ".v_proj.lora_linear_layer.up", | |
| ".to_v_lora.down": ".v_proj.lora_linear_layer.down", | |
| ".to_out_lora.up": ".out_proj.lora_linear_layer.up", | |
| ".to_out_lora.down": ".out_proj.lora_linear_layer.down", | |
| ".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector", | |
| ".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector", | |
| ".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector", | |
| ".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector", | |
| } | |
| PEFT_TO_KOHYA_SS = { | |
| "lora_A": "lora_down", | |
| "lora_B": "lora_up", | |
| # This is not a comprehensive dict as kohya format requires replacing `.` with `_` in keys, | |
| # adding prefixes and adding alpha values | |
| # Check `convert_state_dict_to_kohya` for more | |
| } | |
| PEFT_STATE_DICT_MAPPINGS = { | |
| StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT, | |
| StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT, | |
| } | |
| DIFFUSERS_STATE_DICT_MAPPINGS = { | |
| StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS, | |
| StateDictType.PEFT: PEFT_TO_DIFFUSERS, | |
| } | |
| KOHYA_STATE_DICT_MAPPINGS = {StateDictType.PEFT: PEFT_TO_KOHYA_SS} | |
| KEYS_TO_ALWAYS_REPLACE = { | |
| ".processor.": ".", | |
| } | |
| def convert_state_dict(state_dict, mapping): | |
| r""" | |
| Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values. | |
| Args: | |
| state_dict (`dict[str, torch.Tensor]`): | |
| The state dict to convert. | |
| mapping (`dict[str, str]`): | |
| The mapping to use for conversion, the mapping should be a dictionary with the following structure: | |
| - key: the pattern to replace | |
| - value: the pattern to replace with | |
| Returns: | |
| converted_state_dict (`dict`) | |
| The converted state dict. | |
| """ | |
| converted_state_dict = {} | |
| for k, v in state_dict.items(): | |
| # First, filter out the keys that we always want to replace | |
| for pattern in KEYS_TO_ALWAYS_REPLACE.keys(): | |
| if pattern in k: | |
| new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern] | |
| k = k.replace(pattern, new_pattern) | |
| for pattern in mapping.keys(): | |
| if pattern in k: | |
| new_pattern = mapping[pattern] | |
| k = k.replace(pattern, new_pattern) | |
| break | |
| converted_state_dict[k] = v | |
| return converted_state_dict | |
| def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs): | |
| r""" | |
| Converts a state dict to the PEFT format The state dict can be from previous diffusers format (`OLD_DIFFUSERS`), or | |
| new diffusers format (`DIFFUSERS`). The method only supports the conversion from diffusers old/new to PEFT for now. | |
| Args: | |
| state_dict (`dict[str, torch.Tensor]`): | |
| The state dict to convert. | |
| original_type (`StateDictType`, *optional*): | |
| The original type of the state dict, if not provided, the method will try to infer it automatically. | |
| """ | |
| if original_type is None: | |
| # Old diffusers to PEFT | |
| if any("to_out_lora" in k for k in state_dict.keys()): | |
| original_type = StateDictType.DIFFUSERS_OLD | |
| elif any("lora_linear_layer" in k for k in state_dict.keys()): | |
| original_type = StateDictType.DIFFUSERS | |
| else: | |
| raise ValueError("Could not automatically infer state dict type") | |
| if original_type not in PEFT_STATE_DICT_MAPPINGS.keys(): | |
| raise ValueError(f"Original type {original_type} is not supported") | |
| mapping = PEFT_STATE_DICT_MAPPINGS[original_type] | |
| return convert_state_dict(state_dict, mapping) | |
| def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): | |
| r""" | |
| Converts a state dict to new diffusers format. The state dict can be from previous diffusers format | |
| (`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will | |
| return the state dict as is. | |
| The method only supports the conversion from diffusers old, PEFT to diffusers new for now. | |
| Args: | |
| state_dict (`dict[str, torch.Tensor]`): | |
| The state dict to convert. | |
| original_type (`StateDictType`, *optional*): | |
| The original type of the state dict, if not provided, the method will try to infer it automatically. | |
| kwargs (`dict`, *args*): | |
| Additional arguments to pass to the method. | |
| - **adapter_name**: For example, in case of PEFT, some keys will be prepended | |
| with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in | |
| `get_peft_model_state_dict` method: | |
| https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 | |
| but we add it here in case we don't want to rely on that method. | |
| """ | |
| peft_adapter_name = kwargs.pop("adapter_name", None) | |
| if peft_adapter_name is not None: | |
| peft_adapter_name = "." + peft_adapter_name | |
| else: | |
| peft_adapter_name = "" | |
| if original_type is None: | |
| # Old diffusers to PEFT | |
| if any("to_out_lora" in k for k in state_dict.keys()): | |
| original_type = StateDictType.DIFFUSERS_OLD | |
| elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): | |
| original_type = StateDictType.PEFT | |
| elif any("lora_linear_layer" in k for k in state_dict.keys()): | |
| # nothing to do | |
| return state_dict | |
| else: | |
| raise ValueError("Could not automatically infer state dict type") | |
| if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys(): | |
| raise ValueError(f"Original type {original_type} is not supported") | |
| mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] | |
| return convert_state_dict(state_dict, mapping) | |
| def convert_unet_state_dict_to_peft(state_dict): | |
| r""" | |
| Converts a state dict from UNet format to diffusers format - i.e. by removing some keys | |
| """ | |
| mapping = UNET_TO_DIFFUSERS | |
| return convert_state_dict(state_dict, mapping) | |
| def convert_sai_sd_control_lora_state_dict_to_peft(state_dict): | |
| def _convert_controlnet_to_diffusers(state_dict): | |
| is_sdxl = "input_blocks.11.0.in_layers.0.weight" not in state_dict | |
| logger.info(f"Using ControlNet lora ({'SDXL' if is_sdxl else 'SD15'})") | |
| # Retrieves the keys for the input blocks only | |
| num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer}) | |
| input_blocks = { | |
| layer_id: [key for key in state_dict if f"input_blocks.{layer_id}" in key] | |
| for layer_id in range(num_input_blocks) | |
| } | |
| layers_per_block = 2 | |
| # op blocks | |
| op_blocks = [key for key in state_dict if "0.op" in key] | |
| converted_state_dict = {} | |
| # Conv in layers | |
| for key in input_blocks[0]: | |
| diffusers_key = key.replace("input_blocks.0.0", "conv_in") | |
| converted_state_dict[diffusers_key] = state_dict.get(key) | |
| # controlnet time embedding blocks | |
| time_embedding_blocks = [key for key in state_dict if "time_embed" in key] | |
| for key in time_embedding_blocks: | |
| diffusers_key = key.replace("time_embed.0", "time_embedding.linear_1").replace( | |
| "time_embed.2", "time_embedding.linear_2" | |
| ) | |
| converted_state_dict[diffusers_key] = state_dict.get(key) | |
| # controlnet label embedding blocks | |
| label_embedding_blocks = [key for key in state_dict if "label_emb" in key] | |
| for key in label_embedding_blocks: | |
| diffusers_key = key.replace("label_emb.0.0", "add_embedding.linear_1").replace( | |
| "label_emb.0.2", "add_embedding.linear_2" | |
| ) | |
| converted_state_dict[diffusers_key] = state_dict.get(key) | |
| # Down blocks | |
| for i in range(1, num_input_blocks): | |
| block_id = (i - 1) // (layers_per_block + 1) | |
| layer_in_block_id = (i - 1) % (layers_per_block + 1) | |
| resnets = [ | |
| key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key | |
| ] | |
| for key in resnets: | |
| diffusers_key = ( | |
| key.replace("in_layers.0", "norm1") | |
| .replace("in_layers.2", "conv1") | |
| .replace("out_layers.0", "norm2") | |
| .replace("out_layers.3", "conv2") | |
| .replace("emb_layers.1", "time_emb_proj") | |
| .replace("skip_connection", "conv_shortcut") | |
| ) | |
| diffusers_key = diffusers_key.replace( | |
| f"input_blocks.{i}.0", f"down_blocks.{block_id}.resnets.{layer_in_block_id}" | |
| ) | |
| converted_state_dict[diffusers_key] = state_dict.get(key) | |
| if f"input_blocks.{i}.0.op.bias" in state_dict: | |
| for key in [key for key in op_blocks if f"input_blocks.{i}.0.op" in key]: | |
| diffusers_key = key.replace( | |
| f"input_blocks.{i}.0.op", f"down_blocks.{block_id}.downsamplers.0.conv" | |
| ) | |
| converted_state_dict[diffusers_key] = state_dict.get(key) | |
| attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] | |
| if attentions: | |
| for key in attentions: | |
| diffusers_key = key.replace( | |
| f"input_blocks.{i}.1", f"down_blocks.{block_id}.attentions.{layer_in_block_id}" | |
| ) | |
| converted_state_dict[diffusers_key] = state_dict.get(key) | |
| # controlnet down blocks | |
| for i in range(num_input_blocks): | |
| converted_state_dict[f"controlnet_down_blocks.{i}.weight"] = state_dict.get(f"zero_convs.{i}.0.weight") | |
| converted_state_dict[f"controlnet_down_blocks.{i}.bias"] = state_dict.get(f"zero_convs.{i}.0.bias") | |
| # Retrieves the keys for the middle blocks only | |
| num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "middle_block" in layer}) | |
| middle_blocks = { | |
| layer_id: [key for key in state_dict if f"middle_block.{layer_id}" in key] | |
| for layer_id in range(num_middle_blocks) | |
| } | |
| # Mid blocks | |
| for key in middle_blocks.keys(): | |
| diffusers_key = max(key - 1, 0) | |
| if key % 2 == 0: | |
| for k in middle_blocks[key]: | |
| diffusers_key_hf = ( | |
| k.replace("in_layers.0", "norm1") | |
| .replace("in_layers.2", "conv1") | |
| .replace("out_layers.0", "norm2") | |
| .replace("out_layers.3", "conv2") | |
| .replace("emb_layers.1", "time_emb_proj") | |
| .replace("skip_connection", "conv_shortcut") | |
| ) | |
| diffusers_key_hf = diffusers_key_hf.replace( | |
| f"middle_block.{key}", f"mid_block.resnets.{diffusers_key}" | |
| ) | |
| converted_state_dict[diffusers_key_hf] = state_dict.get(k) | |
| else: | |
| for k in middle_blocks[key]: | |
| diffusers_key_hf = k.replace(f"middle_block.{key}", f"mid_block.attentions.{diffusers_key}") | |
| converted_state_dict[diffusers_key_hf] = state_dict.get(k) | |
| # mid block | |
| converted_state_dict["controlnet_mid_block.weight"] = state_dict.get("middle_block_out.0.weight") | |
| converted_state_dict["controlnet_mid_block.bias"] = state_dict.get("middle_block_out.0.bias") | |
| # controlnet cond embedding blocks | |
| cond_embedding_blocks = { | |
| ".".join(layer.split(".")[:2]) | |
| for layer in state_dict | |
| if "input_hint_block" in layer | |
| and ("input_hint_block.0" not in layer) | |
| and ("input_hint_block.14" not in layer) | |
| } | |
| num_cond_embedding_blocks = len(cond_embedding_blocks) | |
| for idx in range(1, num_cond_embedding_blocks + 1): | |
| diffusers_idx = idx - 1 | |
| cond_block_id = 2 * idx | |
| converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = state_dict.get( | |
| f"input_hint_block.{cond_block_id}.weight" | |
| ) | |
| converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = state_dict.get( | |
| f"input_hint_block.{cond_block_id}.bias" | |
| ) | |
| for key in [key for key in state_dict if "input_hint_block.0" in key]: | |
| diffusers_key = key.replace("input_hint_block.0", "controlnet_cond_embedding.conv_in") | |
| converted_state_dict[diffusers_key] = state_dict.get(key) | |
| for key in [key for key in state_dict if "input_hint_block.14" in key]: | |
| diffusers_key = key.replace("input_hint_block.14", "controlnet_cond_embedding.conv_out") | |
| converted_state_dict[diffusers_key] = state_dict.get(key) | |
| return converted_state_dict | |
| state_dict = _convert_controlnet_to_diffusers(state_dict) | |
| mapping = CONTROL_LORA_TO_DIFFUSERS | |
| return convert_state_dict(state_dict, mapping) | |
| def convert_all_state_dict_to_peft(state_dict): | |
| r""" | |
| Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid | |
| `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft` | |
| """ | |
| try: | |
| peft_dict = convert_state_dict_to_peft(state_dict) | |
| except Exception as e: | |
| if str(e) == "Could not automatically infer state dict type": | |
| peft_dict = convert_unet_state_dict_to_peft(state_dict) | |
| else: | |
| raise | |
| if not any("lora_A" in key or "lora_B" in key for key in peft_dict.keys()): | |
| raise ValueError("Your LoRA was not converted to PEFT") | |
| return peft_dict | |
| def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs): | |
| r""" | |
| Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc. | |
| The method only supports the conversion from PEFT to Kohya for now. | |
| Args: | |
| state_dict (`dict[str, torch.Tensor]`): | |
| The state dict to convert. | |
| original_type (`StateDictType`, *optional*): | |
| The original type of the state dict, if not provided, the method will try to infer it automatically. | |
| kwargs (`dict`, *args*): | |
| Additional arguments to pass to the method. | |
| - **adapter_name**: For example, in case of PEFT, some keys will be prepended | |
| with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in | |
| `get_peft_model_state_dict` method: | |
| https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 | |
| but we add it here in case we don't want to rely on that method. | |
| """ | |
| try: | |
| import torch | |
| except ImportError: | |
| logger.error("Converting PEFT state dicts to Kohya requires torch to be installed.") | |
| raise | |
| peft_adapter_name = kwargs.pop("adapter_name", None) | |
| if peft_adapter_name is not None: | |
| peft_adapter_name = "." + peft_adapter_name | |
| else: | |
| peft_adapter_name = "" | |
| if original_type is None: | |
| if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): | |
| original_type = StateDictType.PEFT | |
| if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys(): | |
| raise ValueError(f"Original type {original_type} is not supported") | |
| # Use the convert_state_dict function with the appropriate mapping | |
| kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT]) | |
| kohya_ss_state_dict = {} | |
| # Additional logic for replacing header, alpha parameters `.` with `_` in all keys | |
| for kohya_key, weight in kohya_ss_partial_state_dict.items(): | |
| if "text_encoder_2." in kohya_key: | |
| kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.") | |
| elif "text_encoder." in kohya_key: | |
| kohya_key = kohya_key.replace("text_encoder.", "lora_te1.") | |
| elif "unet" in kohya_key: | |
| kohya_key = kohya_key.replace("unet", "lora_unet") | |
| elif "lora_magnitude_vector" in kohya_key: | |
| kohya_key = kohya_key.replace("lora_magnitude_vector", "dora_scale") | |
| kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) | |
| kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names | |
| kohya_ss_state_dict[kohya_key] = weight | |
| if "lora_down" in kohya_key: | |
| alpha_key = f"{kohya_key.split('.')[0]}.alpha" | |
| kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight)) | |
| return kohya_ss_state_dict | |
| def state_dict_all_zero(state_dict, filter_str=None): | |
| if filter_str is not None: | |
| if isinstance(filter_str, str): | |
| filter_str = [filter_str] | |
| state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)} | |
| return all(torch.all(param == 0).item() for param in state_dict.values()) | |
| def _load_sft_state_dict_metadata(model_file: str): | |
| import safetensors.torch | |
| from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY | |
| with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f: | |
| metadata = f.metadata() or {} | |
| metadata.pop("format", None) | |
| if metadata: | |
| raw = metadata.get(LORA_ADAPTER_METADATA_KEY) | |
| return json.loads(raw) if raw else None | |
| else: | |
| return None | |