Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- sd-scripts/config_util.py +721 -0
- sd-scripts/model_util.py +1355 -0
- sd-scripts/train_util.py +0 -0
- sd-scripts/utils.py +287 -0
sd-scripts/config_util.py
ADDED
|
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from dataclasses import (
|
| 3 |
+
asdict,
|
| 4 |
+
dataclass,
|
| 5 |
+
)
|
| 6 |
+
import functools
|
| 7 |
+
import random
|
| 8 |
+
from textwrap import dedent, indent
|
| 9 |
+
import json
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# from toolz import curry
|
| 13 |
+
from typing import (
|
| 14 |
+
List,
|
| 15 |
+
Optional,
|
| 16 |
+
Sequence,
|
| 17 |
+
Tuple,
|
| 18 |
+
Union,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
import toml
|
| 22 |
+
import voluptuous
|
| 23 |
+
from voluptuous import (
|
| 24 |
+
Any,
|
| 25 |
+
ExactSequence,
|
| 26 |
+
MultipleInvalid,
|
| 27 |
+
Object,
|
| 28 |
+
Required,
|
| 29 |
+
Schema,
|
| 30 |
+
)
|
| 31 |
+
from transformers import CLIPTokenizer
|
| 32 |
+
|
| 33 |
+
from . import train_util
|
| 34 |
+
from .train_util import (
|
| 35 |
+
DreamBoothSubset,
|
| 36 |
+
FineTuningSubset,
|
| 37 |
+
ControlNetSubset,
|
| 38 |
+
DreamBoothDataset,
|
| 39 |
+
FineTuningDataset,
|
| 40 |
+
ControlNetDataset,
|
| 41 |
+
DatasetGroup,
|
| 42 |
+
)
|
| 43 |
+
from .utils import setup_logging
|
| 44 |
+
|
| 45 |
+
setup_logging()
|
| 46 |
+
import logging
|
| 47 |
+
|
| 48 |
+
logger = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def add_config_arguments(parser: argparse.ArgumentParser):
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# TODO: inherit Params class in Subset, Dataset
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class BaseSubsetParams:
|
| 62 |
+
image_dir: Optional[str] = None
|
| 63 |
+
num_repeats: int = 1
|
| 64 |
+
shuffle_caption: bool = False
|
| 65 |
+
caption_separator: str = (",",)
|
| 66 |
+
keep_tokens: int = 0
|
| 67 |
+
keep_tokens_separator: str = (None,)
|
| 68 |
+
secondary_separator: Optional[str] = None
|
| 69 |
+
enable_wildcard: bool = False
|
| 70 |
+
color_aug: bool = False
|
| 71 |
+
flip_aug: bool = False
|
| 72 |
+
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
| 73 |
+
random_crop: bool = False
|
| 74 |
+
caption_prefix: Optional[str] = None
|
| 75 |
+
caption_suffix: Optional[str] = None
|
| 76 |
+
caption_dropout_rate: float = 0.0
|
| 77 |
+
caption_dropout_every_n_epochs: int = 0
|
| 78 |
+
caption_tag_dropout_rate: float = 0.0
|
| 79 |
+
token_warmup_min: int = 1
|
| 80 |
+
token_warmup_step: float = 0
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@dataclass
|
| 84 |
+
class DreamBoothSubsetParams(BaseSubsetParams):
|
| 85 |
+
is_reg: bool = False
|
| 86 |
+
class_tokens: Optional[str] = None
|
| 87 |
+
caption_extension: str = ".caption"
|
| 88 |
+
cache_info: bool = False
|
| 89 |
+
alpha_mask: bool = False
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclass
|
| 93 |
+
class FineTuningSubsetParams(BaseSubsetParams):
|
| 94 |
+
metadata_file: Optional[str] = None
|
| 95 |
+
alpha_mask: bool = False
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dataclass
|
| 99 |
+
class ControlNetSubsetParams(BaseSubsetParams):
|
| 100 |
+
conditioning_data_dir: str = None
|
| 101 |
+
caption_extension: str = ".caption"
|
| 102 |
+
cache_info: bool = False
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dataclass
|
| 106 |
+
class BaseDatasetParams:
|
| 107 |
+
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
|
| 108 |
+
max_token_length: int = None
|
| 109 |
+
resolution: Optional[Tuple[int, int]] = None
|
| 110 |
+
network_multiplier: float = 1.0
|
| 111 |
+
debug_dataset: bool = False
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@dataclass
|
| 115 |
+
class DreamBoothDatasetParams(BaseDatasetParams):
|
| 116 |
+
batch_size: int = 1
|
| 117 |
+
enable_bucket: bool = False
|
| 118 |
+
min_bucket_reso: int = 256
|
| 119 |
+
max_bucket_reso: int = 1024
|
| 120 |
+
bucket_reso_steps: int = 64
|
| 121 |
+
bucket_no_upscale: bool = False
|
| 122 |
+
prior_loss_weight: float = 1.0
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@dataclass
|
| 126 |
+
class FineTuningDatasetParams(BaseDatasetParams):
|
| 127 |
+
batch_size: int = 1
|
| 128 |
+
enable_bucket: bool = False
|
| 129 |
+
min_bucket_reso: int = 256
|
| 130 |
+
max_bucket_reso: int = 1024
|
| 131 |
+
bucket_reso_steps: int = 64
|
| 132 |
+
bucket_no_upscale: bool = False
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@dataclass
|
| 136 |
+
class ControlNetDatasetParams(BaseDatasetParams):
|
| 137 |
+
batch_size: int = 1
|
| 138 |
+
enable_bucket: bool = False
|
| 139 |
+
min_bucket_reso: int = 256
|
| 140 |
+
max_bucket_reso: int = 1024
|
| 141 |
+
bucket_reso_steps: int = 64
|
| 142 |
+
bucket_no_upscale: bool = False
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@dataclass
|
| 146 |
+
class SubsetBlueprint:
|
| 147 |
+
params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@dataclass
|
| 151 |
+
class DatasetBlueprint:
|
| 152 |
+
is_dreambooth: bool
|
| 153 |
+
is_controlnet: bool
|
| 154 |
+
params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
|
| 155 |
+
subsets: Sequence[SubsetBlueprint]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@dataclass
|
| 159 |
+
class DatasetGroupBlueprint:
|
| 160 |
+
datasets: Sequence[DatasetBlueprint]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@dataclass
|
| 164 |
+
class Blueprint:
|
| 165 |
+
dataset_group: DatasetGroupBlueprint
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class ConfigSanitizer:
|
| 169 |
+
# @curry
|
| 170 |
+
@staticmethod
|
| 171 |
+
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
|
| 172 |
+
Schema(ExactSequence([klass, klass]))(value)
|
| 173 |
+
return tuple(value)
|
| 174 |
+
|
| 175 |
+
# @curry
|
| 176 |
+
@staticmethod
|
| 177 |
+
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
|
| 178 |
+
Schema(Any(klass, ExactSequence([klass, klass])))(value)
|
| 179 |
+
try:
|
| 180 |
+
Schema(klass)(value)
|
| 181 |
+
return (value, value)
|
| 182 |
+
except:
|
| 183 |
+
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
|
| 184 |
+
|
| 185 |
+
# subset schema
|
| 186 |
+
SUBSET_ASCENDABLE_SCHEMA = {
|
| 187 |
+
"color_aug": bool,
|
| 188 |
+
"face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
|
| 189 |
+
"flip_aug": bool,
|
| 190 |
+
"num_repeats": int,
|
| 191 |
+
"random_crop": bool,
|
| 192 |
+
"shuffle_caption": bool,
|
| 193 |
+
"keep_tokens": int,
|
| 194 |
+
"keep_tokens_separator": str,
|
| 195 |
+
"secondary_separator": str,
|
| 196 |
+
"caption_separator": str,
|
| 197 |
+
"enable_wildcard": bool,
|
| 198 |
+
"token_warmup_min": int,
|
| 199 |
+
"token_warmup_step": Any(float, int),
|
| 200 |
+
"caption_prefix": str,
|
| 201 |
+
"caption_suffix": str,
|
| 202 |
+
}
|
| 203 |
+
# DO means DropOut
|
| 204 |
+
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
| 205 |
+
"caption_dropout_every_n_epochs": int,
|
| 206 |
+
"caption_dropout_rate": Any(float, int),
|
| 207 |
+
"caption_tag_dropout_rate": Any(float, int),
|
| 208 |
+
}
|
| 209 |
+
# DB means DreamBooth
|
| 210 |
+
DB_SUBSET_ASCENDABLE_SCHEMA = {
|
| 211 |
+
"caption_extension": str,
|
| 212 |
+
"class_tokens": str,
|
| 213 |
+
"cache_info": bool,
|
| 214 |
+
}
|
| 215 |
+
DB_SUBSET_DISTINCT_SCHEMA = {
|
| 216 |
+
Required("image_dir"): str,
|
| 217 |
+
"is_reg": bool,
|
| 218 |
+
"alpha_mask": bool,
|
| 219 |
+
}
|
| 220 |
+
# FT means FineTuning
|
| 221 |
+
FT_SUBSET_DISTINCT_SCHEMA = {
|
| 222 |
+
Required("metadata_file"): str,
|
| 223 |
+
"image_dir": str,
|
| 224 |
+
"alpha_mask": bool,
|
| 225 |
+
}
|
| 226 |
+
CN_SUBSET_ASCENDABLE_SCHEMA = {
|
| 227 |
+
"caption_extension": str,
|
| 228 |
+
"cache_info": bool,
|
| 229 |
+
}
|
| 230 |
+
CN_SUBSET_DISTINCT_SCHEMA = {
|
| 231 |
+
Required("image_dir"): str,
|
| 232 |
+
Required("conditioning_data_dir"): str,
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
# datasets schema
|
| 236 |
+
DATASET_ASCENDABLE_SCHEMA = {
|
| 237 |
+
"batch_size": int,
|
| 238 |
+
"bucket_no_upscale": bool,
|
| 239 |
+
"bucket_reso_steps": int,
|
| 240 |
+
"enable_bucket": bool,
|
| 241 |
+
"max_bucket_reso": int,
|
| 242 |
+
"min_bucket_reso": int,
|
| 243 |
+
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
| 244 |
+
"network_multiplier": float,
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
# options handled by argparse but not handled by user config
|
| 248 |
+
ARGPARSE_SPECIFIC_SCHEMA = {
|
| 249 |
+
"debug_dataset": bool,
|
| 250 |
+
"max_token_length": Any(None, int),
|
| 251 |
+
"prior_loss_weight": Any(float, int),
|
| 252 |
+
}
|
| 253 |
+
# for handling default None value of argparse
|
| 254 |
+
ARGPARSE_NULLABLE_OPTNAMES = [
|
| 255 |
+
"face_crop_aug_range",
|
| 256 |
+
"resolution",
|
| 257 |
+
]
|
| 258 |
+
# prepare map because option name may differ among argparse and user config
|
| 259 |
+
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
|
| 260 |
+
"train_batch_size": "batch_size",
|
| 261 |
+
"dataset_repeats": "num_repeats",
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
|
| 265 |
+
assert support_dreambooth or support_finetuning or support_controlnet, (
|
| 266 |
+
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
|
| 267 |
+
+ " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
self.db_subset_schema = self.__merge_dict(
|
| 271 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 272 |
+
self.DB_SUBSET_DISTINCT_SCHEMA,
|
| 273 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
| 274 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
self.ft_subset_schema = self.__merge_dict(
|
| 278 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 279 |
+
self.FT_SUBSET_DISTINCT_SCHEMA,
|
| 280 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
self.cn_subset_schema = self.__merge_dict(
|
| 284 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 285 |
+
self.CN_SUBSET_DISTINCT_SCHEMA,
|
| 286 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
| 287 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
self.db_dataset_schema = self.__merge_dict(
|
| 291 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
| 292 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 293 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
| 294 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 295 |
+
{"subsets": [self.db_subset_schema]},
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
self.ft_dataset_schema = self.__merge_dict(
|
| 299 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
| 300 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 301 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 302 |
+
{"subsets": [self.ft_subset_schema]},
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
self.cn_dataset_schema = self.__merge_dict(
|
| 306 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
| 307 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 308 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
| 309 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 310 |
+
{"subsets": [self.cn_subset_schema]},
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
if support_dreambooth and support_finetuning:
|
| 314 |
+
|
| 315 |
+
def validate_flex_dataset(dataset_config: dict):
|
| 316 |
+
subsets_config = dataset_config.get("subsets", [])
|
| 317 |
+
|
| 318 |
+
if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
|
| 319 |
+
return Schema(self.cn_dataset_schema)(dataset_config)
|
| 320 |
+
# check dataset meets FT style
|
| 321 |
+
# NOTE: all FT subsets should have "metadata_file"
|
| 322 |
+
elif all(["metadata_file" in subset for subset in subsets_config]):
|
| 323 |
+
return Schema(self.ft_dataset_schema)(dataset_config)
|
| 324 |
+
# check dataset meets DB style
|
| 325 |
+
# NOTE: all DB subsets should have no "metadata_file"
|
| 326 |
+
elif all(["metadata_file" not in subset for subset in subsets_config]):
|
| 327 |
+
return Schema(self.db_dataset_schema)(dataset_config)
|
| 328 |
+
else:
|
| 329 |
+
raise voluptuous.Invalid(
|
| 330 |
+
"DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
self.dataset_schema = validate_flex_dataset
|
| 334 |
+
elif support_dreambooth:
|
| 335 |
+
if support_controlnet:
|
| 336 |
+
self.dataset_schema = self.cn_dataset_schema
|
| 337 |
+
else:
|
| 338 |
+
self.dataset_schema = self.db_dataset_schema
|
| 339 |
+
elif support_finetuning:
|
| 340 |
+
self.dataset_schema = self.ft_dataset_schema
|
| 341 |
+
elif support_controlnet:
|
| 342 |
+
self.dataset_schema = self.cn_dataset_schema
|
| 343 |
+
|
| 344 |
+
self.general_schema = self.__merge_dict(
|
| 345 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
| 346 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 347 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
|
| 348 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
|
| 349 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
self.user_config_validator = Schema(
|
| 353 |
+
{
|
| 354 |
+
"general": self.general_schema,
|
| 355 |
+
"datasets": [self.dataset_schema],
|
| 356 |
+
}
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
self.argparse_schema = self.__merge_dict(
|
| 360 |
+
self.general_schema,
|
| 361 |
+
self.ARGPARSE_SPECIFIC_SCHEMA,
|
| 362 |
+
{optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
|
| 363 |
+
{a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
|
| 367 |
+
|
| 368 |
+
def sanitize_user_config(self, user_config: dict) -> dict:
|
| 369 |
+
try:
|
| 370 |
+
return self.user_config_validator(user_config)
|
| 371 |
+
except MultipleInvalid:
|
| 372 |
+
# TODO: エラー発生時のメッセージをわかりやすくする
|
| 373 |
+
logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
|
| 374 |
+
raise
|
| 375 |
+
|
| 376 |
+
# NOTE: In nature, argument parser result is not needed to be sanitize
|
| 377 |
+
# However this will help us to detect program bug
|
| 378 |
+
def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
|
| 379 |
+
try:
|
| 380 |
+
return self.argparse_config_validator(argparse_namespace)
|
| 381 |
+
except MultipleInvalid:
|
| 382 |
+
# XXX: this should be a bug
|
| 383 |
+
logger.error(
|
| 384 |
+
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
|
| 385 |
+
)
|
| 386 |
+
raise
|
| 387 |
+
|
| 388 |
+
# NOTE: value would be overwritten by latter dict if there is already the same key
|
| 389 |
+
@staticmethod
|
| 390 |
+
def __merge_dict(*dict_list: dict) -> dict:
|
| 391 |
+
merged = {}
|
| 392 |
+
for schema in dict_list:
|
| 393 |
+
# merged |= schema
|
| 394 |
+
for k, v in schema.items():
|
| 395 |
+
merged[k] = v
|
| 396 |
+
return merged
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class BlueprintGenerator:
|
| 400 |
+
BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
|
| 401 |
+
|
| 402 |
+
def __init__(self, sanitizer: ConfigSanitizer):
|
| 403 |
+
self.sanitizer = sanitizer
|
| 404 |
+
|
| 405 |
+
# runtime_params is for parameters which is only configurable on runtime, such as tokenizer
|
| 406 |
+
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
|
| 407 |
+
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
|
| 408 |
+
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
|
| 409 |
+
|
| 410 |
+
# convert argparse namespace to dict like config
|
| 411 |
+
# NOTE: it is ok to have extra entries in dict
|
| 412 |
+
optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
|
| 413 |
+
argparse_config = {
|
| 414 |
+
optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
general_config = sanitized_user_config.get("general", {})
|
| 418 |
+
|
| 419 |
+
dataset_blueprints = []
|
| 420 |
+
for dataset_config in sanitized_user_config.get("datasets", []):
|
| 421 |
+
# NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
|
| 422 |
+
subsets = dataset_config.get("subsets", [])
|
| 423 |
+
is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
|
| 424 |
+
is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
|
| 425 |
+
if is_controlnet:
|
| 426 |
+
subset_params_klass = ControlNetSubsetParams
|
| 427 |
+
dataset_params_klass = ControlNetDatasetParams
|
| 428 |
+
elif is_dreambooth:
|
| 429 |
+
subset_params_klass = DreamBoothSubsetParams
|
| 430 |
+
dataset_params_klass = DreamBoothDatasetParams
|
| 431 |
+
else:
|
| 432 |
+
subset_params_klass = FineTuningSubsetParams
|
| 433 |
+
dataset_params_klass = FineTuningDatasetParams
|
| 434 |
+
|
| 435 |
+
subset_blueprints = []
|
| 436 |
+
for subset_config in subsets:
|
| 437 |
+
params = self.generate_params_by_fallbacks(
|
| 438 |
+
subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params]
|
| 439 |
+
)
|
| 440 |
+
subset_blueprints.append(SubsetBlueprint(params))
|
| 441 |
+
|
| 442 |
+
params = self.generate_params_by_fallbacks(
|
| 443 |
+
dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
|
| 444 |
+
)
|
| 445 |
+
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
|
| 446 |
+
|
| 447 |
+
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
|
| 448 |
+
|
| 449 |
+
return Blueprint(dataset_group_blueprint)
|
| 450 |
+
|
| 451 |
+
@staticmethod
|
| 452 |
+
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
|
| 453 |
+
name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
|
| 454 |
+
search_value = BlueprintGenerator.search_value
|
| 455 |
+
default_params = asdict(param_klass())
|
| 456 |
+
param_names = default_params.keys()
|
| 457 |
+
|
| 458 |
+
params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
|
| 459 |
+
|
| 460 |
+
return param_klass(**params)
|
| 461 |
+
|
| 462 |
+
@staticmethod
|
| 463 |
+
def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
|
| 464 |
+
for cand in fallbacks:
|
| 465 |
+
value = cand.get(key)
|
| 466 |
+
if value is not None:
|
| 467 |
+
return value
|
| 468 |
+
|
| 469 |
+
return default_value
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
|
| 473 |
+
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
| 474 |
+
|
| 475 |
+
for dataset_blueprint in dataset_group_blueprint.datasets:
|
| 476 |
+
if dataset_blueprint.is_controlnet:
|
| 477 |
+
subset_klass = ControlNetSubset
|
| 478 |
+
dataset_klass = ControlNetDataset
|
| 479 |
+
elif dataset_blueprint.is_dreambooth:
|
| 480 |
+
subset_klass = DreamBoothSubset
|
| 481 |
+
dataset_klass = DreamBoothDataset
|
| 482 |
+
else:
|
| 483 |
+
subset_klass = FineTuningSubset
|
| 484 |
+
dataset_klass = FineTuningDataset
|
| 485 |
+
|
| 486 |
+
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
| 487 |
+
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
| 488 |
+
datasets.append(dataset)
|
| 489 |
+
|
| 490 |
+
# print info
|
| 491 |
+
info = ""
|
| 492 |
+
for i, dataset in enumerate(datasets):
|
| 493 |
+
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
| 494 |
+
is_controlnet = isinstance(dataset, ControlNetDataset)
|
| 495 |
+
info += dedent(
|
| 496 |
+
f"""\
|
| 497 |
+
[Dataset {i}]
|
| 498 |
+
batch_size: {dataset.batch_size}
|
| 499 |
+
resolution: {(dataset.width, dataset.height)}
|
| 500 |
+
enable_bucket: {dataset.enable_bucket}
|
| 501 |
+
network_multiplier: {dataset.network_multiplier}
|
| 502 |
+
"""
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
if dataset.enable_bucket:
|
| 506 |
+
info += indent(
|
| 507 |
+
dedent(
|
| 508 |
+
f"""\
|
| 509 |
+
min_bucket_reso: {dataset.min_bucket_reso}
|
| 510 |
+
max_bucket_reso: {dataset.max_bucket_reso}
|
| 511 |
+
bucket_reso_steps: {dataset.bucket_reso_steps}
|
| 512 |
+
bucket_no_upscale: {dataset.bucket_no_upscale}
|
| 513 |
+
\n"""
|
| 514 |
+
),
|
| 515 |
+
" ",
|
| 516 |
+
)
|
| 517 |
+
else:
|
| 518 |
+
info += "\n"
|
| 519 |
+
|
| 520 |
+
for j, subset in enumerate(dataset.subsets):
|
| 521 |
+
info += indent(
|
| 522 |
+
dedent(
|
| 523 |
+
f"""\
|
| 524 |
+
[Subset {j} of Dataset {i}]
|
| 525 |
+
image_dir: "{subset.image_dir}"
|
| 526 |
+
image_count: {subset.img_count}
|
| 527 |
+
num_repeats: {subset.num_repeats}
|
| 528 |
+
shuffle_caption: {subset.shuffle_caption}
|
| 529 |
+
keep_tokens: {subset.keep_tokens}
|
| 530 |
+
keep_tokens_separator: {subset.keep_tokens_separator}
|
| 531 |
+
caption_separator: {subset.caption_separator}
|
| 532 |
+
secondary_separator: {subset.secondary_separator}
|
| 533 |
+
enable_wildcard: {subset.enable_wildcard}
|
| 534 |
+
caption_dropout_rate: {subset.caption_dropout_rate}
|
| 535 |
+
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
| 536 |
+
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
| 537 |
+
caption_prefix: {subset.caption_prefix}
|
| 538 |
+
caption_suffix: {subset.caption_suffix}
|
| 539 |
+
color_aug: {subset.color_aug}
|
| 540 |
+
flip_aug: {subset.flip_aug}
|
| 541 |
+
face_crop_aug_range: {subset.face_crop_aug_range}
|
| 542 |
+
random_crop: {subset.random_crop}
|
| 543 |
+
token_warmup_min: {subset.token_warmup_min},
|
| 544 |
+
token_warmup_step: {subset.token_warmup_step},
|
| 545 |
+
alpha_mask: {subset.alpha_mask},
|
| 546 |
+
"""
|
| 547 |
+
),
|
| 548 |
+
" ",
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
if is_dreambooth:
|
| 552 |
+
info += indent(
|
| 553 |
+
dedent(
|
| 554 |
+
f"""\
|
| 555 |
+
is_reg: {subset.is_reg}
|
| 556 |
+
class_tokens: {subset.class_tokens}
|
| 557 |
+
caption_extension: {subset.caption_extension}
|
| 558 |
+
\n"""
|
| 559 |
+
),
|
| 560 |
+
" ",
|
| 561 |
+
)
|
| 562 |
+
elif not is_controlnet:
|
| 563 |
+
info += indent(
|
| 564 |
+
dedent(
|
| 565 |
+
f"""\
|
| 566 |
+
metadata_file: {subset.metadata_file}
|
| 567 |
+
\n"""
|
| 568 |
+
),
|
| 569 |
+
" ",
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
logger.info(f"{info}")
|
| 573 |
+
|
| 574 |
+
# make buckets first because it determines the length of dataset
|
| 575 |
+
# and set the same seed for all datasets
|
| 576 |
+
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
| 577 |
+
for i, dataset in enumerate(datasets):
|
| 578 |
+
logger.info(f"[Dataset {i}]")
|
| 579 |
+
dataset.make_buckets()
|
| 580 |
+
dataset.set_seed(seed)
|
| 581 |
+
|
| 582 |
+
return DatasetGroup(datasets)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
|
| 586 |
+
def extract_dreambooth_params(name: str) -> Tuple[int, str]:
|
| 587 |
+
tokens = name.split("_")
|
| 588 |
+
try:
|
| 589 |
+
n_repeats = int(tokens[0])
|
| 590 |
+
except ValueError as e:
|
| 591 |
+
logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
|
| 592 |
+
return 0, ""
|
| 593 |
+
caption_by_folder = "_".join(tokens[1:])
|
| 594 |
+
return n_repeats, caption_by_folder
|
| 595 |
+
|
| 596 |
+
def generate(base_dir: Optional[str], is_reg: bool):
|
| 597 |
+
if base_dir is None:
|
| 598 |
+
return []
|
| 599 |
+
|
| 600 |
+
base_dir: Path = Path(base_dir)
|
| 601 |
+
if not base_dir.is_dir():
|
| 602 |
+
return []
|
| 603 |
+
|
| 604 |
+
subsets_config = []
|
| 605 |
+
for subdir in base_dir.iterdir():
|
| 606 |
+
if not subdir.is_dir():
|
| 607 |
+
continue
|
| 608 |
+
|
| 609 |
+
num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
|
| 610 |
+
if num_repeats < 1:
|
| 611 |
+
continue
|
| 612 |
+
|
| 613 |
+
subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
|
| 614 |
+
subsets_config.append(subset_config)
|
| 615 |
+
|
| 616 |
+
return subsets_config
|
| 617 |
+
|
| 618 |
+
subsets_config = []
|
| 619 |
+
subsets_config += generate(train_data_dir, False)
|
| 620 |
+
subsets_config += generate(reg_data_dir, True)
|
| 621 |
+
|
| 622 |
+
return subsets_config
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def generate_controlnet_subsets_config_by_subdirs(
|
| 626 |
+
train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"
|
| 627 |
+
):
|
| 628 |
+
def generate(base_dir: Optional[str]):
|
| 629 |
+
if base_dir is None:
|
| 630 |
+
return []
|
| 631 |
+
|
| 632 |
+
base_dir: Path = Path(base_dir)
|
| 633 |
+
if not base_dir.is_dir():
|
| 634 |
+
return []
|
| 635 |
+
|
| 636 |
+
subsets_config = []
|
| 637 |
+
subset_config = {
|
| 638 |
+
"image_dir": train_data_dir,
|
| 639 |
+
"conditioning_data_dir": conditioning_data_dir,
|
| 640 |
+
"caption_extension": caption_extension,
|
| 641 |
+
"num_repeats": 1,
|
| 642 |
+
}
|
| 643 |
+
subsets_config.append(subset_config)
|
| 644 |
+
|
| 645 |
+
return subsets_config
|
| 646 |
+
|
| 647 |
+
subsets_config = []
|
| 648 |
+
subsets_config += generate(train_data_dir)
|
| 649 |
+
|
| 650 |
+
return subsets_config
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
def load_user_config(file: str) -> dict:
|
| 654 |
+
file: Path = Path(file)
|
| 655 |
+
if not file.is_file():
|
| 656 |
+
raise ValueError(f"file not found / ファイルが見つかりません: {file}")
|
| 657 |
+
|
| 658 |
+
if file.name.lower().endswith(".json"):
|
| 659 |
+
try:
|
| 660 |
+
with open(file, "r") as f:
|
| 661 |
+
config = json.load(f)
|
| 662 |
+
except Exception:
|
| 663 |
+
logger.error(
|
| 664 |
+
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
| 665 |
+
)
|
| 666 |
+
raise
|
| 667 |
+
elif file.name.lower().endswith(".toml"):
|
| 668 |
+
try:
|
| 669 |
+
config = toml.load(file)
|
| 670 |
+
except Exception:
|
| 671 |
+
logger.error(
|
| 672 |
+
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
| 673 |
+
)
|
| 674 |
+
raise
|
| 675 |
+
else:
|
| 676 |
+
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
| 677 |
+
|
| 678 |
+
return config
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
# for config test
|
| 682 |
+
if __name__ == "__main__":
|
| 683 |
+
parser = argparse.ArgumentParser()
|
| 684 |
+
parser.add_argument("--support_dreambooth", action="store_true")
|
| 685 |
+
parser.add_argument("--support_finetuning", action="store_true")
|
| 686 |
+
parser.add_argument("--support_controlnet", action="store_true")
|
| 687 |
+
parser.add_argument("--support_dropout", action="store_true")
|
| 688 |
+
parser.add_argument("dataset_config")
|
| 689 |
+
config_args, remain = parser.parse_known_args()
|
| 690 |
+
|
| 691 |
+
parser = argparse.ArgumentParser()
|
| 692 |
+
train_util.add_dataset_arguments(
|
| 693 |
+
parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout
|
| 694 |
+
)
|
| 695 |
+
train_util.add_training_arguments(parser, config_args.support_dreambooth)
|
| 696 |
+
argparse_namespace = parser.parse_args(remain)
|
| 697 |
+
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
| 698 |
+
|
| 699 |
+
logger.info("[argparse_namespace]")
|
| 700 |
+
logger.info(f"{vars(argparse_namespace)}")
|
| 701 |
+
|
| 702 |
+
user_config = load_user_config(config_args.dataset_config)
|
| 703 |
+
|
| 704 |
+
logger.info("")
|
| 705 |
+
logger.info("[user_config]")
|
| 706 |
+
logger.info(f"{user_config}")
|
| 707 |
+
|
| 708 |
+
sanitizer = ConfigSanitizer(
|
| 709 |
+
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
|
| 710 |
+
)
|
| 711 |
+
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
| 712 |
+
|
| 713 |
+
logger.info("")
|
| 714 |
+
logger.info("[sanitized_user_config]")
|
| 715 |
+
logger.info(f"{sanitized_user_config}")
|
| 716 |
+
|
| 717 |
+
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
| 718 |
+
|
| 719 |
+
logger.info("")
|
| 720 |
+
logger.info("[blueprint]")
|
| 721 |
+
logger.info(f"{blueprint}")
|
sd-scripts/model_util.py
ADDED
|
@@ -0,0 +1,1355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# v1: split from train_db_fixed.py.
|
| 2 |
+
# v2: support safetensors
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from library.device_utils import init_ipex
|
| 9 |
+
init_ipex()
|
| 10 |
+
|
| 11 |
+
import diffusers
|
| 12 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
| 13 |
+
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
| 14 |
+
from safetensors.torch import load_file, save_file
|
| 15 |
+
from library.original_unet import UNet2DConditionModel
|
| 16 |
+
from library.utils import setup_logging
|
| 17 |
+
setup_logging()
|
| 18 |
+
import logging
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# DiffUsers版StableDiffusionのモデルパラメータ
|
| 22 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
| 23 |
+
BETA_START = 0.00085
|
| 24 |
+
BETA_END = 0.0120
|
| 25 |
+
|
| 26 |
+
UNET_PARAMS_MODEL_CHANNELS = 320
|
| 27 |
+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
| 28 |
+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
| 29 |
+
UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
|
| 30 |
+
UNET_PARAMS_IN_CHANNELS = 4
|
| 31 |
+
UNET_PARAMS_OUT_CHANNELS = 4
|
| 32 |
+
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
| 33 |
+
UNET_PARAMS_CONTEXT_DIM = 768
|
| 34 |
+
UNET_PARAMS_NUM_HEADS = 8
|
| 35 |
+
# UNET_PARAMS_USE_LINEAR_PROJECTION = False
|
| 36 |
+
|
| 37 |
+
VAE_PARAMS_Z_CHANNELS = 4
|
| 38 |
+
VAE_PARAMS_RESOLUTION = 256
|
| 39 |
+
VAE_PARAMS_IN_CHANNELS = 3
|
| 40 |
+
VAE_PARAMS_OUT_CH = 3
|
| 41 |
+
VAE_PARAMS_CH = 128
|
| 42 |
+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
| 43 |
+
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
| 44 |
+
|
| 45 |
+
# V2
|
| 46 |
+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
| 47 |
+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
| 48 |
+
# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
|
| 49 |
+
|
| 50 |
+
# Diffusersの設定を読み込むための参照モデル
|
| 51 |
+
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
| 52 |
+
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# region StableDiffusion->Diffusersの変換コード
|
| 56 |
+
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
| 60 |
+
"""
|
| 61 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
| 62 |
+
"""
|
| 63 |
+
if n_shave_prefix_segments >= 0:
|
| 64 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
| 65 |
+
else:
|
| 66 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 70 |
+
"""
|
| 71 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 72 |
+
"""
|
| 73 |
+
mapping = []
|
| 74 |
+
for old_item in old_list:
|
| 75 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
| 76 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
| 77 |
+
|
| 78 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
| 79 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
| 80 |
+
|
| 81 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
| 82 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
| 83 |
+
|
| 84 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 85 |
+
|
| 86 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 87 |
+
|
| 88 |
+
return mapping
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 92 |
+
"""
|
| 93 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 94 |
+
"""
|
| 95 |
+
mapping = []
|
| 96 |
+
for old_item in old_list:
|
| 97 |
+
new_item = old_item
|
| 98 |
+
|
| 99 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
| 100 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 101 |
+
|
| 102 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 103 |
+
|
| 104 |
+
return mapping
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 108 |
+
"""
|
| 109 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 110 |
+
"""
|
| 111 |
+
mapping = []
|
| 112 |
+
for old_item in old_list:
|
| 113 |
+
new_item = old_item
|
| 114 |
+
|
| 115 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
| 116 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
| 117 |
+
|
| 118 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
| 119 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
| 120 |
+
|
| 121 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 122 |
+
|
| 123 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 124 |
+
|
| 125 |
+
return mapping
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 129 |
+
"""
|
| 130 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 131 |
+
"""
|
| 132 |
+
mapping = []
|
| 133 |
+
for old_item in old_list:
|
| 134 |
+
new_item = old_item
|
| 135 |
+
|
| 136 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
| 137 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
| 138 |
+
|
| 139 |
+
if diffusers.__version__ < "0.17.0":
|
| 140 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
| 141 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
| 142 |
+
|
| 143 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
| 144 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
| 145 |
+
|
| 146 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
| 147 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
| 148 |
+
|
| 149 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
| 150 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
| 151 |
+
else:
|
| 152 |
+
new_item = new_item.replace("q.weight", "to_q.weight")
|
| 153 |
+
new_item = new_item.replace("q.bias", "to_q.bias")
|
| 154 |
+
|
| 155 |
+
new_item = new_item.replace("k.weight", "to_k.weight")
|
| 156 |
+
new_item = new_item.replace("k.bias", "to_k.bias")
|
| 157 |
+
|
| 158 |
+
new_item = new_item.replace("v.weight", "to_v.weight")
|
| 159 |
+
new_item = new_item.replace("v.bias", "to_v.bias")
|
| 160 |
+
|
| 161 |
+
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
| 162 |
+
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
| 163 |
+
|
| 164 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 165 |
+
|
| 166 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 167 |
+
|
| 168 |
+
return mapping
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def assign_to_checkpoint(
|
| 172 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
| 173 |
+
):
|
| 174 |
+
"""
|
| 175 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
| 176 |
+
to them. It splits attention layers, and takes into account additional replacements
|
| 177 |
+
that may arise.
|
| 178 |
+
|
| 179 |
+
Assigns the weights to the new checkpoint.
|
| 180 |
+
"""
|
| 181 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
| 182 |
+
|
| 183 |
+
# Splits the attention layers into three variables.
|
| 184 |
+
if attention_paths_to_split is not None:
|
| 185 |
+
for path, path_map in attention_paths_to_split.items():
|
| 186 |
+
old_tensor = old_checkpoint[path]
|
| 187 |
+
channels = old_tensor.shape[0] // 3
|
| 188 |
+
|
| 189 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
| 190 |
+
|
| 191 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
| 192 |
+
|
| 193 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
| 194 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
| 195 |
+
|
| 196 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
| 197 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
| 198 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
| 199 |
+
|
| 200 |
+
for path in paths:
|
| 201 |
+
new_path = path["new"]
|
| 202 |
+
|
| 203 |
+
# These have already been assigned
|
| 204 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
| 205 |
+
continue
|
| 206 |
+
|
| 207 |
+
# Global renaming happens here
|
| 208 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
| 209 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
| 210 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
| 211 |
+
|
| 212 |
+
if additional_replacements is not None:
|
| 213 |
+
for replacement in additional_replacements:
|
| 214 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
| 215 |
+
|
| 216 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
| 217 |
+
reshaping = False
|
| 218 |
+
if diffusers.__version__ < "0.17.0":
|
| 219 |
+
if "proj_attn.weight" in new_path:
|
| 220 |
+
reshaping = True
|
| 221 |
+
else:
|
| 222 |
+
if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
|
| 223 |
+
reshaping = True
|
| 224 |
+
|
| 225 |
+
if reshaping:
|
| 226 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
| 227 |
+
else:
|
| 228 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def conv_attn_to_linear(checkpoint):
|
| 232 |
+
keys = list(checkpoint.keys())
|
| 233 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
| 234 |
+
for key in keys:
|
| 235 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
| 236 |
+
if checkpoint[key].ndim > 2:
|
| 237 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
| 238 |
+
elif "proj_attn.weight" in key:
|
| 239 |
+
if checkpoint[key].ndim > 2:
|
| 240 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def linear_transformer_to_conv(checkpoint):
|
| 244 |
+
keys = list(checkpoint.keys())
|
| 245 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
| 246 |
+
for key in keys:
|
| 247 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
| 248 |
+
if checkpoint[key].ndim == 2:
|
| 249 |
+
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
| 253 |
+
"""
|
| 254 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
# extract state_dict for UNet
|
| 258 |
+
unet_state_dict = {}
|
| 259 |
+
unet_key = "model.diffusion_model."
|
| 260 |
+
keys = list(checkpoint.keys())
|
| 261 |
+
for key in keys:
|
| 262 |
+
if key.startswith(unet_key):
|
| 263 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
| 264 |
+
|
| 265 |
+
new_checkpoint = {}
|
| 266 |
+
|
| 267 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
| 268 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
| 269 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
| 270 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
| 271 |
+
|
| 272 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
| 273 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
| 274 |
+
|
| 275 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
| 276 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
| 277 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
| 278 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
| 279 |
+
|
| 280 |
+
# Retrieves the keys for the input blocks only
|
| 281 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
| 282 |
+
input_blocks = {
|
| 283 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
# Retrieves the keys for the middle blocks only
|
| 287 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
| 288 |
+
middle_blocks = {
|
| 289 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
# Retrieves the keys for the output blocks only
|
| 293 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
| 294 |
+
output_blocks = {
|
| 295 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
for i in range(1, num_input_blocks):
|
| 299 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
| 300 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
| 301 |
+
|
| 302 |
+
resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
|
| 303 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
| 304 |
+
|
| 305 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
| 306 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
| 307 |
+
f"input_blocks.{i}.0.op.weight"
|
| 308 |
+
)
|
| 309 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
|
| 310 |
+
|
| 311 |
+
paths = renew_resnet_paths(resnets)
|
| 312 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 313 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
| 314 |
+
|
| 315 |
+
if len(attentions):
|
| 316 |
+
paths = renew_attention_paths(attentions)
|
| 317 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
| 318 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
| 319 |
+
|
| 320 |
+
resnet_0 = middle_blocks[0]
|
| 321 |
+
attentions = middle_blocks[1]
|
| 322 |
+
resnet_1 = middle_blocks[2]
|
| 323 |
+
|
| 324 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
| 325 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
| 326 |
+
|
| 327 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
| 328 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
| 329 |
+
|
| 330 |
+
attentions_paths = renew_attention_paths(attentions)
|
| 331 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
| 332 |
+
assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
| 333 |
+
|
| 334 |
+
for i in range(num_output_blocks):
|
| 335 |
+
block_id = i // (config["layers_per_block"] + 1)
|
| 336 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
| 337 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
| 338 |
+
output_block_list = {}
|
| 339 |
+
|
| 340 |
+
for layer in output_block_layers:
|
| 341 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
| 342 |
+
if layer_id in output_block_list:
|
| 343 |
+
output_block_list[layer_id].append(layer_name)
|
| 344 |
+
else:
|
| 345 |
+
output_block_list[layer_id] = [layer_name]
|
| 346 |
+
|
| 347 |
+
if len(output_block_list) > 1:
|
| 348 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
| 349 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
| 350 |
+
|
| 351 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
| 352 |
+
paths = renew_resnet_paths(resnets)
|
| 353 |
+
|
| 354 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 355 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
| 356 |
+
|
| 357 |
+
# オリジナル:
|
| 358 |
+
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
| 359 |
+
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
| 360 |
+
|
| 361 |
+
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
| 362 |
+
for l in output_block_list.values():
|
| 363 |
+
l.sort()
|
| 364 |
+
|
| 365 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
| 366 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
| 367 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
| 368 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
| 369 |
+
]
|
| 370 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
| 371 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
| 372 |
+
]
|
| 373 |
+
|
| 374 |
+
# Clear attentions as they have been attributed above.
|
| 375 |
+
if len(attentions) == 2:
|
| 376 |
+
attentions = []
|
| 377 |
+
|
| 378 |
+
if len(attentions):
|
| 379 |
+
paths = renew_attention_paths(attentions)
|
| 380 |
+
meta_path = {
|
| 381 |
+
"old": f"output_blocks.{i}.1",
|
| 382 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
| 383 |
+
}
|
| 384 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
| 385 |
+
else:
|
| 386 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
| 387 |
+
for path in resnet_0_paths:
|
| 388 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
| 389 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
| 390 |
+
|
| 391 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
| 392 |
+
|
| 393 |
+
# SDのv2では1*1のconv2dがlinearに変わっている
|
| 394 |
+
# 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
|
| 395 |
+
if v2 and not config.get("use_linear_projection", False):
|
| 396 |
+
linear_transformer_to_conv(new_checkpoint)
|
| 397 |
+
|
| 398 |
+
return new_checkpoint
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
| 402 |
+
# extract state dict for VAE
|
| 403 |
+
vae_state_dict = {}
|
| 404 |
+
vae_key = "first_stage_model."
|
| 405 |
+
keys = list(checkpoint.keys())
|
| 406 |
+
for key in keys:
|
| 407 |
+
if key.startswith(vae_key):
|
| 408 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
| 409 |
+
# if len(vae_state_dict) == 0:
|
| 410 |
+
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
| 411 |
+
# vae_state_dict = checkpoint
|
| 412 |
+
|
| 413 |
+
new_checkpoint = {}
|
| 414 |
+
|
| 415 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
| 416 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
| 417 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
| 418 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
| 419 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
| 420 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
| 421 |
+
|
| 422 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
| 423 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
| 424 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
| 425 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
| 426 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
| 427 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
| 428 |
+
|
| 429 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
| 430 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
| 431 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
| 432 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
| 433 |
+
|
| 434 |
+
# Retrieves the keys for the encoder down blocks only
|
| 435 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
| 436 |
+
down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
|
| 437 |
+
|
| 438 |
+
# Retrieves the keys for the decoder up blocks only
|
| 439 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
| 440 |
+
up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
|
| 441 |
+
|
| 442 |
+
for i in range(num_down_blocks):
|
| 443 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
| 444 |
+
|
| 445 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
| 446 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
| 447 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
| 448 |
+
)
|
| 449 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
| 450 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 454 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
| 455 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 456 |
+
|
| 457 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
| 458 |
+
num_mid_res_blocks = 2
|
| 459 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 460 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
| 461 |
+
|
| 462 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 463 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 464 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 465 |
+
|
| 466 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
| 467 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 468 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 469 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 470 |
+
conv_attn_to_linear(new_checkpoint)
|
| 471 |
+
|
| 472 |
+
for i in range(num_up_blocks):
|
| 473 |
+
block_id = num_up_blocks - 1 - i
|
| 474 |
+
resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
|
| 475 |
+
|
| 476 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
| 477 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
| 478 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
| 479 |
+
]
|
| 480 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
| 481 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
| 482 |
+
]
|
| 483 |
+
|
| 484 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 485 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
| 486 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 487 |
+
|
| 488 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
| 489 |
+
num_mid_res_blocks = 2
|
| 490 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 491 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
| 492 |
+
|
| 493 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 494 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 495 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 496 |
+
|
| 497 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
| 498 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 499 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 500 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 501 |
+
conv_attn_to_linear(new_checkpoint)
|
| 502 |
+
return new_checkpoint
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
|
| 506 |
+
"""
|
| 507 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
| 508 |
+
"""
|
| 509 |
+
# unet_params = original_config.model.params.unet_config.params
|
| 510 |
+
|
| 511 |
+
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
|
| 512 |
+
|
| 513 |
+
down_block_types = []
|
| 514 |
+
resolution = 1
|
| 515 |
+
for i in range(len(block_out_channels)):
|
| 516 |
+
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
|
| 517 |
+
down_block_types.append(block_type)
|
| 518 |
+
if i != len(block_out_channels) - 1:
|
| 519 |
+
resolution *= 2
|
| 520 |
+
|
| 521 |
+
up_block_types = []
|
| 522 |
+
for i in range(len(block_out_channels)):
|
| 523 |
+
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
|
| 524 |
+
up_block_types.append(block_type)
|
| 525 |
+
resolution //= 2
|
| 526 |
+
|
| 527 |
+
config = dict(
|
| 528 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
| 529 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
| 530 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
| 531 |
+
down_block_types=tuple(down_block_types),
|
| 532 |
+
up_block_types=tuple(up_block_types),
|
| 533 |
+
block_out_channels=tuple(block_out_channels),
|
| 534 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
| 535 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
| 536 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
| 537 |
+
# use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
|
| 538 |
+
)
|
| 539 |
+
if v2 and use_linear_projection_in_v2:
|
| 540 |
+
config["use_linear_projection"] = True
|
| 541 |
+
|
| 542 |
+
return config
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def create_vae_diffusers_config():
|
| 546 |
+
"""
|
| 547 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
| 548 |
+
"""
|
| 549 |
+
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
| 550 |
+
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
| 551 |
+
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
| 552 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
| 553 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
| 554 |
+
|
| 555 |
+
config = dict(
|
| 556 |
+
sample_size=VAE_PARAMS_RESOLUTION,
|
| 557 |
+
in_channels=VAE_PARAMS_IN_CHANNELS,
|
| 558 |
+
out_channels=VAE_PARAMS_OUT_CH,
|
| 559 |
+
down_block_types=tuple(down_block_types),
|
| 560 |
+
up_block_types=tuple(up_block_types),
|
| 561 |
+
block_out_channels=tuple(block_out_channels),
|
| 562 |
+
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
| 563 |
+
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
| 564 |
+
)
|
| 565 |
+
return config
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
| 569 |
+
keys = list(checkpoint.keys())
|
| 570 |
+
text_model_dict = {}
|
| 571 |
+
for key in keys:
|
| 572 |
+
if key.startswith("cond_stage_model.transformer"):
|
| 573 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
| 574 |
+
|
| 575 |
+
# remove position_ids for newer transformer, which causes error :(
|
| 576 |
+
if "text_model.embeddings.position_ids" in text_model_dict:
|
| 577 |
+
text_model_dict.pop("text_model.embeddings.position_ids")
|
| 578 |
+
|
| 579 |
+
return text_model_dict
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
| 583 |
+
# 嫌になるくらい違うぞ!
|
| 584 |
+
def convert_key(key):
|
| 585 |
+
if not key.startswith("cond_stage_model"):
|
| 586 |
+
return None
|
| 587 |
+
|
| 588 |
+
# common conversion
|
| 589 |
+
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
| 590 |
+
key = key.replace("cond_stage_model.model.", "text_model.")
|
| 591 |
+
|
| 592 |
+
if "resblocks" in key:
|
| 593 |
+
# resblocks conversion
|
| 594 |
+
key = key.replace(".resblocks.", ".layers.")
|
| 595 |
+
if ".ln_" in key:
|
| 596 |
+
key = key.replace(".ln_", ".layer_norm")
|
| 597 |
+
elif ".mlp." in key:
|
| 598 |
+
key = key.replace(".c_fc.", ".fc1.")
|
| 599 |
+
key = key.replace(".c_proj.", ".fc2.")
|
| 600 |
+
elif ".attn.out_proj" in key:
|
| 601 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
| 602 |
+
elif ".attn.in_proj" in key:
|
| 603 |
+
key = None # 特殊なので後で処理する
|
| 604 |
+
else:
|
| 605 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
| 606 |
+
elif ".positional_embedding" in key:
|
| 607 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
| 608 |
+
elif ".text_projection" in key:
|
| 609 |
+
key = None # 使われない???
|
| 610 |
+
elif ".logit_scale" in key:
|
| 611 |
+
key = None # 使われない???
|
| 612 |
+
elif ".token_embedding" in key:
|
| 613 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
| 614 |
+
elif ".ln_final" in key:
|
| 615 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
| 616 |
+
return key
|
| 617 |
+
|
| 618 |
+
keys = list(checkpoint.keys())
|
| 619 |
+
new_sd = {}
|
| 620 |
+
for key in keys:
|
| 621 |
+
# remove resblocks 23
|
| 622 |
+
if ".resblocks.23." in key:
|
| 623 |
+
continue
|
| 624 |
+
new_key = convert_key(key)
|
| 625 |
+
if new_key is None:
|
| 626 |
+
continue
|
| 627 |
+
new_sd[new_key] = checkpoint[key]
|
| 628 |
+
|
| 629 |
+
# attnの変換
|
| 630 |
+
for key in keys:
|
| 631 |
+
if ".resblocks.23." in key:
|
| 632 |
+
continue
|
| 633 |
+
if ".resblocks" in key and ".attn.in_proj_" in key:
|
| 634 |
+
# 三つに分割
|
| 635 |
+
values = torch.chunk(checkpoint[key], 3)
|
| 636 |
+
|
| 637 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
| 638 |
+
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
| 639 |
+
key_pfx = key_pfx.replace("_weight", "")
|
| 640 |
+
key_pfx = key_pfx.replace("_bias", "")
|
| 641 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
| 642 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
| 643 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
| 644 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
| 645 |
+
|
| 646 |
+
# remove position_ids for newer transformer, which causes error :(
|
| 647 |
+
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
| 648 |
+
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
| 649 |
+
# waifu diffusion v1.4
|
| 650 |
+
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
| 651 |
+
|
| 652 |
+
if "text_model.embeddings.position_ids" in new_sd:
|
| 653 |
+
del new_sd["text_model.embeddings.position_ids"]
|
| 654 |
+
|
| 655 |
+
return new_sd
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
# endregion
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
# region Diffusers->StableDiffusion の変換コード
|
| 662 |
+
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def conv_transformer_to_linear(checkpoint):
|
| 666 |
+
keys = list(checkpoint.keys())
|
| 667 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
| 668 |
+
for key in keys:
|
| 669 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
| 670 |
+
if checkpoint[key].ndim > 2:
|
| 671 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
| 675 |
+
unet_conversion_map = [
|
| 676 |
+
# (stable-diffusion, HF Diffusers)
|
| 677 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
| 678 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
| 679 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
| 680 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
| 681 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
| 682 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
| 683 |
+
("out.0.weight", "conv_norm_out.weight"),
|
| 684 |
+
("out.0.bias", "conv_norm_out.bias"),
|
| 685 |
+
("out.2.weight", "conv_out.weight"),
|
| 686 |
+
("out.2.bias", "conv_out.bias"),
|
| 687 |
+
]
|
| 688 |
+
|
| 689 |
+
unet_conversion_map_resnet = [
|
| 690 |
+
# (stable-diffusion, HF Diffusers)
|
| 691 |
+
("in_layers.0", "norm1"),
|
| 692 |
+
("in_layers.2", "conv1"),
|
| 693 |
+
("out_layers.0", "norm2"),
|
| 694 |
+
("out_layers.3", "conv2"),
|
| 695 |
+
("emb_layers.1", "time_emb_proj"),
|
| 696 |
+
("skip_connection", "conv_shortcut"),
|
| 697 |
+
]
|
| 698 |
+
|
| 699 |
+
unet_conversion_map_layer = []
|
| 700 |
+
for i in range(4):
|
| 701 |
+
# loop over downblocks/upblocks
|
| 702 |
+
|
| 703 |
+
for j in range(2):
|
| 704 |
+
# loop over resnets/attentions for downblocks
|
| 705 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
| 706 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
| 707 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
| 708 |
+
|
| 709 |
+
if i < 3:
|
| 710 |
+
# no attention layers in down_blocks.3
|
| 711 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
| 712 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
| 713 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
| 714 |
+
|
| 715 |
+
for j in range(3):
|
| 716 |
+
# loop over resnets/attentions for upblocks
|
| 717 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
| 718 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
| 719 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
| 720 |
+
|
| 721 |
+
if i > 0:
|
| 722 |
+
# no attention layers in up_blocks.0
|
| 723 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
| 724 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
| 725 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
| 726 |
+
|
| 727 |
+
if i < 3:
|
| 728 |
+
# no downsample in down_blocks.3
|
| 729 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
| 730 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
| 731 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
| 732 |
+
|
| 733 |
+
# no upsample in up_blocks.3
|
| 734 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
| 735 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
| 736 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
| 737 |
+
|
| 738 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
| 739 |
+
sd_mid_atn_prefix = "middle_block.1."
|
| 740 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
| 741 |
+
|
| 742 |
+
for j in range(2):
|
| 743 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
| 744 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
| 745 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
| 746 |
+
|
| 747 |
+
# buyer beware: this is a *brittle* function,
|
| 748 |
+
# and correct output requires that all of these pieces interact in
|
| 749 |
+
# the exact order in which I have arranged them.
|
| 750 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
| 751 |
+
for sd_name, hf_name in unet_conversion_map:
|
| 752 |
+
mapping[hf_name] = sd_name
|
| 753 |
+
for k, v in mapping.items():
|
| 754 |
+
if "resnets" in k:
|
| 755 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
| 756 |
+
v = v.replace(hf_part, sd_part)
|
| 757 |
+
mapping[k] = v
|
| 758 |
+
for k, v in mapping.items():
|
| 759 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
| 760 |
+
v = v.replace(hf_part, sd_part)
|
| 761 |
+
mapping[k] = v
|
| 762 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
| 763 |
+
|
| 764 |
+
if v2:
|
| 765 |
+
conv_transformer_to_linear(new_state_dict)
|
| 766 |
+
|
| 767 |
+
return new_state_dict
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
def controlnet_conversion_map():
|
| 771 |
+
unet_conversion_map = [
|
| 772 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
| 773 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
| 774 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
| 775 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
| 776 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
| 777 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
| 778 |
+
("middle_block_out.0.weight", "controlnet_mid_block.weight"),
|
| 779 |
+
("middle_block_out.0.bias", "controlnet_mid_block.bias"),
|
| 780 |
+
]
|
| 781 |
+
|
| 782 |
+
unet_conversion_map_resnet = [
|
| 783 |
+
("in_layers.0", "norm1"),
|
| 784 |
+
("in_layers.2", "conv1"),
|
| 785 |
+
("out_layers.0", "norm2"),
|
| 786 |
+
("out_layers.3", "conv2"),
|
| 787 |
+
("emb_layers.1", "time_emb_proj"),
|
| 788 |
+
("skip_connection", "conv_shortcut"),
|
| 789 |
+
]
|
| 790 |
+
|
| 791 |
+
unet_conversion_map_layer = []
|
| 792 |
+
for i in range(4):
|
| 793 |
+
for j in range(2):
|
| 794 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
| 795 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
| 796 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
| 797 |
+
|
| 798 |
+
if i < 3:
|
| 799 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
| 800 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
| 801 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
| 802 |
+
|
| 803 |
+
if i < 3:
|
| 804 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
| 805 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
| 806 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
| 807 |
+
|
| 808 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
| 809 |
+
sd_mid_atn_prefix = "middle_block.1."
|
| 810 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
| 811 |
+
|
| 812 |
+
for j in range(2):
|
| 813 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
| 814 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
| 815 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
| 816 |
+
|
| 817 |
+
controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
|
| 818 |
+
for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
|
| 819 |
+
hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
|
| 820 |
+
sd_prefix = f"input_hint_block.{i*2}."
|
| 821 |
+
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
| 822 |
+
|
| 823 |
+
for i in range(12):
|
| 824 |
+
hf_prefix = f"controlnet_down_blocks.{i}."
|
| 825 |
+
sd_prefix = f"zero_convs.{i}.0."
|
| 826 |
+
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
| 827 |
+
|
| 828 |
+
return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
| 832 |
+
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
| 833 |
+
|
| 834 |
+
mapping = {k: k for k in controlnet_state_dict.keys()}
|
| 835 |
+
for sd_name, diffusers_name in unet_conversion_map:
|
| 836 |
+
mapping[diffusers_name] = sd_name
|
| 837 |
+
for k, v in mapping.items():
|
| 838 |
+
if "resnets" in k:
|
| 839 |
+
for sd_part, diffusers_part in unet_conversion_map_resnet:
|
| 840 |
+
v = v.replace(diffusers_part, sd_part)
|
| 841 |
+
mapping[k] = v
|
| 842 |
+
for k, v in mapping.items():
|
| 843 |
+
for sd_part, diffusers_part in unet_conversion_map_layer:
|
| 844 |
+
v = v.replace(diffusers_part, sd_part)
|
| 845 |
+
mapping[k] = v
|
| 846 |
+
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
| 847 |
+
return new_state_dict
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
|
| 851 |
+
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
| 852 |
+
|
| 853 |
+
mapping = {k: k for k in controlnet_state_dict.keys()}
|
| 854 |
+
for sd_name, diffusers_name in unet_conversion_map:
|
| 855 |
+
mapping[sd_name] = diffusers_name
|
| 856 |
+
for k, v in mapping.items():
|
| 857 |
+
for sd_part, diffusers_part in unet_conversion_map_layer:
|
| 858 |
+
v = v.replace(sd_part, diffusers_part)
|
| 859 |
+
mapping[k] = v
|
| 860 |
+
for k, v in mapping.items():
|
| 861 |
+
if "resnets" in v:
|
| 862 |
+
for sd_part, diffusers_part in unet_conversion_map_resnet:
|
| 863 |
+
v = v.replace(sd_part, diffusers_part)
|
| 864 |
+
mapping[k] = v
|
| 865 |
+
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
| 866 |
+
return new_state_dict
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
# ================#
|
| 870 |
+
# VAE Conversion #
|
| 871 |
+
# ================#
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
def reshape_weight_for_sd(w):
|
| 875 |
+
# convert HF linear weights to SD conv2d weights
|
| 876 |
+
return w.reshape(*w.shape, 1, 1)
|
| 877 |
+
|
| 878 |
+
|
| 879 |
+
def convert_vae_state_dict(vae_state_dict):
|
| 880 |
+
vae_conversion_map = [
|
| 881 |
+
# (stable-diffusion, HF Diffusers)
|
| 882 |
+
("nin_shortcut", "conv_shortcut"),
|
| 883 |
+
("norm_out", "conv_norm_out"),
|
| 884 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
| 885 |
+
]
|
| 886 |
+
|
| 887 |
+
for i in range(4):
|
| 888 |
+
# down_blocks have two resnets
|
| 889 |
+
for j in range(2):
|
| 890 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
| 891 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
| 892 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
| 893 |
+
|
| 894 |
+
if i < 3:
|
| 895 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
| 896 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
| 897 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
| 898 |
+
|
| 899 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
| 900 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
| 901 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
| 902 |
+
|
| 903 |
+
# up_blocks have three resnets
|
| 904 |
+
# also, up blocks in hf are numbered in reverse from sd
|
| 905 |
+
for j in range(3):
|
| 906 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
| 907 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
| 908 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
| 909 |
+
|
| 910 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
| 911 |
+
for i in range(2):
|
| 912 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
| 913 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
| 914 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
| 915 |
+
|
| 916 |
+
if diffusers.__version__ < "0.17.0":
|
| 917 |
+
vae_conversion_map_attn = [
|
| 918 |
+
# (stable-diffusion, HF Diffusers)
|
| 919 |
+
("norm.", "group_norm."),
|
| 920 |
+
("q.", "query."),
|
| 921 |
+
("k.", "key."),
|
| 922 |
+
("v.", "value."),
|
| 923 |
+
("proj_out.", "proj_attn."),
|
| 924 |
+
]
|
| 925 |
+
else:
|
| 926 |
+
vae_conversion_map_attn = [
|
| 927 |
+
# (stable-diffusion, HF Diffusers)
|
| 928 |
+
("norm.", "group_norm."),
|
| 929 |
+
("q.", "to_q."),
|
| 930 |
+
("k.", "to_k."),
|
| 931 |
+
("v.", "to_v."),
|
| 932 |
+
("proj_out.", "to_out.0."),
|
| 933 |
+
]
|
| 934 |
+
|
| 935 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
| 936 |
+
for k, v in mapping.items():
|
| 937 |
+
for sd_part, hf_part in vae_conversion_map:
|
| 938 |
+
v = v.replace(hf_part, sd_part)
|
| 939 |
+
mapping[k] = v
|
| 940 |
+
for k, v in mapping.items():
|
| 941 |
+
if "attentions" in k:
|
| 942 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
| 943 |
+
v = v.replace(hf_part, sd_part)
|
| 944 |
+
mapping[k] = v
|
| 945 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
| 946 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
| 947 |
+
for k, v in new_state_dict.items():
|
| 948 |
+
for weight_name in weights_to_convert:
|
| 949 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
| 950 |
+
# logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
|
| 951 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
| 952 |
+
|
| 953 |
+
return new_state_dict
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
# endregion
|
| 957 |
+
|
| 958 |
+
# region 自作のモデル読み書きなど
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
def is_safetensors(path):
|
| 962 |
+
return os.path.splitext(path)[1].lower() == ".safetensors"
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
| 966 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
| 967 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
| 968 |
+
("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
|
| 969 |
+
("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
|
| 970 |
+
("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
|
| 971 |
+
]
|
| 972 |
+
|
| 973 |
+
if is_safetensors(ckpt_path):
|
| 974 |
+
checkpoint = None
|
| 975 |
+
state_dict = load_file(ckpt_path) # , device) # may causes error
|
| 976 |
+
else:
|
| 977 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
| 978 |
+
if "state_dict" in checkpoint:
|
| 979 |
+
state_dict = checkpoint["state_dict"]
|
| 980 |
+
else:
|
| 981 |
+
state_dict = checkpoint
|
| 982 |
+
checkpoint = None
|
| 983 |
+
|
| 984 |
+
key_reps = []
|
| 985 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
| 986 |
+
for key in state_dict.keys():
|
| 987 |
+
if key.startswith(rep_from):
|
| 988 |
+
new_key = rep_to + key[len(rep_from) :]
|
| 989 |
+
key_reps.append((key, new_key))
|
| 990 |
+
|
| 991 |
+
for key, new_key in key_reps:
|
| 992 |
+
state_dict[new_key] = state_dict[key]
|
| 993 |
+
del state_dict[key]
|
| 994 |
+
|
| 995 |
+
return checkpoint, state_dict
|
| 996 |
+
|
| 997 |
+
|
| 998 |
+
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
| 999 |
+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
|
| 1000 |
+
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
|
| 1001 |
+
|
| 1002 |
+
# Convert the UNet2DConditionModel model.
|
| 1003 |
+
unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
|
| 1004 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
| 1005 |
+
|
| 1006 |
+
unet = UNet2DConditionModel(**unet_config).to(device)
|
| 1007 |
+
info = unet.load_state_dict(converted_unet_checkpoint)
|
| 1008 |
+
logger.info(f"loading u-net: {info}")
|
| 1009 |
+
|
| 1010 |
+
# Convert the VAE model.
|
| 1011 |
+
vae_config = create_vae_diffusers_config()
|
| 1012 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
| 1013 |
+
|
| 1014 |
+
vae = AutoencoderKL(**vae_config).to(device)
|
| 1015 |
+
info = vae.load_state_dict(converted_vae_checkpoint)
|
| 1016 |
+
logger.info(f"loading vae: {info}")
|
| 1017 |
+
|
| 1018 |
+
# convert text_model
|
| 1019 |
+
if v2:
|
| 1020 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
| 1021 |
+
cfg = CLIPTextConfig(
|
| 1022 |
+
vocab_size=49408,
|
| 1023 |
+
hidden_size=1024,
|
| 1024 |
+
intermediate_size=4096,
|
| 1025 |
+
num_hidden_layers=23,
|
| 1026 |
+
num_attention_heads=16,
|
| 1027 |
+
max_position_embeddings=77,
|
| 1028 |
+
hidden_act="gelu",
|
| 1029 |
+
layer_norm_eps=1e-05,
|
| 1030 |
+
dropout=0.0,
|
| 1031 |
+
attention_dropout=0.0,
|
| 1032 |
+
initializer_range=0.02,
|
| 1033 |
+
initializer_factor=1.0,
|
| 1034 |
+
pad_token_id=1,
|
| 1035 |
+
bos_token_id=0,
|
| 1036 |
+
eos_token_id=2,
|
| 1037 |
+
model_type="clip_text_model",
|
| 1038 |
+
projection_dim=512,
|
| 1039 |
+
torch_dtype="float32",
|
| 1040 |
+
transformers_version="4.25.0.dev0",
|
| 1041 |
+
)
|
| 1042 |
+
text_model = CLIPTextModel._from_config(cfg)
|
| 1043 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
| 1044 |
+
else:
|
| 1045 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
| 1046 |
+
|
| 1047 |
+
# logging.set_verbosity_error() # don't show annoying warning
|
| 1048 |
+
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
| 1049 |
+
# logging.set_verbosity_warning()
|
| 1050 |
+
# logger.info(f"config: {text_model.config}")
|
| 1051 |
+
cfg = CLIPTextConfig(
|
| 1052 |
+
vocab_size=49408,
|
| 1053 |
+
hidden_size=768,
|
| 1054 |
+
intermediate_size=3072,
|
| 1055 |
+
num_hidden_layers=12,
|
| 1056 |
+
num_attention_heads=12,
|
| 1057 |
+
max_position_embeddings=77,
|
| 1058 |
+
hidden_act="quick_gelu",
|
| 1059 |
+
layer_norm_eps=1e-05,
|
| 1060 |
+
dropout=0.0,
|
| 1061 |
+
attention_dropout=0.0,
|
| 1062 |
+
initializer_range=0.02,
|
| 1063 |
+
initializer_factor=1.0,
|
| 1064 |
+
pad_token_id=1,
|
| 1065 |
+
bos_token_id=0,
|
| 1066 |
+
eos_token_id=2,
|
| 1067 |
+
model_type="clip_text_model",
|
| 1068 |
+
projection_dim=768,
|
| 1069 |
+
torch_dtype="float32",
|
| 1070 |
+
)
|
| 1071 |
+
text_model = CLIPTextModel._from_config(cfg)
|
| 1072 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
| 1073 |
+
logger.info(f"loading text encoder: {info}")
|
| 1074 |
+
|
| 1075 |
+
return text_model, vae, unet
|
| 1076 |
+
|
| 1077 |
+
|
| 1078 |
+
def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
|
| 1079 |
+
# only for reference
|
| 1080 |
+
version_str = "sd"
|
| 1081 |
+
if v2:
|
| 1082 |
+
version_str += "_v2"
|
| 1083 |
+
else:
|
| 1084 |
+
version_str += "_v1"
|
| 1085 |
+
if v_parameterization:
|
| 1086 |
+
version_str += "_v"
|
| 1087 |
+
return version_str
|
| 1088 |
+
|
| 1089 |
+
|
| 1090 |
+
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
| 1091 |
+
def convert_key(key):
|
| 1092 |
+
# position_idsの除去
|
| 1093 |
+
if ".position_ids" in key:
|
| 1094 |
+
return None
|
| 1095 |
+
|
| 1096 |
+
# common
|
| 1097 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
| 1098 |
+
key = key.replace("text_model.", "")
|
| 1099 |
+
if "layers" in key:
|
| 1100 |
+
# resblocks conversion
|
| 1101 |
+
key = key.replace(".layers.", ".resblocks.")
|
| 1102 |
+
if ".layer_norm" in key:
|
| 1103 |
+
key = key.replace(".layer_norm", ".ln_")
|
| 1104 |
+
elif ".mlp." in key:
|
| 1105 |
+
key = key.replace(".fc1.", ".c_fc.")
|
| 1106 |
+
key = key.replace(".fc2.", ".c_proj.")
|
| 1107 |
+
elif ".self_attn.out_proj" in key:
|
| 1108 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
| 1109 |
+
elif ".self_attn." in key:
|
| 1110 |
+
key = None # 特殊なので後で処理する
|
| 1111 |
+
else:
|
| 1112 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
| 1113 |
+
elif ".position_embedding" in key:
|
| 1114 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
| 1115 |
+
elif ".token_embedding" in key:
|
| 1116 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
| 1117 |
+
elif "final_layer_norm" in key:
|
| 1118 |
+
key = key.replace("final_layer_norm", "ln_final")
|
| 1119 |
+
return key
|
| 1120 |
+
|
| 1121 |
+
keys = list(checkpoint.keys())
|
| 1122 |
+
new_sd = {}
|
| 1123 |
+
for key in keys:
|
| 1124 |
+
new_key = convert_key(key)
|
| 1125 |
+
if new_key is None:
|
| 1126 |
+
continue
|
| 1127 |
+
new_sd[new_key] = checkpoint[key]
|
| 1128 |
+
|
| 1129 |
+
# attnの変換
|
| 1130 |
+
for key in keys:
|
| 1131 |
+
if "layers" in key and "q_proj" in key:
|
| 1132 |
+
# 三つを結合
|
| 1133 |
+
key_q = key
|
| 1134 |
+
key_k = key.replace("q_proj", "k_proj")
|
| 1135 |
+
key_v = key.replace("q_proj", "v_proj")
|
| 1136 |
+
|
| 1137 |
+
value_q = checkpoint[key_q]
|
| 1138 |
+
value_k = checkpoint[key_k]
|
| 1139 |
+
value_v = checkpoint[key_v]
|
| 1140 |
+
value = torch.cat([value_q, value_k, value_v])
|
| 1141 |
+
|
| 1142 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
| 1143 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
| 1144 |
+
new_sd[new_key] = value
|
| 1145 |
+
|
| 1146 |
+
# 最後の層などを捏造するか
|
| 1147 |
+
if make_dummy_weights:
|
| 1148 |
+
logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
|
| 1149 |
+
keys = list(new_sd.keys())
|
| 1150 |
+
for key in keys:
|
| 1151 |
+
if key.startswith("transformer.resblocks.22."):
|
| 1152 |
+
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
| 1153 |
+
|
| 1154 |
+
# Diffusersに含まれない重みを作っておく
|
| 1155 |
+
new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
| 1156 |
+
new_sd["logit_scale"] = torch.tensor(1)
|
| 1157 |
+
|
| 1158 |
+
return new_sd
|
| 1159 |
+
|
| 1160 |
+
|
| 1161 |
+
def save_stable_diffusion_checkpoint(
|
| 1162 |
+
v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
|
| 1163 |
+
):
|
| 1164 |
+
if ckpt_path is not None:
|
| 1165 |
+
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
| 1166 |
+
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
| 1167 |
+
if checkpoint is None: # safetensors または state_dictのckpt
|
| 1168 |
+
checkpoint = {}
|
| 1169 |
+
strict = False
|
| 1170 |
+
else:
|
| 1171 |
+
strict = True
|
| 1172 |
+
if "state_dict" in state_dict:
|
| 1173 |
+
del state_dict["state_dict"]
|
| 1174 |
+
else:
|
| 1175 |
+
# 新しく作る
|
| 1176 |
+
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
| 1177 |
+
checkpoint = {}
|
| 1178 |
+
state_dict = {}
|
| 1179 |
+
strict = False
|
| 1180 |
+
|
| 1181 |
+
def update_sd(prefix, sd):
|
| 1182 |
+
for k, v in sd.items():
|
| 1183 |
+
key = prefix + k
|
| 1184 |
+
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
| 1185 |
+
if save_dtype is not None:
|
| 1186 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
| 1187 |
+
state_dict[key] = v
|
| 1188 |
+
|
| 1189 |
+
# Convert the UNet model
|
| 1190 |
+
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
| 1191 |
+
update_sd("model.diffusion_model.", unet_state_dict)
|
| 1192 |
+
|
| 1193 |
+
# Convert the text encoder model
|
| 1194 |
+
if v2:
|
| 1195 |
+
make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
|
| 1196 |
+
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
| 1197 |
+
update_sd("cond_stage_model.model.", text_enc_dict)
|
| 1198 |
+
else:
|
| 1199 |
+
text_enc_dict = text_encoder.state_dict()
|
| 1200 |
+
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
| 1201 |
+
|
| 1202 |
+
# Convert the VAE
|
| 1203 |
+
if vae is not None:
|
| 1204 |
+
vae_dict = convert_vae_state_dict(vae.state_dict())
|
| 1205 |
+
update_sd("first_stage_model.", vae_dict)
|
| 1206 |
+
|
| 1207 |
+
# Put together new checkpoint
|
| 1208 |
+
key_count = len(state_dict.keys())
|
| 1209 |
+
new_ckpt = {"state_dict": state_dict}
|
| 1210 |
+
|
| 1211 |
+
# epoch and global_step are sometimes not int
|
| 1212 |
+
try:
|
| 1213 |
+
if "epoch" in checkpoint:
|
| 1214 |
+
epochs += checkpoint["epoch"]
|
| 1215 |
+
if "global_step" in checkpoint:
|
| 1216 |
+
steps += checkpoint["global_step"]
|
| 1217 |
+
except:
|
| 1218 |
+
pass
|
| 1219 |
+
|
| 1220 |
+
new_ckpt["epoch"] = epochs
|
| 1221 |
+
new_ckpt["global_step"] = steps
|
| 1222 |
+
|
| 1223 |
+
if is_safetensors(output_file):
|
| 1224 |
+
# TODO Tensor以外のdictの値を削除したほうがいいか
|
| 1225 |
+
save_file(state_dict, output_file, metadata)
|
| 1226 |
+
else:
|
| 1227 |
+
torch.save(new_ckpt, output_file)
|
| 1228 |
+
|
| 1229 |
+
return key_count
|
| 1230 |
+
|
| 1231 |
+
|
| 1232 |
+
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
| 1233 |
+
if pretrained_model_name_or_path is None:
|
| 1234 |
+
# load default settings for v1/v2
|
| 1235 |
+
if v2:
|
| 1236 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
| 1237 |
+
else:
|
| 1238 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
| 1239 |
+
|
| 1240 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
| 1241 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
| 1242 |
+
if vae is None:
|
| 1243 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
| 1244 |
+
|
| 1245 |
+
# original U-Net cannot be saved, so we need to convert it to the Diffusers version
|
| 1246 |
+
# TODO this consumes a lot of memory
|
| 1247 |
+
diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
|
| 1248 |
+
diffusers_unet.load_state_dict(unet.state_dict())
|
| 1249 |
+
|
| 1250 |
+
pipeline = StableDiffusionPipeline(
|
| 1251 |
+
unet=diffusers_unet,
|
| 1252 |
+
text_encoder=text_encoder,
|
| 1253 |
+
vae=vae,
|
| 1254 |
+
scheduler=scheduler,
|
| 1255 |
+
tokenizer=tokenizer,
|
| 1256 |
+
safety_checker=None,
|
| 1257 |
+
feature_extractor=None,
|
| 1258 |
+
requires_safety_checker=None,
|
| 1259 |
+
)
|
| 1260 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
| 1261 |
+
|
| 1262 |
+
|
| 1263 |
+
VAE_PREFIX = "first_stage_model."
|
| 1264 |
+
|
| 1265 |
+
|
| 1266 |
+
def load_vae(vae_id, dtype):
|
| 1267 |
+
logger.info(f"load VAE: {vae_id}")
|
| 1268 |
+
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
| 1269 |
+
# Diffusers local/remote
|
| 1270 |
+
try:
|
| 1271 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
| 1272 |
+
except EnvironmentError as e:
|
| 1273 |
+
logger.error(f"exception occurs in loading vae: {e}")
|
| 1274 |
+
logger.error("retry with subfolder='vae'")
|
| 1275 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
| 1276 |
+
return vae
|
| 1277 |
+
|
| 1278 |
+
# local
|
| 1279 |
+
vae_config = create_vae_diffusers_config()
|
| 1280 |
+
|
| 1281 |
+
if vae_id.endswith(".bin"):
|
| 1282 |
+
# SD 1.5 VAE on Huggingface
|
| 1283 |
+
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
| 1284 |
+
else:
|
| 1285 |
+
# StableDiffusion
|
| 1286 |
+
vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
|
| 1287 |
+
vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
|
| 1288 |
+
|
| 1289 |
+
# vae only or full model
|
| 1290 |
+
full_model = False
|
| 1291 |
+
for vae_key in vae_sd:
|
| 1292 |
+
if vae_key.startswith(VAE_PREFIX):
|
| 1293 |
+
full_model = True
|
| 1294 |
+
break
|
| 1295 |
+
if not full_model:
|
| 1296 |
+
sd = {}
|
| 1297 |
+
for key, value in vae_sd.items():
|
| 1298 |
+
sd[VAE_PREFIX + key] = value
|
| 1299 |
+
vae_sd = sd
|
| 1300 |
+
del sd
|
| 1301 |
+
|
| 1302 |
+
# Convert the VAE model.
|
| 1303 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
| 1304 |
+
|
| 1305 |
+
vae = AutoencoderKL(**vae_config)
|
| 1306 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
| 1307 |
+
return vae
|
| 1308 |
+
|
| 1309 |
+
|
| 1310 |
+
# endregion
|
| 1311 |
+
|
| 1312 |
+
|
| 1313 |
+
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
| 1314 |
+
max_width, max_height = max_reso
|
| 1315 |
+
max_area = max_width * max_height
|
| 1316 |
+
|
| 1317 |
+
resos = set()
|
| 1318 |
+
|
| 1319 |
+
width = int(math.sqrt(max_area) // divisible) * divisible
|
| 1320 |
+
resos.add((width, width))
|
| 1321 |
+
|
| 1322 |
+
width = min_size
|
| 1323 |
+
while width <= max_size:
|
| 1324 |
+
height = min(max_size, int((max_area // width) // divisible) * divisible)
|
| 1325 |
+
if height >= min_size:
|
| 1326 |
+
resos.add((width, height))
|
| 1327 |
+
resos.add((height, width))
|
| 1328 |
+
|
| 1329 |
+
# # make additional resos
|
| 1330 |
+
# if width >= height and width - divisible >= min_size:
|
| 1331 |
+
# resos.add((width - divisible, height))
|
| 1332 |
+
# resos.add((height, width - divisible))
|
| 1333 |
+
# if height >= width and height - divisible >= min_size:
|
| 1334 |
+
# resos.add((width, height - divisible))
|
| 1335 |
+
# resos.add((height - divisible, width))
|
| 1336 |
+
|
| 1337 |
+
width += divisible
|
| 1338 |
+
|
| 1339 |
+
resos = list(resos)
|
| 1340 |
+
resos.sort()
|
| 1341 |
+
return resos
|
| 1342 |
+
|
| 1343 |
+
|
| 1344 |
+
if __name__ == "__main__":
|
| 1345 |
+
resos = make_bucket_resolutions((512, 768))
|
| 1346 |
+
logger.info(f"{len(resos)}")
|
| 1347 |
+
logger.info(f"{resos}")
|
| 1348 |
+
aspect_ratios = [w / h for w, h in resos]
|
| 1349 |
+
logger.info(f"{aspect_ratios}")
|
| 1350 |
+
|
| 1351 |
+
ars = set()
|
| 1352 |
+
for ar in aspect_ratios:
|
| 1353 |
+
if ar in ars:
|
| 1354 |
+
logger.error(f"error! duplicate ar: {ar}")
|
| 1355 |
+
ars.add(ar)
|
sd-scripts/train_util.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sd-scripts/utils.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import sys
|
| 3 |
+
import threading
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from typing import *
|
| 7 |
+
from diffusers import EulerAncestralDiscreteScheduler
|
| 8 |
+
import diffusers.schedulers.scheduling_euler_ancestral_discrete
|
| 9 |
+
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
|
| 10 |
+
import cv2
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def fire_in_thread(f, *args, **kwargs):
|
| 16 |
+
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def add_logging_arguments(parser):
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--console_log_level",
|
| 22 |
+
type=str,
|
| 23 |
+
default=None,
|
| 24 |
+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
| 25 |
+
help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO",
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--console_log_file",
|
| 29 |
+
type=str,
|
| 30 |
+
default=None,
|
| 31 |
+
help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する",
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def setup_logging(args=None, log_level=None, reset=False):
|
| 37 |
+
if logging.root.handlers:
|
| 38 |
+
if reset:
|
| 39 |
+
# remove all handlers
|
| 40 |
+
for handler in logging.root.handlers[:]:
|
| 41 |
+
logging.root.removeHandler(handler)
|
| 42 |
+
else:
|
| 43 |
+
return
|
| 44 |
+
|
| 45 |
+
# log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO
|
| 46 |
+
if log_level is None and args is not None:
|
| 47 |
+
log_level = args.console_log_level
|
| 48 |
+
if log_level is None:
|
| 49 |
+
log_level = "INFO"
|
| 50 |
+
log_level = getattr(logging, log_level)
|
| 51 |
+
|
| 52 |
+
msg_init = None
|
| 53 |
+
if args is not None and args.console_log_file:
|
| 54 |
+
handler = logging.FileHandler(args.console_log_file, mode="w")
|
| 55 |
+
else:
|
| 56 |
+
handler = None
|
| 57 |
+
if not args or not args.console_log_simple:
|
| 58 |
+
try:
|
| 59 |
+
from rich.logging import RichHandler
|
| 60 |
+
from rich.console import Console
|
| 61 |
+
from rich.logging import RichHandler
|
| 62 |
+
|
| 63 |
+
handler = RichHandler(console=Console(stderr=True))
|
| 64 |
+
except ImportError:
|
| 65 |
+
# print("rich is not installed, using basic logging")
|
| 66 |
+
msg_init = "rich is not installed, using basic logging"
|
| 67 |
+
|
| 68 |
+
if handler is None:
|
| 69 |
+
handler = logging.StreamHandler(sys.stdout) # same as print
|
| 70 |
+
handler.propagate = False
|
| 71 |
+
|
| 72 |
+
formatter = logging.Formatter(
|
| 73 |
+
fmt="%(message)s",
|
| 74 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 75 |
+
)
|
| 76 |
+
handler.setFormatter(formatter)
|
| 77 |
+
logging.root.setLevel(log_level)
|
| 78 |
+
logging.root.addHandler(handler)
|
| 79 |
+
|
| 80 |
+
if msg_init is not None:
|
| 81 |
+
logger = logging.getLogger(__name__)
|
| 82 |
+
logger.info(msg_init)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def pil_resize(image, size, interpolation=Image.LANCZOS):
|
| 86 |
+
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
|
| 87 |
+
|
| 88 |
+
if has_alpha:
|
| 89 |
+
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
|
| 90 |
+
else:
|
| 91 |
+
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
| 92 |
+
|
| 93 |
+
resized_pil = pil_image.resize(size, interpolation)
|
| 94 |
+
|
| 95 |
+
# Convert back to cv2 format
|
| 96 |
+
if has_alpha:
|
| 97 |
+
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
|
| 98 |
+
else:
|
| 99 |
+
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
|
| 100 |
+
|
| 101 |
+
return resized_cv2
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# TODO make inf_utils.py
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# region Gradual Latent hires fix
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class GradualLatent:
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
ratio,
|
| 114 |
+
start_timesteps,
|
| 115 |
+
every_n_steps,
|
| 116 |
+
ratio_step,
|
| 117 |
+
s_noise=1.0,
|
| 118 |
+
gaussian_blur_ksize=None,
|
| 119 |
+
gaussian_blur_sigma=0.5,
|
| 120 |
+
gaussian_blur_strength=0.5,
|
| 121 |
+
unsharp_target_x=True,
|
| 122 |
+
):
|
| 123 |
+
self.ratio = ratio
|
| 124 |
+
self.start_timesteps = start_timesteps
|
| 125 |
+
self.every_n_steps = every_n_steps
|
| 126 |
+
self.ratio_step = ratio_step
|
| 127 |
+
self.s_noise = s_noise
|
| 128 |
+
self.gaussian_blur_ksize = gaussian_blur_ksize
|
| 129 |
+
self.gaussian_blur_sigma = gaussian_blur_sigma
|
| 130 |
+
self.gaussian_blur_strength = gaussian_blur_strength
|
| 131 |
+
self.unsharp_target_x = unsharp_target_x
|
| 132 |
+
|
| 133 |
+
def __str__(self) -> str:
|
| 134 |
+
return (
|
| 135 |
+
f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
|
| 136 |
+
+ f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
|
| 137 |
+
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
|
| 138 |
+
+ f"unsharp_target_x={self.unsharp_target_x})"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def apply_unshark_mask(self, x: torch.Tensor):
|
| 142 |
+
if self.gaussian_blur_ksize is None:
|
| 143 |
+
return x
|
| 144 |
+
blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma)
|
| 145 |
+
# mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength)
|
| 146 |
+
mask = (x - blurred) * self.gaussian_blur_strength
|
| 147 |
+
sharpened = x + mask
|
| 148 |
+
return sharpened
|
| 149 |
+
|
| 150 |
+
def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
|
| 151 |
+
org_dtype = x.dtype
|
| 152 |
+
if org_dtype == torch.bfloat16:
|
| 153 |
+
x = x.float()
|
| 154 |
+
|
| 155 |
+
x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
|
| 156 |
+
|
| 157 |
+
# apply unsharp mask / アンシャープマスクを適用する
|
| 158 |
+
if unsharp and self.gaussian_blur_ksize:
|
| 159 |
+
x = self.apply_unshark_mask(x)
|
| 160 |
+
|
| 161 |
+
return x
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
|
| 165 |
+
def __init__(self, *args, **kwargs):
|
| 166 |
+
super().__init__(*args, **kwargs)
|
| 167 |
+
self.resized_size = None
|
| 168 |
+
self.gradual_latent = None
|
| 169 |
+
|
| 170 |
+
def set_gradual_latent_params(self, size, gradual_latent: GradualLatent):
|
| 171 |
+
self.resized_size = size
|
| 172 |
+
self.gradual_latent = gradual_latent
|
| 173 |
+
|
| 174 |
+
def step(
|
| 175 |
+
self,
|
| 176 |
+
model_output: torch.FloatTensor,
|
| 177 |
+
timestep: Union[float, torch.FloatTensor],
|
| 178 |
+
sample: torch.FloatTensor,
|
| 179 |
+
generator: Optional[torch.Generator] = None,
|
| 180 |
+
return_dict: bool = True,
|
| 181 |
+
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
| 182 |
+
"""
|
| 183 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 184 |
+
process from the learned model outputs (most often the predicted noise).
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
model_output (`torch.FloatTensor`):
|
| 188 |
+
The direct output from learned diffusion model.
|
| 189 |
+
timestep (`float`):
|
| 190 |
+
The current discrete timestep in the diffusion chain.
|
| 191 |
+
sample (`torch.FloatTensor`):
|
| 192 |
+
A current instance of a sample created by the diffusion process.
|
| 193 |
+
generator (`torch.Generator`, *optional*):
|
| 194 |
+
A random number generator.
|
| 195 |
+
return_dict (`bool`):
|
| 196 |
+
Whether or not to return a
|
| 197 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
| 201 |
+
If return_dict is `True`,
|
| 202 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
|
| 203 |
+
otherwise a tuple is returned where the first element is the sample tensor.
|
| 204 |
+
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
|
| 208 |
+
raise ValueError(
|
| 209 |
+
(
|
| 210 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
| 211 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
| 212 |
+
" one of the `scheduler.timesteps` as a timestep."
|
| 213 |
+
),
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
if not self.is_scale_input_called:
|
| 217 |
+
# logger.warning(
|
| 218 |
+
print(
|
| 219 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
| 220 |
+
"See `StableDiffusionPipeline` for a usage example."
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if self.step_index is None:
|
| 224 |
+
self._init_step_index(timestep)
|
| 225 |
+
|
| 226 |
+
sigma = self.sigmas[self.step_index]
|
| 227 |
+
|
| 228 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
| 229 |
+
if self.config.prediction_type == "epsilon":
|
| 230 |
+
pred_original_sample = sample - sigma * model_output
|
| 231 |
+
elif self.config.prediction_type == "v_prediction":
|
| 232 |
+
# * c_out + input * c_skip
|
| 233 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
| 234 |
+
elif self.config.prediction_type == "sample":
|
| 235 |
+
raise NotImplementedError("prediction_type not implemented yet: sample")
|
| 236 |
+
else:
|
| 237 |
+
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
|
| 238 |
+
|
| 239 |
+
sigma_from = self.sigmas[self.step_index]
|
| 240 |
+
sigma_to = self.sigmas[self.step_index + 1]
|
| 241 |
+
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
| 242 |
+
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
| 243 |
+
|
| 244 |
+
# 2. Convert to an ODE derivative
|
| 245 |
+
derivative = (sample - pred_original_sample) / sigma
|
| 246 |
+
|
| 247 |
+
dt = sigma_down - sigma
|
| 248 |
+
|
| 249 |
+
device = model_output.device
|
| 250 |
+
if self.resized_size is None:
|
| 251 |
+
prev_sample = sample + derivative * dt
|
| 252 |
+
|
| 253 |
+
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
| 254 |
+
model_output.shape, dtype=model_output.dtype, device=device, generator=generator
|
| 255 |
+
)
|
| 256 |
+
s_noise = 1.0
|
| 257 |
+
else:
|
| 258 |
+
print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
|
| 259 |
+
s_noise = self.gradual_latent.s_noise
|
| 260 |
+
|
| 261 |
+
if self.gradual_latent.unsharp_target_x:
|
| 262 |
+
prev_sample = sample + derivative * dt
|
| 263 |
+
prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
|
| 264 |
+
else:
|
| 265 |
+
sample = self.gradual_latent.interpolate(sample, self.resized_size)
|
| 266 |
+
derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
|
| 267 |
+
prev_sample = sample + derivative * dt
|
| 268 |
+
|
| 269 |
+
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
| 270 |
+
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
|
| 271 |
+
dtype=model_output.dtype,
|
| 272 |
+
device=device,
|
| 273 |
+
generator=generator,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
prev_sample = prev_sample + noise * sigma_up * s_noise
|
| 277 |
+
|
| 278 |
+
# upon completion increase step index by one
|
| 279 |
+
self._step_index += 1
|
| 280 |
+
|
| 281 |
+
if not return_dict:
|
| 282 |
+
return (prev_sample,)
|
| 283 |
+
|
| 284 |
+
return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# endregion
|