build-tools / diffusers /utils /state_dict_utils.py
salmankhanpm's picture
Add files using upload-large-folder tool
69e1a8d verified
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
State dict utilities: utility methods for converting state dicts easily
"""
import enum
import json
from .import_utils import is_torch_available
from .logging import get_logger
if is_torch_available():
import torch
logger = get_logger(__name__)
class StateDictType(enum.Enum):
"""
The mode to use when converting state dicts.
"""
DIFFUSERS_OLD = "diffusers_old"
KOHYA_SS = "kohya_ss"
PEFT = "peft"
DIFFUSERS = "diffusers"
# We need to define a proper mapping for Unet since it uses different output keys than text encoder
# e.g. to_q_lora -> q_proj / to_q
UNET_TO_DIFFUSERS = {
".to_out_lora.up": ".to_out.0.lora_B",
".to_out_lora.down": ".to_out.0.lora_A",
".to_q_lora.down": ".to_q.lora_A",
".to_q_lora.up": ".to_q.lora_B",
".to_k_lora.down": ".to_k.lora_A",
".to_k_lora.up": ".to_k.lora_B",
".to_v_lora.down": ".to_v.lora_A",
".to_v_lora.up": ".to_v.lora_B",
".lora.up": ".lora_B",
".lora.down": ".lora_A",
".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector",
}
CONTROL_LORA_TO_DIFFUSERS = {
".to_q.down": ".to_q.lora_A.weight",
".to_q.up": ".to_q.lora_B.weight",
".to_k.down": ".to_k.lora_A.weight",
".to_k.up": ".to_k.lora_B.weight",
".to_v.down": ".to_v.lora_A.weight",
".to_v.up": ".to_v.lora_B.weight",
".to_out.0.down": ".to_out.0.lora_A.weight",
".to_out.0.up": ".to_out.0.lora_B.weight",
".ff.net.0.proj.down": ".ff.net.0.proj.lora_A.weight",
".ff.net.0.proj.up": ".ff.net.0.proj.lora_B.weight",
".ff.net.2.down": ".ff.net.2.lora_A.weight",
".ff.net.2.up": ".ff.net.2.lora_B.weight",
".proj_in.down": ".proj_in.lora_A.weight",
".proj_in.up": ".proj_in.lora_B.weight",
".proj_out.down": ".proj_out.lora_A.weight",
".proj_out.up": ".proj_out.lora_B.weight",
".conv.down": ".conv.lora_A.weight",
".conv.up": ".conv.lora_B.weight",
**{f".conv{i}.down": f".conv{i}.lora_A.weight" for i in range(1, 3)},
**{f".conv{i}.up": f".conv{i}.lora_B.weight" for i in range(1, 3)},
"conv_in.down": "conv_in.lora_A.weight",
"conv_in.up": "conv_in.lora_B.weight",
".conv_shortcut.down": ".conv_shortcut.lora_A.weight",
".conv_shortcut.up": ".conv_shortcut.lora_B.weight",
**{f".linear_{i}.down": f".linear_{i}.lora_A.weight" for i in range(1, 3)},
**{f".linear_{i}.up": f".linear_{i}.lora_B.weight" for i in range(1, 3)},
"time_emb_proj.down": "time_emb_proj.lora_A.weight",
"time_emb_proj.up": "time_emb_proj.lora_B.weight",
}
DIFFUSERS_TO_PEFT = {
".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
".q_proj.lora_linear_layer.down": ".q_proj.lora_A",
".k_proj.lora_linear_layer.up": ".k_proj.lora_B",
".k_proj.lora_linear_layer.down": ".k_proj.lora_A",
".v_proj.lora_linear_layer.up": ".v_proj.lora_B",
".v_proj.lora_linear_layer.down": ".v_proj.lora_A",
".out_proj.lora_linear_layer.up": ".out_proj.lora_B",
".out_proj.lora_linear_layer.down": ".out_proj.lora_A",
".lora_linear_layer.up": ".lora_B",
".lora_linear_layer.down": ".lora_A",
"text_projection.lora.down.weight": "text_projection.lora_A.weight",
"text_projection.lora.up.weight": "text_projection.lora_B.weight",
}
DIFFUSERS_OLD_TO_PEFT = {
".to_q_lora.up": ".q_proj.lora_B",
".to_q_lora.down": ".q_proj.lora_A",
".to_k_lora.up": ".k_proj.lora_B",
".to_k_lora.down": ".k_proj.lora_A",
".to_v_lora.up": ".v_proj.lora_B",
".to_v_lora.down": ".v_proj.lora_A",
".to_out_lora.up": ".out_proj.lora_B",
".to_out_lora.down": ".out_proj.lora_A",
".lora_linear_layer.up": ".lora_B",
".lora_linear_layer.down": ".lora_A",
}
PEFT_TO_DIFFUSERS = {
".q_proj.lora_B": ".q_proj.lora_linear_layer.up",
".q_proj.lora_A": ".q_proj.lora_linear_layer.down",
".k_proj.lora_B": ".k_proj.lora_linear_layer.up",
".k_proj.lora_A": ".k_proj.lora_linear_layer.down",
".v_proj.lora_B": ".v_proj.lora_linear_layer.up",
".v_proj.lora_A": ".v_proj.lora_linear_layer.down",
".out_proj.lora_B": ".out_proj.lora_linear_layer.up",
".out_proj.lora_A": ".out_proj.lora_linear_layer.down",
"to_k.lora_A": "to_k.lora.down",
"to_k.lora_B": "to_k.lora.up",
"to_q.lora_A": "to_q.lora.down",
"to_q.lora_B": "to_q.lora.up",
"to_v.lora_A": "to_v.lora.down",
"to_v.lora_B": "to_v.lora.up",
"to_out.0.lora_A": "to_out.0.lora.down",
"to_out.0.lora_B": "to_out.0.lora.up",
}
DIFFUSERS_OLD_TO_DIFFUSERS = {
".to_q_lora.up": ".q_proj.lora_linear_layer.up",
".to_q_lora.down": ".q_proj.lora_linear_layer.down",
".to_k_lora.up": ".k_proj.lora_linear_layer.up",
".to_k_lora.down": ".k_proj.lora_linear_layer.down",
".to_v_lora.up": ".v_proj.lora_linear_layer.up",
".to_v_lora.down": ".v_proj.lora_linear_layer.down",
".to_out_lora.up": ".out_proj.lora_linear_layer.up",
".to_out_lora.down": ".out_proj.lora_linear_layer.down",
".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector",
".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector",
".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector",
".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector",
}
PEFT_TO_KOHYA_SS = {
"lora_A": "lora_down",
"lora_B": "lora_up",
# This is not a comprehensive dict as kohya format requires replacing `.` with `_` in keys,
# adding prefixes and adding alpha values
# Check `convert_state_dict_to_kohya` for more
}
PEFT_STATE_DICT_MAPPINGS = {
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT,
StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT,
}
DIFFUSERS_STATE_DICT_MAPPINGS = {
StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS,
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
}
KOHYA_STATE_DICT_MAPPINGS = {StateDictType.PEFT: PEFT_TO_KOHYA_SS}
KEYS_TO_ALWAYS_REPLACE = {
".processor.": ".",
}
def convert_state_dict(state_dict, mapping):
r"""
Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
mapping (`dict[str, str]`):
The mapping to use for conversion, the mapping should be a dictionary with the following structure:
- key: the pattern to replace
- value: the pattern to replace with
Returns:
converted_state_dict (`dict`)
The converted state dict.
"""
converted_state_dict = {}
for k, v in state_dict.items():
# First, filter out the keys that we always want to replace
for pattern in KEYS_TO_ALWAYS_REPLACE.keys():
if pattern in k:
new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern]
k = k.replace(pattern, new_pattern)
for pattern in mapping.keys():
if pattern in k:
new_pattern = mapping[pattern]
k = k.replace(pattern, new_pattern)
break
converted_state_dict[k] = v
return converted_state_dict
def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs):
r"""
Converts a state dict to the PEFT format The state dict can be from previous diffusers format (`OLD_DIFFUSERS`), or
new diffusers format (`DIFFUSERS`). The method only supports the conversion from diffusers old/new to PEFT for now.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
original_type (`StateDictType`, *optional*):
The original type of the state dict, if not provided, the method will try to infer it automatically.
"""
if original_type is None:
# Old diffusers to PEFT
if any("to_out_lora" in k for k in state_dict.keys()):
original_type = StateDictType.DIFFUSERS_OLD
elif any("lora_linear_layer" in k for k in state_dict.keys()):
original_type = StateDictType.DIFFUSERS
else:
raise ValueError("Could not automatically infer state dict type")
if original_type not in PEFT_STATE_DICT_MAPPINGS.keys():
raise ValueError(f"Original type {original_type} is not supported")
mapping = PEFT_STATE_DICT_MAPPINGS[original_type]
return convert_state_dict(state_dict, mapping)
def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
r"""
Converts a state dict to new diffusers format. The state dict can be from previous diffusers format
(`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will
return the state dict as is.
The method only supports the conversion from diffusers old, PEFT to diffusers new for now.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
original_type (`StateDictType`, *optional*):
The original type of the state dict, if not provided, the method will try to infer it automatically.
kwargs (`dict`, *args*):
Additional arguments to pass to the method.
- **adapter_name**: For example, in case of PEFT, some keys will be prepended
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
`get_peft_model_state_dict` method:
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
but we add it here in case we don't want to rely on that method.
"""
peft_adapter_name = kwargs.pop("adapter_name", None)
if peft_adapter_name is not None:
peft_adapter_name = "." + peft_adapter_name
else:
peft_adapter_name = ""
if original_type is None:
# Old diffusers to PEFT
if any("to_out_lora" in k for k in state_dict.keys()):
original_type = StateDictType.DIFFUSERS_OLD
elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
original_type = StateDictType.PEFT
elif any("lora_linear_layer" in k for k in state_dict.keys()):
# nothing to do
return state_dict
else:
raise ValueError("Could not automatically infer state dict type")
if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys():
raise ValueError(f"Original type {original_type} is not supported")
mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type]
return convert_state_dict(state_dict, mapping)
def convert_unet_state_dict_to_peft(state_dict):
r"""
Converts a state dict from UNet format to diffusers format - i.e. by removing some keys
"""
mapping = UNET_TO_DIFFUSERS
return convert_state_dict(state_dict, mapping)
def convert_sai_sd_control_lora_state_dict_to_peft(state_dict):
def _convert_controlnet_to_diffusers(state_dict):
is_sdxl = "input_blocks.11.0.in_layers.0.weight" not in state_dict
logger.info(f"Using ControlNet lora ({'SDXL' if is_sdxl else 'SD15'})")
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
layers_per_block = 2
# op blocks
op_blocks = [key for key in state_dict if "0.op" in key]
converted_state_dict = {}
# Conv in layers
for key in input_blocks[0]:
diffusers_key = key.replace("input_blocks.0.0", "conv_in")
converted_state_dict[diffusers_key] = state_dict.get(key)
# controlnet time embedding blocks
time_embedding_blocks = [key for key in state_dict if "time_embed" in key]
for key in time_embedding_blocks:
diffusers_key = key.replace("time_embed.0", "time_embedding.linear_1").replace(
"time_embed.2", "time_embedding.linear_2"
)
converted_state_dict[diffusers_key] = state_dict.get(key)
# controlnet label embedding blocks
label_embedding_blocks = [key for key in state_dict if "label_emb" in key]
for key in label_embedding_blocks:
diffusers_key = key.replace("label_emb.0.0", "add_embedding.linear_1").replace(
"label_emb.0.2", "add_embedding.linear_2"
)
converted_state_dict[diffusers_key] = state_dict.get(key)
# Down blocks
for i in range(1, num_input_blocks):
block_id = (i - 1) // (layers_per_block + 1)
layer_in_block_id = (i - 1) % (layers_per_block + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
for key in resnets:
diffusers_key = (
key.replace("in_layers.0", "norm1")
.replace("in_layers.2", "conv1")
.replace("out_layers.0", "norm2")
.replace("out_layers.3", "conv2")
.replace("emb_layers.1", "time_emb_proj")
.replace("skip_connection", "conv_shortcut")
)
diffusers_key = diffusers_key.replace(
f"input_blocks.{i}.0", f"down_blocks.{block_id}.resnets.{layer_in_block_id}"
)
converted_state_dict[diffusers_key] = state_dict.get(key)
if f"input_blocks.{i}.0.op.bias" in state_dict:
for key in [key for key in op_blocks if f"input_blocks.{i}.0.op" in key]:
diffusers_key = key.replace(
f"input_blocks.{i}.0.op", f"down_blocks.{block_id}.downsamplers.0.conv"
)
converted_state_dict[diffusers_key] = state_dict.get(key)
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if attentions:
for key in attentions:
diffusers_key = key.replace(
f"input_blocks.{i}.1", f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
)
converted_state_dict[diffusers_key] = state_dict.get(key)
# controlnet down blocks
for i in range(num_input_blocks):
converted_state_dict[f"controlnet_down_blocks.{i}.weight"] = state_dict.get(f"zero_convs.{i}.0.weight")
converted_state_dict[f"controlnet_down_blocks.{i}.bias"] = state_dict.get(f"zero_convs.{i}.0.bias")
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Mid blocks
for key in middle_blocks.keys():
diffusers_key = max(key - 1, 0)
if key % 2 == 0:
for k in middle_blocks[key]:
diffusers_key_hf = (
k.replace("in_layers.0", "norm1")
.replace("in_layers.2", "conv1")
.replace("out_layers.0", "norm2")
.replace("out_layers.3", "conv2")
.replace("emb_layers.1", "time_emb_proj")
.replace("skip_connection", "conv_shortcut")
)
diffusers_key_hf = diffusers_key_hf.replace(
f"middle_block.{key}", f"mid_block.resnets.{diffusers_key}"
)
converted_state_dict[diffusers_key_hf] = state_dict.get(k)
else:
for k in middle_blocks[key]:
diffusers_key_hf = k.replace(f"middle_block.{key}", f"mid_block.attentions.{diffusers_key}")
converted_state_dict[diffusers_key_hf] = state_dict.get(k)
# mid block
converted_state_dict["controlnet_mid_block.weight"] = state_dict.get("middle_block_out.0.weight")
converted_state_dict["controlnet_mid_block.bias"] = state_dict.get("middle_block_out.0.bias")
# controlnet cond embedding blocks
cond_embedding_blocks = {
".".join(layer.split(".")[:2])
for layer in state_dict
if "input_hint_block" in layer
and ("input_hint_block.0" not in layer)
and ("input_hint_block.14" not in layer)
}
num_cond_embedding_blocks = len(cond_embedding_blocks)
for idx in range(1, num_cond_embedding_blocks + 1):
diffusers_idx = idx - 1
cond_block_id = 2 * idx
converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = state_dict.get(
f"input_hint_block.{cond_block_id}.weight"
)
converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = state_dict.get(
f"input_hint_block.{cond_block_id}.bias"
)
for key in [key for key in state_dict if "input_hint_block.0" in key]:
diffusers_key = key.replace("input_hint_block.0", "controlnet_cond_embedding.conv_in")
converted_state_dict[diffusers_key] = state_dict.get(key)
for key in [key for key in state_dict if "input_hint_block.14" in key]:
diffusers_key = key.replace("input_hint_block.14", "controlnet_cond_embedding.conv_out")
converted_state_dict[diffusers_key] = state_dict.get(key)
return converted_state_dict
state_dict = _convert_controlnet_to_diffusers(state_dict)
mapping = CONTROL_LORA_TO_DIFFUSERS
return convert_state_dict(state_dict, mapping)
def convert_all_state_dict_to_peft(state_dict):
r"""
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid
`DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
"""
try:
peft_dict = convert_state_dict_to_peft(state_dict)
except Exception as e:
if str(e) == "Could not automatically infer state dict type":
peft_dict = convert_unet_state_dict_to_peft(state_dict)
else:
raise
if not any("lora_A" in key or "lora_B" in key for key in peft_dict.keys()):
raise ValueError("Your LoRA was not converted to PEFT")
return peft_dict
def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
r"""
Converts a `PEFT` state dict to `Kohya` format that can be used in AUTOMATIC1111, ComfyUI, SD.Next, InvokeAI, etc.
The method only supports the conversion from PEFT to Kohya for now.
Args:
state_dict (`dict[str, torch.Tensor]`):
The state dict to convert.
original_type (`StateDictType`, *optional*):
The original type of the state dict, if not provided, the method will try to infer it automatically.
kwargs (`dict`, *args*):
Additional arguments to pass to the method.
- **adapter_name**: For example, in case of PEFT, some keys will be prepended
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
`get_peft_model_state_dict` method:
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
but we add it here in case we don't want to rely on that method.
"""
try:
import torch
except ImportError:
logger.error("Converting PEFT state dicts to Kohya requires torch to be installed.")
raise
peft_adapter_name = kwargs.pop("adapter_name", None)
if peft_adapter_name is not None:
peft_adapter_name = "." + peft_adapter_name
else:
peft_adapter_name = ""
if original_type is None:
if any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
original_type = StateDictType.PEFT
if original_type not in KOHYA_STATE_DICT_MAPPINGS.keys():
raise ValueError(f"Original type {original_type} is not supported")
# Use the convert_state_dict function with the appropriate mapping
kohya_ss_partial_state_dict = convert_state_dict(state_dict, KOHYA_STATE_DICT_MAPPINGS[StateDictType.PEFT])
kohya_ss_state_dict = {}
# Additional logic for replacing header, alpha parameters `.` with `_` in all keys
for kohya_key, weight in kohya_ss_partial_state_dict.items():
if "text_encoder_2." in kohya_key:
kohya_key = kohya_key.replace("text_encoder_2.", "lora_te2.")
elif "text_encoder." in kohya_key:
kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
elif "unet" in kohya_key:
kohya_key = kohya_key.replace("unet", "lora_unet")
elif "lora_magnitude_vector" in kohya_key:
kohya_key = kohya_key.replace("lora_magnitude_vector", "dora_scale")
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
kohya_ss_state_dict[kohya_key] = weight
if "lora_down" in kohya_key:
alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
return kohya_ss_state_dict
def state_dict_all_zero(state_dict, filter_str=None):
if filter_str is not None:
if isinstance(filter_str, str):
filter_str = [filter_str]
state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}
return all(torch.all(param == 0).item() for param in state_dict.values())
def _load_sft_state_dict_metadata(model_file: str):
import safetensors.torch
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
if metadata:
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
return json.loads(raw) if raw else None
else:
return None