| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import copy |
| | from typing import TYPE_CHECKING, Dict, List, Union |
| |
|
| | from ..utils import logging |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | |
| | from ..models import UNet2DConditionModel |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | def _translate_into_actual_layer_name(name): |
| | """Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')""" |
| | if name == "mid": |
| | return "mid_block.attentions.0" |
| |
|
| | updown, block, attn = name.split(".") |
| |
|
| | updown = updown.replace("down", "down_blocks").replace("up", "up_blocks") |
| | block = block.replace("block_", "") |
| | attn = "attentions." + attn |
| |
|
| | return ".".join((updown, block, attn)) |
| |
|
| |
|
| | def _maybe_expand_lora_scales( |
| | unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0 |
| | ): |
| | blocks_with_transformer = { |
| | "down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")], |
| | "up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")], |
| | } |
| | transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1} |
| |
|
| | expanded_weight_scales = [ |
| | _maybe_expand_lora_scales_for_one_adapter( |
| | weight_for_adapter, |
| | blocks_with_transformer, |
| | transformer_per_block, |
| | unet.state_dict(), |
| | default_scale=default_scale, |
| | ) |
| | for weight_for_adapter in weight_scales |
| | ] |
| |
|
| | return expanded_weight_scales |
| |
|
| |
|
| | def _maybe_expand_lora_scales_for_one_adapter( |
| | scales: Union[float, Dict], |
| | blocks_with_transformer: Dict[str, int], |
| | transformer_per_block: Dict[str, int], |
| | state_dict: None, |
| | default_scale: float = 1.0, |
| | ): |
| | """ |
| | Expands the inputs into a more granular dictionary. See the example below for more details. |
| | |
| | Parameters: |
| | scales (`Union[float, Dict]`): |
| | Scales dict to expand. |
| | blocks_with_transformer (`Dict[str, int]`): |
| | Dict with keys 'up' and 'down', showing which blocks have transformer layers |
| | transformer_per_block (`Dict[str, int]`): |
| | Dict with keys 'up' and 'down', showing how many transformer layers each block has |
| | |
| | E.g. turns |
| | ```python |
| | scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}} |
| | blocks_with_transformer = {"down": [1, 2], "up": [0, 1]} |
| | transformer_per_block = {"down": 2, "up": 3} |
| | ``` |
| | into |
| | ```python |
| | { |
| | "down.block_1.0": 2, |
| | "down.block_1.1": 2, |
| | "down.block_2.0": 2, |
| | "down.block_2.1": 2, |
| | "mid": 3, |
| | "up.block_0.0": 4, |
| | "up.block_0.1": 4, |
| | "up.block_0.2": 4, |
| | "up.block_1.0": 5, |
| | "up.block_1.1": 6, |
| | "up.block_1.2": 7, |
| | } |
| | ``` |
| | """ |
| | if sorted(blocks_with_transformer.keys()) != ["down", "up"]: |
| | raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`") |
| |
|
| | if sorted(transformer_per_block.keys()) != ["down", "up"]: |
| | raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`") |
| |
|
| | if not isinstance(scales, dict): |
| | |
| | return scales |
| |
|
| | scales = copy.deepcopy(scales) |
| |
|
| | if "mid" not in scales: |
| | scales["mid"] = default_scale |
| | elif isinstance(scales["mid"], list): |
| | if len(scales["mid"]) == 1: |
| | scales["mid"] = scales["mid"][0] |
| | else: |
| | raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.") |
| |
|
| | for updown in ["up", "down"]: |
| | if updown not in scales: |
| | scales[updown] = default_scale |
| |
|
| | |
| | if not isinstance(scales[updown], dict): |
| | scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]} |
| |
|
| | |
| | for i in blocks_with_transformer[updown]: |
| | block = f"block_{i}" |
| | |
| | if block not in scales[updown]: |
| | scales[updown][block] = default_scale |
| | if not isinstance(scales[updown][block], list): |
| | scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])] |
| | elif len(scales[updown][block]) == 1: |
| | |
| | scales[updown][block] = scales[updown][block] * transformer_per_block[updown] |
| | elif len(scales[updown][block]) != transformer_per_block[updown]: |
| | raise ValueError( |
| | f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}." |
| | ) |
| |
|
| | |
| | for i in blocks_with_transformer[updown]: |
| | block = f"block_{i}" |
| | for tf_idx, value in enumerate(scales[updown][block]): |
| | scales[f"{updown}.{block}.{tf_idx}"] = value |
| |
|
| | del scales[updown] |
| |
|
| | for layer in scales.keys(): |
| | if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()): |
| | raise ValueError( |
| | f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions." |
| | ) |
| |
|
| | return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()} |
| |
|