Allex21 commited on
Commit
216a44c
·
verified ·
1 Parent(s): 6fb5a28

Upload 4 files

Browse files
sd-scripts/config_util.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import (
3
+ asdict,
4
+ dataclass,
5
+ )
6
+ import functools
7
+ import random
8
+ from textwrap import dedent, indent
9
+ import json
10
+ from pathlib import Path
11
+
12
+ # from toolz import curry
13
+ from typing import (
14
+ List,
15
+ Optional,
16
+ Sequence,
17
+ Tuple,
18
+ Union,
19
+ )
20
+
21
+ import toml
22
+ import voluptuous
23
+ from voluptuous import (
24
+ Any,
25
+ ExactSequence,
26
+ MultipleInvalid,
27
+ Object,
28
+ Required,
29
+ Schema,
30
+ )
31
+ from transformers import CLIPTokenizer
32
+
33
+ from . import train_util
34
+ from .train_util import (
35
+ DreamBoothSubset,
36
+ FineTuningSubset,
37
+ ControlNetSubset,
38
+ DreamBoothDataset,
39
+ FineTuningDataset,
40
+ ControlNetDataset,
41
+ DatasetGroup,
42
+ )
43
+ from .utils import setup_logging
44
+
45
+ setup_logging()
46
+ import logging
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ def add_config_arguments(parser: argparse.ArgumentParser):
52
+ parser.add_argument(
53
+ "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
54
+ )
55
+
56
+
57
+ # TODO: inherit Params class in Subset, Dataset
58
+
59
+
60
+ @dataclass
61
+ class BaseSubsetParams:
62
+ image_dir: Optional[str] = None
63
+ num_repeats: int = 1
64
+ shuffle_caption: bool = False
65
+ caption_separator: str = (",",)
66
+ keep_tokens: int = 0
67
+ keep_tokens_separator: str = (None,)
68
+ secondary_separator: Optional[str] = None
69
+ enable_wildcard: bool = False
70
+ color_aug: bool = False
71
+ flip_aug: bool = False
72
+ face_crop_aug_range: Optional[Tuple[float, float]] = None
73
+ random_crop: bool = False
74
+ caption_prefix: Optional[str] = None
75
+ caption_suffix: Optional[str] = None
76
+ caption_dropout_rate: float = 0.0
77
+ caption_dropout_every_n_epochs: int = 0
78
+ caption_tag_dropout_rate: float = 0.0
79
+ token_warmup_min: int = 1
80
+ token_warmup_step: float = 0
81
+
82
+
83
+ @dataclass
84
+ class DreamBoothSubsetParams(BaseSubsetParams):
85
+ is_reg: bool = False
86
+ class_tokens: Optional[str] = None
87
+ caption_extension: str = ".caption"
88
+ cache_info: bool = False
89
+ alpha_mask: bool = False
90
+
91
+
92
+ @dataclass
93
+ class FineTuningSubsetParams(BaseSubsetParams):
94
+ metadata_file: Optional[str] = None
95
+ alpha_mask: bool = False
96
+
97
+
98
+ @dataclass
99
+ class ControlNetSubsetParams(BaseSubsetParams):
100
+ conditioning_data_dir: str = None
101
+ caption_extension: str = ".caption"
102
+ cache_info: bool = False
103
+
104
+
105
+ @dataclass
106
+ class BaseDatasetParams:
107
+ tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
108
+ max_token_length: int = None
109
+ resolution: Optional[Tuple[int, int]] = None
110
+ network_multiplier: float = 1.0
111
+ debug_dataset: bool = False
112
+
113
+
114
+ @dataclass
115
+ class DreamBoothDatasetParams(BaseDatasetParams):
116
+ batch_size: int = 1
117
+ enable_bucket: bool = False
118
+ min_bucket_reso: int = 256
119
+ max_bucket_reso: int = 1024
120
+ bucket_reso_steps: int = 64
121
+ bucket_no_upscale: bool = False
122
+ prior_loss_weight: float = 1.0
123
+
124
+
125
+ @dataclass
126
+ class FineTuningDatasetParams(BaseDatasetParams):
127
+ batch_size: int = 1
128
+ enable_bucket: bool = False
129
+ min_bucket_reso: int = 256
130
+ max_bucket_reso: int = 1024
131
+ bucket_reso_steps: int = 64
132
+ bucket_no_upscale: bool = False
133
+
134
+
135
+ @dataclass
136
+ class ControlNetDatasetParams(BaseDatasetParams):
137
+ batch_size: int = 1
138
+ enable_bucket: bool = False
139
+ min_bucket_reso: int = 256
140
+ max_bucket_reso: int = 1024
141
+ bucket_reso_steps: int = 64
142
+ bucket_no_upscale: bool = False
143
+
144
+
145
+ @dataclass
146
+ class SubsetBlueprint:
147
+ params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
148
+
149
+
150
+ @dataclass
151
+ class DatasetBlueprint:
152
+ is_dreambooth: bool
153
+ is_controlnet: bool
154
+ params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
155
+ subsets: Sequence[SubsetBlueprint]
156
+
157
+
158
+ @dataclass
159
+ class DatasetGroupBlueprint:
160
+ datasets: Sequence[DatasetBlueprint]
161
+
162
+
163
+ @dataclass
164
+ class Blueprint:
165
+ dataset_group: DatasetGroupBlueprint
166
+
167
+
168
+ class ConfigSanitizer:
169
+ # @curry
170
+ @staticmethod
171
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
172
+ Schema(ExactSequence([klass, klass]))(value)
173
+ return tuple(value)
174
+
175
+ # @curry
176
+ @staticmethod
177
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
178
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
179
+ try:
180
+ Schema(klass)(value)
181
+ return (value, value)
182
+ except:
183
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
184
+
185
+ # subset schema
186
+ SUBSET_ASCENDABLE_SCHEMA = {
187
+ "color_aug": bool,
188
+ "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
189
+ "flip_aug": bool,
190
+ "num_repeats": int,
191
+ "random_crop": bool,
192
+ "shuffle_caption": bool,
193
+ "keep_tokens": int,
194
+ "keep_tokens_separator": str,
195
+ "secondary_separator": str,
196
+ "caption_separator": str,
197
+ "enable_wildcard": bool,
198
+ "token_warmup_min": int,
199
+ "token_warmup_step": Any(float, int),
200
+ "caption_prefix": str,
201
+ "caption_suffix": str,
202
+ }
203
+ # DO means DropOut
204
+ DO_SUBSET_ASCENDABLE_SCHEMA = {
205
+ "caption_dropout_every_n_epochs": int,
206
+ "caption_dropout_rate": Any(float, int),
207
+ "caption_tag_dropout_rate": Any(float, int),
208
+ }
209
+ # DB means DreamBooth
210
+ DB_SUBSET_ASCENDABLE_SCHEMA = {
211
+ "caption_extension": str,
212
+ "class_tokens": str,
213
+ "cache_info": bool,
214
+ }
215
+ DB_SUBSET_DISTINCT_SCHEMA = {
216
+ Required("image_dir"): str,
217
+ "is_reg": bool,
218
+ "alpha_mask": bool,
219
+ }
220
+ # FT means FineTuning
221
+ FT_SUBSET_DISTINCT_SCHEMA = {
222
+ Required("metadata_file"): str,
223
+ "image_dir": str,
224
+ "alpha_mask": bool,
225
+ }
226
+ CN_SUBSET_ASCENDABLE_SCHEMA = {
227
+ "caption_extension": str,
228
+ "cache_info": bool,
229
+ }
230
+ CN_SUBSET_DISTINCT_SCHEMA = {
231
+ Required("image_dir"): str,
232
+ Required("conditioning_data_dir"): str,
233
+ }
234
+
235
+ # datasets schema
236
+ DATASET_ASCENDABLE_SCHEMA = {
237
+ "batch_size": int,
238
+ "bucket_no_upscale": bool,
239
+ "bucket_reso_steps": int,
240
+ "enable_bucket": bool,
241
+ "max_bucket_reso": int,
242
+ "min_bucket_reso": int,
243
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
244
+ "network_multiplier": float,
245
+ }
246
+
247
+ # options handled by argparse but not handled by user config
248
+ ARGPARSE_SPECIFIC_SCHEMA = {
249
+ "debug_dataset": bool,
250
+ "max_token_length": Any(None, int),
251
+ "prior_loss_weight": Any(float, int),
252
+ }
253
+ # for handling default None value of argparse
254
+ ARGPARSE_NULLABLE_OPTNAMES = [
255
+ "face_crop_aug_range",
256
+ "resolution",
257
+ ]
258
+ # prepare map because option name may differ among argparse and user config
259
+ ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
260
+ "train_batch_size": "batch_size",
261
+ "dataset_repeats": "num_repeats",
262
+ }
263
+
264
+ def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
265
+ assert support_dreambooth or support_finetuning or support_controlnet, (
266
+ "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
267
+ + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
268
+ )
269
+
270
+ self.db_subset_schema = self.__merge_dict(
271
+ self.SUBSET_ASCENDABLE_SCHEMA,
272
+ self.DB_SUBSET_DISTINCT_SCHEMA,
273
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
274
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
275
+ )
276
+
277
+ self.ft_subset_schema = self.__merge_dict(
278
+ self.SUBSET_ASCENDABLE_SCHEMA,
279
+ self.FT_SUBSET_DISTINCT_SCHEMA,
280
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
281
+ )
282
+
283
+ self.cn_subset_schema = self.__merge_dict(
284
+ self.SUBSET_ASCENDABLE_SCHEMA,
285
+ self.CN_SUBSET_DISTINCT_SCHEMA,
286
+ self.CN_SUBSET_ASCENDABLE_SCHEMA,
287
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
288
+ )
289
+
290
+ self.db_dataset_schema = self.__merge_dict(
291
+ self.DATASET_ASCENDABLE_SCHEMA,
292
+ self.SUBSET_ASCENDABLE_SCHEMA,
293
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
294
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
295
+ {"subsets": [self.db_subset_schema]},
296
+ )
297
+
298
+ self.ft_dataset_schema = self.__merge_dict(
299
+ self.DATASET_ASCENDABLE_SCHEMA,
300
+ self.SUBSET_ASCENDABLE_SCHEMA,
301
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
302
+ {"subsets": [self.ft_subset_schema]},
303
+ )
304
+
305
+ self.cn_dataset_schema = self.__merge_dict(
306
+ self.DATASET_ASCENDABLE_SCHEMA,
307
+ self.SUBSET_ASCENDABLE_SCHEMA,
308
+ self.CN_SUBSET_ASCENDABLE_SCHEMA,
309
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
310
+ {"subsets": [self.cn_subset_schema]},
311
+ )
312
+
313
+ if support_dreambooth and support_finetuning:
314
+
315
+ def validate_flex_dataset(dataset_config: dict):
316
+ subsets_config = dataset_config.get("subsets", [])
317
+
318
+ if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
319
+ return Schema(self.cn_dataset_schema)(dataset_config)
320
+ # check dataset meets FT style
321
+ # NOTE: all FT subsets should have "metadata_file"
322
+ elif all(["metadata_file" in subset for subset in subsets_config]):
323
+ return Schema(self.ft_dataset_schema)(dataset_config)
324
+ # check dataset meets DB style
325
+ # NOTE: all DB subsets should have no "metadata_file"
326
+ elif all(["metadata_file" not in subset for subset in subsets_config]):
327
+ return Schema(self.db_dataset_schema)(dataset_config)
328
+ else:
329
+ raise voluptuous.Invalid(
330
+ "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。"
331
+ )
332
+
333
+ self.dataset_schema = validate_flex_dataset
334
+ elif support_dreambooth:
335
+ if support_controlnet:
336
+ self.dataset_schema = self.cn_dataset_schema
337
+ else:
338
+ self.dataset_schema = self.db_dataset_schema
339
+ elif support_finetuning:
340
+ self.dataset_schema = self.ft_dataset_schema
341
+ elif support_controlnet:
342
+ self.dataset_schema = self.cn_dataset_schema
343
+
344
+ self.general_schema = self.__merge_dict(
345
+ self.DATASET_ASCENDABLE_SCHEMA,
346
+ self.SUBSET_ASCENDABLE_SCHEMA,
347
+ self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
348
+ self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
349
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
350
+ )
351
+
352
+ self.user_config_validator = Schema(
353
+ {
354
+ "general": self.general_schema,
355
+ "datasets": [self.dataset_schema],
356
+ }
357
+ )
358
+
359
+ self.argparse_schema = self.__merge_dict(
360
+ self.general_schema,
361
+ self.ARGPARSE_SPECIFIC_SCHEMA,
362
+ {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
363
+ {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
364
+ )
365
+
366
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
367
+
368
+ def sanitize_user_config(self, user_config: dict) -> dict:
369
+ try:
370
+ return self.user_config_validator(user_config)
371
+ except MultipleInvalid:
372
+ # TODO: エラー発生時のメッセージをわかりやすくする
373
+ logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
374
+ raise
375
+
376
+ # NOTE: In nature, argument parser result is not needed to be sanitize
377
+ # However this will help us to detect program bug
378
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
379
+ try:
380
+ return self.argparse_config_validator(argparse_namespace)
381
+ except MultipleInvalid:
382
+ # XXX: this should be a bug
383
+ logger.error(
384
+ "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
385
+ )
386
+ raise
387
+
388
+ # NOTE: value would be overwritten by latter dict if there is already the same key
389
+ @staticmethod
390
+ def __merge_dict(*dict_list: dict) -> dict:
391
+ merged = {}
392
+ for schema in dict_list:
393
+ # merged |= schema
394
+ for k, v in schema.items():
395
+ merged[k] = v
396
+ return merged
397
+
398
+
399
+ class BlueprintGenerator:
400
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
401
+
402
+ def __init__(self, sanitizer: ConfigSanitizer):
403
+ self.sanitizer = sanitizer
404
+
405
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
406
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
407
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
408
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
409
+
410
+ # convert argparse namespace to dict like config
411
+ # NOTE: it is ok to have extra entries in dict
412
+ optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
413
+ argparse_config = {
414
+ optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()
415
+ }
416
+
417
+ general_config = sanitized_user_config.get("general", {})
418
+
419
+ dataset_blueprints = []
420
+ for dataset_config in sanitized_user_config.get("datasets", []):
421
+ # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
422
+ subsets = dataset_config.get("subsets", [])
423
+ is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
424
+ is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
425
+ if is_controlnet:
426
+ subset_params_klass = ControlNetSubsetParams
427
+ dataset_params_klass = ControlNetDatasetParams
428
+ elif is_dreambooth:
429
+ subset_params_klass = DreamBoothSubsetParams
430
+ dataset_params_klass = DreamBoothDatasetParams
431
+ else:
432
+ subset_params_klass = FineTuningSubsetParams
433
+ dataset_params_klass = FineTuningDatasetParams
434
+
435
+ subset_blueprints = []
436
+ for subset_config in subsets:
437
+ params = self.generate_params_by_fallbacks(
438
+ subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params]
439
+ )
440
+ subset_blueprints.append(SubsetBlueprint(params))
441
+
442
+ params = self.generate_params_by_fallbacks(
443
+ dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
444
+ )
445
+ dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
446
+
447
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
448
+
449
+ return Blueprint(dataset_group_blueprint)
450
+
451
+ @staticmethod
452
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
453
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
454
+ search_value = BlueprintGenerator.search_value
455
+ default_params = asdict(param_klass())
456
+ param_names = default_params.keys()
457
+
458
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
459
+
460
+ return param_klass(**params)
461
+
462
+ @staticmethod
463
+ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
464
+ for cand in fallbacks:
465
+ value = cand.get(key)
466
+ if value is not None:
467
+ return value
468
+
469
+ return default_value
470
+
471
+
472
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
473
+ datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
474
+
475
+ for dataset_blueprint in dataset_group_blueprint.datasets:
476
+ if dataset_blueprint.is_controlnet:
477
+ subset_klass = ControlNetSubset
478
+ dataset_klass = ControlNetDataset
479
+ elif dataset_blueprint.is_dreambooth:
480
+ subset_klass = DreamBoothSubset
481
+ dataset_klass = DreamBoothDataset
482
+ else:
483
+ subset_klass = FineTuningSubset
484
+ dataset_klass = FineTuningDataset
485
+
486
+ subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
487
+ dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
488
+ datasets.append(dataset)
489
+
490
+ # print info
491
+ info = ""
492
+ for i, dataset in enumerate(datasets):
493
+ is_dreambooth = isinstance(dataset, DreamBoothDataset)
494
+ is_controlnet = isinstance(dataset, ControlNetDataset)
495
+ info += dedent(
496
+ f"""\
497
+ [Dataset {i}]
498
+ batch_size: {dataset.batch_size}
499
+ resolution: {(dataset.width, dataset.height)}
500
+ enable_bucket: {dataset.enable_bucket}
501
+ network_multiplier: {dataset.network_multiplier}
502
+ """
503
+ )
504
+
505
+ if dataset.enable_bucket:
506
+ info += indent(
507
+ dedent(
508
+ f"""\
509
+ min_bucket_reso: {dataset.min_bucket_reso}
510
+ max_bucket_reso: {dataset.max_bucket_reso}
511
+ bucket_reso_steps: {dataset.bucket_reso_steps}
512
+ bucket_no_upscale: {dataset.bucket_no_upscale}
513
+ \n"""
514
+ ),
515
+ " ",
516
+ )
517
+ else:
518
+ info += "\n"
519
+
520
+ for j, subset in enumerate(dataset.subsets):
521
+ info += indent(
522
+ dedent(
523
+ f"""\
524
+ [Subset {j} of Dataset {i}]
525
+ image_dir: "{subset.image_dir}"
526
+ image_count: {subset.img_count}
527
+ num_repeats: {subset.num_repeats}
528
+ shuffle_caption: {subset.shuffle_caption}
529
+ keep_tokens: {subset.keep_tokens}
530
+ keep_tokens_separator: {subset.keep_tokens_separator}
531
+ caption_separator: {subset.caption_separator}
532
+ secondary_separator: {subset.secondary_separator}
533
+ enable_wildcard: {subset.enable_wildcard}
534
+ caption_dropout_rate: {subset.caption_dropout_rate}
535
+ caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
536
+ caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
537
+ caption_prefix: {subset.caption_prefix}
538
+ caption_suffix: {subset.caption_suffix}
539
+ color_aug: {subset.color_aug}
540
+ flip_aug: {subset.flip_aug}
541
+ face_crop_aug_range: {subset.face_crop_aug_range}
542
+ random_crop: {subset.random_crop}
543
+ token_warmup_min: {subset.token_warmup_min},
544
+ token_warmup_step: {subset.token_warmup_step},
545
+ alpha_mask: {subset.alpha_mask},
546
+ """
547
+ ),
548
+ " ",
549
+ )
550
+
551
+ if is_dreambooth:
552
+ info += indent(
553
+ dedent(
554
+ f"""\
555
+ is_reg: {subset.is_reg}
556
+ class_tokens: {subset.class_tokens}
557
+ caption_extension: {subset.caption_extension}
558
+ \n"""
559
+ ),
560
+ " ",
561
+ )
562
+ elif not is_controlnet:
563
+ info += indent(
564
+ dedent(
565
+ f"""\
566
+ metadata_file: {subset.metadata_file}
567
+ \n"""
568
+ ),
569
+ " ",
570
+ )
571
+
572
+ logger.info(f"{info}")
573
+
574
+ # make buckets first because it determines the length of dataset
575
+ # and set the same seed for all datasets
576
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
577
+ for i, dataset in enumerate(datasets):
578
+ logger.info(f"[Dataset {i}]")
579
+ dataset.make_buckets()
580
+ dataset.set_seed(seed)
581
+
582
+ return DatasetGroup(datasets)
583
+
584
+
585
+ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
586
+ def extract_dreambooth_params(name: str) -> Tuple[int, str]:
587
+ tokens = name.split("_")
588
+ try:
589
+ n_repeats = int(tokens[0])
590
+ except ValueError as e:
591
+ logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
592
+ return 0, ""
593
+ caption_by_folder = "_".join(tokens[1:])
594
+ return n_repeats, caption_by_folder
595
+
596
+ def generate(base_dir: Optional[str], is_reg: bool):
597
+ if base_dir is None:
598
+ return []
599
+
600
+ base_dir: Path = Path(base_dir)
601
+ if not base_dir.is_dir():
602
+ return []
603
+
604
+ subsets_config = []
605
+ for subdir in base_dir.iterdir():
606
+ if not subdir.is_dir():
607
+ continue
608
+
609
+ num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
610
+ if num_repeats < 1:
611
+ continue
612
+
613
+ subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
614
+ subsets_config.append(subset_config)
615
+
616
+ return subsets_config
617
+
618
+ subsets_config = []
619
+ subsets_config += generate(train_data_dir, False)
620
+ subsets_config += generate(reg_data_dir, True)
621
+
622
+ return subsets_config
623
+
624
+
625
+ def generate_controlnet_subsets_config_by_subdirs(
626
+ train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"
627
+ ):
628
+ def generate(base_dir: Optional[str]):
629
+ if base_dir is None:
630
+ return []
631
+
632
+ base_dir: Path = Path(base_dir)
633
+ if not base_dir.is_dir():
634
+ return []
635
+
636
+ subsets_config = []
637
+ subset_config = {
638
+ "image_dir": train_data_dir,
639
+ "conditioning_data_dir": conditioning_data_dir,
640
+ "caption_extension": caption_extension,
641
+ "num_repeats": 1,
642
+ }
643
+ subsets_config.append(subset_config)
644
+
645
+ return subsets_config
646
+
647
+ subsets_config = []
648
+ subsets_config += generate(train_data_dir)
649
+
650
+ return subsets_config
651
+
652
+
653
+ def load_user_config(file: str) -> dict:
654
+ file: Path = Path(file)
655
+ if not file.is_file():
656
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
657
+
658
+ if file.name.lower().endswith(".json"):
659
+ try:
660
+ with open(file, "r") as f:
661
+ config = json.load(f)
662
+ except Exception:
663
+ logger.error(
664
+ f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
665
+ )
666
+ raise
667
+ elif file.name.lower().endswith(".toml"):
668
+ try:
669
+ config = toml.load(file)
670
+ except Exception:
671
+ logger.error(
672
+ f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
673
+ )
674
+ raise
675
+ else:
676
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
677
+
678
+ return config
679
+
680
+
681
+ # for config test
682
+ if __name__ == "__main__":
683
+ parser = argparse.ArgumentParser()
684
+ parser.add_argument("--support_dreambooth", action="store_true")
685
+ parser.add_argument("--support_finetuning", action="store_true")
686
+ parser.add_argument("--support_controlnet", action="store_true")
687
+ parser.add_argument("--support_dropout", action="store_true")
688
+ parser.add_argument("dataset_config")
689
+ config_args, remain = parser.parse_known_args()
690
+
691
+ parser = argparse.ArgumentParser()
692
+ train_util.add_dataset_arguments(
693
+ parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout
694
+ )
695
+ train_util.add_training_arguments(parser, config_args.support_dreambooth)
696
+ argparse_namespace = parser.parse_args(remain)
697
+ train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
698
+
699
+ logger.info("[argparse_namespace]")
700
+ logger.info(f"{vars(argparse_namespace)}")
701
+
702
+ user_config = load_user_config(config_args.dataset_config)
703
+
704
+ logger.info("")
705
+ logger.info("[user_config]")
706
+ logger.info(f"{user_config}")
707
+
708
+ sanitizer = ConfigSanitizer(
709
+ config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
710
+ )
711
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
712
+
713
+ logger.info("")
714
+ logger.info("[sanitized_user_config]")
715
+ logger.info(f"{sanitized_user_config}")
716
+
717
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
718
+
719
+ logger.info("")
720
+ logger.info("[blueprint]")
721
+ logger.info(f"{blueprint}")
sd-scripts/model_util.py ADDED
@@ -0,0 +1,1355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v1: split from train_db_fixed.py.
2
+ # v2: support safetensors
3
+
4
+ import math
5
+ import os
6
+
7
+ import torch
8
+ from library.device_utils import init_ipex
9
+ init_ipex()
10
+
11
+ import diffusers
12
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
13
+ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
14
+ from safetensors.torch import load_file, save_file
15
+ from library.original_unet import UNet2DConditionModel
16
+ from library.utils import setup_logging
17
+ setup_logging()
18
+ import logging
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # DiffUsers版StableDiffusionのモデルパラメータ
22
+ NUM_TRAIN_TIMESTEPS = 1000
23
+ BETA_START = 0.00085
24
+ BETA_END = 0.0120
25
+
26
+ UNET_PARAMS_MODEL_CHANNELS = 320
27
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
28
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
29
+ UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
30
+ UNET_PARAMS_IN_CHANNELS = 4
31
+ UNET_PARAMS_OUT_CHANNELS = 4
32
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
33
+ UNET_PARAMS_CONTEXT_DIM = 768
34
+ UNET_PARAMS_NUM_HEADS = 8
35
+ # UNET_PARAMS_USE_LINEAR_PROJECTION = False
36
+
37
+ VAE_PARAMS_Z_CHANNELS = 4
38
+ VAE_PARAMS_RESOLUTION = 256
39
+ VAE_PARAMS_IN_CHANNELS = 3
40
+ VAE_PARAMS_OUT_CH = 3
41
+ VAE_PARAMS_CH = 128
42
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
43
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
44
+
45
+ # V2
46
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
47
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
48
+ # V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
49
+
50
+ # Diffusersの設定を読み込むための参照モデル
51
+ DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
52
+ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
53
+
54
+
55
+ # region StableDiffusion->Diffusersの変換コード
56
+ # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
57
+
58
+
59
+ def shave_segments(path, n_shave_prefix_segments=1):
60
+ """
61
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
62
+ """
63
+ if n_shave_prefix_segments >= 0:
64
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
65
+ else:
66
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
67
+
68
+
69
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
70
+ """
71
+ Updates paths inside resnets to the new naming scheme (local renaming)
72
+ """
73
+ mapping = []
74
+ for old_item in old_list:
75
+ new_item = old_item.replace("in_layers.0", "norm1")
76
+ new_item = new_item.replace("in_layers.2", "conv1")
77
+
78
+ new_item = new_item.replace("out_layers.0", "norm2")
79
+ new_item = new_item.replace("out_layers.3", "conv2")
80
+
81
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
82
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
83
+
84
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
85
+
86
+ mapping.append({"old": old_item, "new": new_item})
87
+
88
+ return mapping
89
+
90
+
91
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
92
+ """
93
+ Updates paths inside resnets to the new naming scheme (local renaming)
94
+ """
95
+ mapping = []
96
+ for old_item in old_list:
97
+ new_item = old_item
98
+
99
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
100
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
101
+
102
+ mapping.append({"old": old_item, "new": new_item})
103
+
104
+ return mapping
105
+
106
+
107
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
108
+ """
109
+ Updates paths inside attentions to the new naming scheme (local renaming)
110
+ """
111
+ mapping = []
112
+ for old_item in old_list:
113
+ new_item = old_item
114
+
115
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
116
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
117
+
118
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
119
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
120
+
121
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
122
+
123
+ mapping.append({"old": old_item, "new": new_item})
124
+
125
+ return mapping
126
+
127
+
128
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
129
+ """
130
+ Updates paths inside attentions to the new naming scheme (local renaming)
131
+ """
132
+ mapping = []
133
+ for old_item in old_list:
134
+ new_item = old_item
135
+
136
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
137
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
138
+
139
+ if diffusers.__version__ < "0.17.0":
140
+ new_item = new_item.replace("q.weight", "query.weight")
141
+ new_item = new_item.replace("q.bias", "query.bias")
142
+
143
+ new_item = new_item.replace("k.weight", "key.weight")
144
+ new_item = new_item.replace("k.bias", "key.bias")
145
+
146
+ new_item = new_item.replace("v.weight", "value.weight")
147
+ new_item = new_item.replace("v.bias", "value.bias")
148
+
149
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
150
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
151
+ else:
152
+ new_item = new_item.replace("q.weight", "to_q.weight")
153
+ new_item = new_item.replace("q.bias", "to_q.bias")
154
+
155
+ new_item = new_item.replace("k.weight", "to_k.weight")
156
+ new_item = new_item.replace("k.bias", "to_k.bias")
157
+
158
+ new_item = new_item.replace("v.weight", "to_v.weight")
159
+ new_item = new_item.replace("v.bias", "to_v.bias")
160
+
161
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
162
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
163
+
164
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
165
+
166
+ mapping.append({"old": old_item, "new": new_item})
167
+
168
+ return mapping
169
+
170
+
171
+ def assign_to_checkpoint(
172
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
173
+ ):
174
+ """
175
+ This does the final conversion step: take locally converted weights and apply a global renaming
176
+ to them. It splits attention layers, and takes into account additional replacements
177
+ that may arise.
178
+
179
+ Assigns the weights to the new checkpoint.
180
+ """
181
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
182
+
183
+ # Splits the attention layers into three variables.
184
+ if attention_paths_to_split is not None:
185
+ for path, path_map in attention_paths_to_split.items():
186
+ old_tensor = old_checkpoint[path]
187
+ channels = old_tensor.shape[0] // 3
188
+
189
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
190
+
191
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
192
+
193
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
194
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
195
+
196
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
197
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
198
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
199
+
200
+ for path in paths:
201
+ new_path = path["new"]
202
+
203
+ # These have already been assigned
204
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
205
+ continue
206
+
207
+ # Global renaming happens here
208
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
209
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
210
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
211
+
212
+ if additional_replacements is not None:
213
+ for replacement in additional_replacements:
214
+ new_path = new_path.replace(replacement["old"], replacement["new"])
215
+
216
+ # proj_attn.weight has to be converted from conv 1D to linear
217
+ reshaping = False
218
+ if diffusers.__version__ < "0.17.0":
219
+ if "proj_attn.weight" in new_path:
220
+ reshaping = True
221
+ else:
222
+ if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
223
+ reshaping = True
224
+
225
+ if reshaping:
226
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
227
+ else:
228
+ checkpoint[new_path] = old_checkpoint[path["old"]]
229
+
230
+
231
+ def conv_attn_to_linear(checkpoint):
232
+ keys = list(checkpoint.keys())
233
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
234
+ for key in keys:
235
+ if ".".join(key.split(".")[-2:]) in attn_keys:
236
+ if checkpoint[key].ndim > 2:
237
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
238
+ elif "proj_attn.weight" in key:
239
+ if checkpoint[key].ndim > 2:
240
+ checkpoint[key] = checkpoint[key][:, :, 0]
241
+
242
+
243
+ def linear_transformer_to_conv(checkpoint):
244
+ keys = list(checkpoint.keys())
245
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
246
+ for key in keys:
247
+ if ".".join(key.split(".")[-2:]) in tf_keys:
248
+ if checkpoint[key].ndim == 2:
249
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
250
+
251
+
252
+ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
253
+ """
254
+ Takes a state dict and a config, and returns a converted checkpoint.
255
+ """
256
+
257
+ # extract state_dict for UNet
258
+ unet_state_dict = {}
259
+ unet_key = "model.diffusion_model."
260
+ keys = list(checkpoint.keys())
261
+ for key in keys:
262
+ if key.startswith(unet_key):
263
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
264
+
265
+ new_checkpoint = {}
266
+
267
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
268
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
269
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
270
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
271
+
272
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
273
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
274
+
275
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
276
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
277
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
278
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
279
+
280
+ # Retrieves the keys for the input blocks only
281
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
282
+ input_blocks = {
283
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
284
+ }
285
+
286
+ # Retrieves the keys for the middle blocks only
287
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
288
+ middle_blocks = {
289
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
290
+ }
291
+
292
+ # Retrieves the keys for the output blocks only
293
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
294
+ output_blocks = {
295
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
296
+ }
297
+
298
+ for i in range(1, num_input_blocks):
299
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
300
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
301
+
302
+ resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
303
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
304
+
305
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
306
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
307
+ f"input_blocks.{i}.0.op.weight"
308
+ )
309
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
310
+
311
+ paths = renew_resnet_paths(resnets)
312
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
313
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
314
+
315
+ if len(attentions):
316
+ paths = renew_attention_paths(attentions)
317
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
318
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
319
+
320
+ resnet_0 = middle_blocks[0]
321
+ attentions = middle_blocks[1]
322
+ resnet_1 = middle_blocks[2]
323
+
324
+ resnet_0_paths = renew_resnet_paths(resnet_0)
325
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
326
+
327
+ resnet_1_paths = renew_resnet_paths(resnet_1)
328
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
329
+
330
+ attentions_paths = renew_attention_paths(attentions)
331
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
332
+ assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
333
+
334
+ for i in range(num_output_blocks):
335
+ block_id = i // (config["layers_per_block"] + 1)
336
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
337
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
338
+ output_block_list = {}
339
+
340
+ for layer in output_block_layers:
341
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
342
+ if layer_id in output_block_list:
343
+ output_block_list[layer_id].append(layer_name)
344
+ else:
345
+ output_block_list[layer_id] = [layer_name]
346
+
347
+ if len(output_block_list) > 1:
348
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
349
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
350
+
351
+ resnet_0_paths = renew_resnet_paths(resnets)
352
+ paths = renew_resnet_paths(resnets)
353
+
354
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
355
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
356
+
357
+ # オリジナル:
358
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
359
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
360
+
361
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
362
+ for l in output_block_list.values():
363
+ l.sort()
364
+
365
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
366
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
367
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
368
+ f"output_blocks.{i}.{index}.conv.bias"
369
+ ]
370
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
371
+ f"output_blocks.{i}.{index}.conv.weight"
372
+ ]
373
+
374
+ # Clear attentions as they have been attributed above.
375
+ if len(attentions) == 2:
376
+ attentions = []
377
+
378
+ if len(attentions):
379
+ paths = renew_attention_paths(attentions)
380
+ meta_path = {
381
+ "old": f"output_blocks.{i}.1",
382
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
383
+ }
384
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
385
+ else:
386
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
387
+ for path in resnet_0_paths:
388
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
389
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
390
+
391
+ new_checkpoint[new_path] = unet_state_dict[old_path]
392
+
393
+ # SDのv2では1*1のconv2dがlinearに変わっている
394
+ # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
395
+ if v2 and not config.get("use_linear_projection", False):
396
+ linear_transformer_to_conv(new_checkpoint)
397
+
398
+ return new_checkpoint
399
+
400
+
401
+ def convert_ldm_vae_checkpoint(checkpoint, config):
402
+ # extract state dict for VAE
403
+ vae_state_dict = {}
404
+ vae_key = "first_stage_model."
405
+ keys = list(checkpoint.keys())
406
+ for key in keys:
407
+ if key.startswith(vae_key):
408
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
409
+ # if len(vae_state_dict) == 0:
410
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
411
+ # vae_state_dict = checkpoint
412
+
413
+ new_checkpoint = {}
414
+
415
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
416
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
417
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
418
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
419
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
420
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
421
+
422
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
423
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
424
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
425
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
426
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
427
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
428
+
429
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
430
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
431
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
432
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
433
+
434
+ # Retrieves the keys for the encoder down blocks only
435
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
436
+ down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
437
+
438
+ # Retrieves the keys for the decoder up blocks only
439
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
440
+ up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
441
+
442
+ for i in range(num_down_blocks):
443
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
444
+
445
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
446
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
447
+ f"encoder.down.{i}.downsample.conv.weight"
448
+ )
449
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
450
+ f"encoder.down.{i}.downsample.conv.bias"
451
+ )
452
+
453
+ paths = renew_vae_resnet_paths(resnets)
454
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
455
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
456
+
457
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
458
+ num_mid_res_blocks = 2
459
+ for i in range(1, num_mid_res_blocks + 1):
460
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
461
+
462
+ paths = renew_vae_resnet_paths(resnets)
463
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
464
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
465
+
466
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
467
+ paths = renew_vae_attention_paths(mid_attentions)
468
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
469
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
470
+ conv_attn_to_linear(new_checkpoint)
471
+
472
+ for i in range(num_up_blocks):
473
+ block_id = num_up_blocks - 1 - i
474
+ resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
475
+
476
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
477
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
478
+ f"decoder.up.{block_id}.upsample.conv.weight"
479
+ ]
480
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
481
+ f"decoder.up.{block_id}.upsample.conv.bias"
482
+ ]
483
+
484
+ paths = renew_vae_resnet_paths(resnets)
485
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
486
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
487
+
488
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
489
+ num_mid_res_blocks = 2
490
+ for i in range(1, num_mid_res_blocks + 1):
491
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
492
+
493
+ paths = renew_vae_resnet_paths(resnets)
494
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
495
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
496
+
497
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
498
+ paths = renew_vae_attention_paths(mid_attentions)
499
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
500
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
501
+ conv_attn_to_linear(new_checkpoint)
502
+ return new_checkpoint
503
+
504
+
505
+ def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
506
+ """
507
+ Creates a config for the diffusers based on the config of the LDM model.
508
+ """
509
+ # unet_params = original_config.model.params.unet_config.params
510
+
511
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
512
+
513
+ down_block_types = []
514
+ resolution = 1
515
+ for i in range(len(block_out_channels)):
516
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
517
+ down_block_types.append(block_type)
518
+ if i != len(block_out_channels) - 1:
519
+ resolution *= 2
520
+
521
+ up_block_types = []
522
+ for i in range(len(block_out_channels)):
523
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
524
+ up_block_types.append(block_type)
525
+ resolution //= 2
526
+
527
+ config = dict(
528
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
529
+ in_channels=UNET_PARAMS_IN_CHANNELS,
530
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
531
+ down_block_types=tuple(down_block_types),
532
+ up_block_types=tuple(up_block_types),
533
+ block_out_channels=tuple(block_out_channels),
534
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
535
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
536
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
537
+ # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
538
+ )
539
+ if v2 and use_linear_projection_in_v2:
540
+ config["use_linear_projection"] = True
541
+
542
+ return config
543
+
544
+
545
+ def create_vae_diffusers_config():
546
+ """
547
+ Creates a config for the diffusers based on the config of the LDM model.
548
+ """
549
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
550
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
551
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
552
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
553
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
554
+
555
+ config = dict(
556
+ sample_size=VAE_PARAMS_RESOLUTION,
557
+ in_channels=VAE_PARAMS_IN_CHANNELS,
558
+ out_channels=VAE_PARAMS_OUT_CH,
559
+ down_block_types=tuple(down_block_types),
560
+ up_block_types=tuple(up_block_types),
561
+ block_out_channels=tuple(block_out_channels),
562
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
563
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
564
+ )
565
+ return config
566
+
567
+
568
+ def convert_ldm_clip_checkpoint_v1(checkpoint):
569
+ keys = list(checkpoint.keys())
570
+ text_model_dict = {}
571
+ for key in keys:
572
+ if key.startswith("cond_stage_model.transformer"):
573
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
574
+
575
+ # remove position_ids for newer transformer, which causes error :(
576
+ if "text_model.embeddings.position_ids" in text_model_dict:
577
+ text_model_dict.pop("text_model.embeddings.position_ids")
578
+
579
+ return text_model_dict
580
+
581
+
582
+ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
583
+ # 嫌になるくらい違うぞ!
584
+ def convert_key(key):
585
+ if not key.startswith("cond_stage_model"):
586
+ return None
587
+
588
+ # common conversion
589
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
590
+ key = key.replace("cond_stage_model.model.", "text_model.")
591
+
592
+ if "resblocks" in key:
593
+ # resblocks conversion
594
+ key = key.replace(".resblocks.", ".layers.")
595
+ if ".ln_" in key:
596
+ key = key.replace(".ln_", ".layer_norm")
597
+ elif ".mlp." in key:
598
+ key = key.replace(".c_fc.", ".fc1.")
599
+ key = key.replace(".c_proj.", ".fc2.")
600
+ elif ".attn.out_proj" in key:
601
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
602
+ elif ".attn.in_proj" in key:
603
+ key = None # 特殊なので後で処理する
604
+ else:
605
+ raise ValueError(f"unexpected key in SD: {key}")
606
+ elif ".positional_embedding" in key:
607
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
608
+ elif ".text_projection" in key:
609
+ key = None # 使われない???
610
+ elif ".logit_scale" in key:
611
+ key = None # 使われない???
612
+ elif ".token_embedding" in key:
613
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
614
+ elif ".ln_final" in key:
615
+ key = key.replace(".ln_final", ".final_layer_norm")
616
+ return key
617
+
618
+ keys = list(checkpoint.keys())
619
+ new_sd = {}
620
+ for key in keys:
621
+ # remove resblocks 23
622
+ if ".resblocks.23." in key:
623
+ continue
624
+ new_key = convert_key(key)
625
+ if new_key is None:
626
+ continue
627
+ new_sd[new_key] = checkpoint[key]
628
+
629
+ # attnの変換
630
+ for key in keys:
631
+ if ".resblocks.23." in key:
632
+ continue
633
+ if ".resblocks" in key and ".attn.in_proj_" in key:
634
+ # 三つに分割
635
+ values = torch.chunk(checkpoint[key], 3)
636
+
637
+ key_suffix = ".weight" if "weight" in key else ".bias"
638
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
639
+ key_pfx = key_pfx.replace("_weight", "")
640
+ key_pfx = key_pfx.replace("_bias", "")
641
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
642
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
643
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
644
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
645
+
646
+ # remove position_ids for newer transformer, which causes error :(
647
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
648
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
649
+ # waifu diffusion v1.4
650
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
651
+
652
+ if "text_model.embeddings.position_ids" in new_sd:
653
+ del new_sd["text_model.embeddings.position_ids"]
654
+
655
+ return new_sd
656
+
657
+
658
+ # endregion
659
+
660
+
661
+ # region Diffusers->StableDiffusion の変換コード
662
+ # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
663
+
664
+
665
+ def conv_transformer_to_linear(checkpoint):
666
+ keys = list(checkpoint.keys())
667
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
668
+ for key in keys:
669
+ if ".".join(key.split(".")[-2:]) in tf_keys:
670
+ if checkpoint[key].ndim > 2:
671
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
672
+
673
+
674
+ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
675
+ unet_conversion_map = [
676
+ # (stable-diffusion, HF Diffusers)
677
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
678
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
679
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
680
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
681
+ ("input_blocks.0.0.weight", "conv_in.weight"),
682
+ ("input_blocks.0.0.bias", "conv_in.bias"),
683
+ ("out.0.weight", "conv_norm_out.weight"),
684
+ ("out.0.bias", "conv_norm_out.bias"),
685
+ ("out.2.weight", "conv_out.weight"),
686
+ ("out.2.bias", "conv_out.bias"),
687
+ ]
688
+
689
+ unet_conversion_map_resnet = [
690
+ # (stable-diffusion, HF Diffusers)
691
+ ("in_layers.0", "norm1"),
692
+ ("in_layers.2", "conv1"),
693
+ ("out_layers.0", "norm2"),
694
+ ("out_layers.3", "conv2"),
695
+ ("emb_layers.1", "time_emb_proj"),
696
+ ("skip_connection", "conv_shortcut"),
697
+ ]
698
+
699
+ unet_conversion_map_layer = []
700
+ for i in range(4):
701
+ # loop over downblocks/upblocks
702
+
703
+ for j in range(2):
704
+ # loop over resnets/attentions for downblocks
705
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
706
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
707
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
708
+
709
+ if i < 3:
710
+ # no attention layers in down_blocks.3
711
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
712
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
713
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
714
+
715
+ for j in range(3):
716
+ # loop over resnets/attentions for upblocks
717
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
718
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
719
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
720
+
721
+ if i > 0:
722
+ # no attention layers in up_blocks.0
723
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
724
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
725
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
726
+
727
+ if i < 3:
728
+ # no downsample in down_blocks.3
729
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
730
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
731
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
732
+
733
+ # no upsample in up_blocks.3
734
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
735
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
736
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
737
+
738
+ hf_mid_atn_prefix = "mid_block.attentions.0."
739
+ sd_mid_atn_prefix = "middle_block.1."
740
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
741
+
742
+ for j in range(2):
743
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
744
+ sd_mid_res_prefix = f"middle_block.{2*j}."
745
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
746
+
747
+ # buyer beware: this is a *brittle* function,
748
+ # and correct output requires that all of these pieces interact in
749
+ # the exact order in which I have arranged them.
750
+ mapping = {k: k for k in unet_state_dict.keys()}
751
+ for sd_name, hf_name in unet_conversion_map:
752
+ mapping[hf_name] = sd_name
753
+ for k, v in mapping.items():
754
+ if "resnets" in k:
755
+ for sd_part, hf_part in unet_conversion_map_resnet:
756
+ v = v.replace(hf_part, sd_part)
757
+ mapping[k] = v
758
+ for k, v in mapping.items():
759
+ for sd_part, hf_part in unet_conversion_map_layer:
760
+ v = v.replace(hf_part, sd_part)
761
+ mapping[k] = v
762
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
763
+
764
+ if v2:
765
+ conv_transformer_to_linear(new_state_dict)
766
+
767
+ return new_state_dict
768
+
769
+
770
+ def controlnet_conversion_map():
771
+ unet_conversion_map = [
772
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
773
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
774
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
775
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
776
+ ("input_blocks.0.0.weight", "conv_in.weight"),
777
+ ("input_blocks.0.0.bias", "conv_in.bias"),
778
+ ("middle_block_out.0.weight", "controlnet_mid_block.weight"),
779
+ ("middle_block_out.0.bias", "controlnet_mid_block.bias"),
780
+ ]
781
+
782
+ unet_conversion_map_resnet = [
783
+ ("in_layers.0", "norm1"),
784
+ ("in_layers.2", "conv1"),
785
+ ("out_layers.0", "norm2"),
786
+ ("out_layers.3", "conv2"),
787
+ ("emb_layers.1", "time_emb_proj"),
788
+ ("skip_connection", "conv_shortcut"),
789
+ ]
790
+
791
+ unet_conversion_map_layer = []
792
+ for i in range(4):
793
+ for j in range(2):
794
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
795
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
796
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
797
+
798
+ if i < 3:
799
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
800
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
801
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
802
+
803
+ if i < 3:
804
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
805
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
806
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
807
+
808
+ hf_mid_atn_prefix = "mid_block.attentions.0."
809
+ sd_mid_atn_prefix = "middle_block.1."
810
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
811
+
812
+ for j in range(2):
813
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
814
+ sd_mid_res_prefix = f"middle_block.{2*j}."
815
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
816
+
817
+ controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
818
+ for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
819
+ hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
820
+ sd_prefix = f"input_hint_block.{i*2}."
821
+ unet_conversion_map_layer.append((sd_prefix, hf_prefix))
822
+
823
+ for i in range(12):
824
+ hf_prefix = f"controlnet_down_blocks.{i}."
825
+ sd_prefix = f"zero_convs.{i}.0."
826
+ unet_conversion_map_layer.append((sd_prefix, hf_prefix))
827
+
828
+ return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
829
+
830
+
831
+ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
832
+ unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
833
+
834
+ mapping = {k: k for k in controlnet_state_dict.keys()}
835
+ for sd_name, diffusers_name in unet_conversion_map:
836
+ mapping[diffusers_name] = sd_name
837
+ for k, v in mapping.items():
838
+ if "resnets" in k:
839
+ for sd_part, diffusers_part in unet_conversion_map_resnet:
840
+ v = v.replace(diffusers_part, sd_part)
841
+ mapping[k] = v
842
+ for k, v in mapping.items():
843
+ for sd_part, diffusers_part in unet_conversion_map_layer:
844
+ v = v.replace(diffusers_part, sd_part)
845
+ mapping[k] = v
846
+ new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
847
+ return new_state_dict
848
+
849
+
850
+ def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
851
+ unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
852
+
853
+ mapping = {k: k for k in controlnet_state_dict.keys()}
854
+ for sd_name, diffusers_name in unet_conversion_map:
855
+ mapping[sd_name] = diffusers_name
856
+ for k, v in mapping.items():
857
+ for sd_part, diffusers_part in unet_conversion_map_layer:
858
+ v = v.replace(sd_part, diffusers_part)
859
+ mapping[k] = v
860
+ for k, v in mapping.items():
861
+ if "resnets" in v:
862
+ for sd_part, diffusers_part in unet_conversion_map_resnet:
863
+ v = v.replace(sd_part, diffusers_part)
864
+ mapping[k] = v
865
+ new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
866
+ return new_state_dict
867
+
868
+
869
+ # ================#
870
+ # VAE Conversion #
871
+ # ================#
872
+
873
+
874
+ def reshape_weight_for_sd(w):
875
+ # convert HF linear weights to SD conv2d weights
876
+ return w.reshape(*w.shape, 1, 1)
877
+
878
+
879
+ def convert_vae_state_dict(vae_state_dict):
880
+ vae_conversion_map = [
881
+ # (stable-diffusion, HF Diffusers)
882
+ ("nin_shortcut", "conv_shortcut"),
883
+ ("norm_out", "conv_norm_out"),
884
+ ("mid.attn_1.", "mid_block.attentions.0."),
885
+ ]
886
+
887
+ for i in range(4):
888
+ # down_blocks have two resnets
889
+ for j in range(2):
890
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
891
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
892
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
893
+
894
+ if i < 3:
895
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
896
+ sd_downsample_prefix = f"down.{i}.downsample."
897
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
898
+
899
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
900
+ sd_upsample_prefix = f"up.{3-i}.upsample."
901
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
902
+
903
+ # up_blocks have three resnets
904
+ # also, up blocks in hf are numbered in reverse from sd
905
+ for j in range(3):
906
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
907
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
908
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
909
+
910
+ # this part accounts for mid blocks in both the encoder and the decoder
911
+ for i in range(2):
912
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
913
+ sd_mid_res_prefix = f"mid.block_{i+1}."
914
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
915
+
916
+ if diffusers.__version__ < "0.17.0":
917
+ vae_conversion_map_attn = [
918
+ # (stable-diffusion, HF Diffusers)
919
+ ("norm.", "group_norm."),
920
+ ("q.", "query."),
921
+ ("k.", "key."),
922
+ ("v.", "value."),
923
+ ("proj_out.", "proj_attn."),
924
+ ]
925
+ else:
926
+ vae_conversion_map_attn = [
927
+ # (stable-diffusion, HF Diffusers)
928
+ ("norm.", "group_norm."),
929
+ ("q.", "to_q."),
930
+ ("k.", "to_k."),
931
+ ("v.", "to_v."),
932
+ ("proj_out.", "to_out.0."),
933
+ ]
934
+
935
+ mapping = {k: k for k in vae_state_dict.keys()}
936
+ for k, v in mapping.items():
937
+ for sd_part, hf_part in vae_conversion_map:
938
+ v = v.replace(hf_part, sd_part)
939
+ mapping[k] = v
940
+ for k, v in mapping.items():
941
+ if "attentions" in k:
942
+ for sd_part, hf_part in vae_conversion_map_attn:
943
+ v = v.replace(hf_part, sd_part)
944
+ mapping[k] = v
945
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
946
+ weights_to_convert = ["q", "k", "v", "proj_out"]
947
+ for k, v in new_state_dict.items():
948
+ for weight_name in weights_to_convert:
949
+ if f"mid.attn_1.{weight_name}.weight" in k:
950
+ # logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
951
+ new_state_dict[k] = reshape_weight_for_sd(v)
952
+
953
+ return new_state_dict
954
+
955
+
956
+ # endregion
957
+
958
+ # region 自作のモデル読み書きなど
959
+
960
+
961
+ def is_safetensors(path):
962
+ return os.path.splitext(path)[1].lower() == ".safetensors"
963
+
964
+
965
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
966
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
967
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
968
+ ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
969
+ ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
970
+ ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
971
+ ]
972
+
973
+ if is_safetensors(ckpt_path):
974
+ checkpoint = None
975
+ state_dict = load_file(ckpt_path) # , device) # may causes error
976
+ else:
977
+ checkpoint = torch.load(ckpt_path, map_location=device)
978
+ if "state_dict" in checkpoint:
979
+ state_dict = checkpoint["state_dict"]
980
+ else:
981
+ state_dict = checkpoint
982
+ checkpoint = None
983
+
984
+ key_reps = []
985
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
986
+ for key in state_dict.keys():
987
+ if key.startswith(rep_from):
988
+ new_key = rep_to + key[len(rep_from) :]
989
+ key_reps.append((key, new_key))
990
+
991
+ for key, new_key in key_reps:
992
+ state_dict[new_key] = state_dict[key]
993
+ del state_dict[key]
994
+
995
+ return checkpoint, state_dict
996
+
997
+
998
+ # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
999
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
1000
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
1001
+
1002
+ # Convert the UNet2DConditionModel model.
1003
+ unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
1004
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
1005
+
1006
+ unet = UNet2DConditionModel(**unet_config).to(device)
1007
+ info = unet.load_state_dict(converted_unet_checkpoint)
1008
+ logger.info(f"loading u-net: {info}")
1009
+
1010
+ # Convert the VAE model.
1011
+ vae_config = create_vae_diffusers_config()
1012
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
1013
+
1014
+ vae = AutoencoderKL(**vae_config).to(device)
1015
+ info = vae.load_state_dict(converted_vae_checkpoint)
1016
+ logger.info(f"loading vae: {info}")
1017
+
1018
+ # convert text_model
1019
+ if v2:
1020
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
1021
+ cfg = CLIPTextConfig(
1022
+ vocab_size=49408,
1023
+ hidden_size=1024,
1024
+ intermediate_size=4096,
1025
+ num_hidden_layers=23,
1026
+ num_attention_heads=16,
1027
+ max_position_embeddings=77,
1028
+ hidden_act="gelu",
1029
+ layer_norm_eps=1e-05,
1030
+ dropout=0.0,
1031
+ attention_dropout=0.0,
1032
+ initializer_range=0.02,
1033
+ initializer_factor=1.0,
1034
+ pad_token_id=1,
1035
+ bos_token_id=0,
1036
+ eos_token_id=2,
1037
+ model_type="clip_text_model",
1038
+ projection_dim=512,
1039
+ torch_dtype="float32",
1040
+ transformers_version="4.25.0.dev0",
1041
+ )
1042
+ text_model = CLIPTextModel._from_config(cfg)
1043
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
1044
+ else:
1045
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
1046
+
1047
+ # logging.set_verbosity_error() # don't show annoying warning
1048
+ # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
1049
+ # logging.set_verbosity_warning()
1050
+ # logger.info(f"config: {text_model.config}")
1051
+ cfg = CLIPTextConfig(
1052
+ vocab_size=49408,
1053
+ hidden_size=768,
1054
+ intermediate_size=3072,
1055
+ num_hidden_layers=12,
1056
+ num_attention_heads=12,
1057
+ max_position_embeddings=77,
1058
+ hidden_act="quick_gelu",
1059
+ layer_norm_eps=1e-05,
1060
+ dropout=0.0,
1061
+ attention_dropout=0.0,
1062
+ initializer_range=0.02,
1063
+ initializer_factor=1.0,
1064
+ pad_token_id=1,
1065
+ bos_token_id=0,
1066
+ eos_token_id=2,
1067
+ model_type="clip_text_model",
1068
+ projection_dim=768,
1069
+ torch_dtype="float32",
1070
+ )
1071
+ text_model = CLIPTextModel._from_config(cfg)
1072
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
1073
+ logger.info(f"loading text encoder: {info}")
1074
+
1075
+ return text_model, vae, unet
1076
+
1077
+
1078
+ def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
1079
+ # only for reference
1080
+ version_str = "sd"
1081
+ if v2:
1082
+ version_str += "_v2"
1083
+ else:
1084
+ version_str += "_v1"
1085
+ if v_parameterization:
1086
+ version_str += "_v"
1087
+ return version_str
1088
+
1089
+
1090
+ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
1091
+ def convert_key(key):
1092
+ # position_idsの除去
1093
+ if ".position_ids" in key:
1094
+ return None
1095
+
1096
+ # common
1097
+ key = key.replace("text_model.encoder.", "transformer.")
1098
+ key = key.replace("text_model.", "")
1099
+ if "layers" in key:
1100
+ # resblocks conversion
1101
+ key = key.replace(".layers.", ".resblocks.")
1102
+ if ".layer_norm" in key:
1103
+ key = key.replace(".layer_norm", ".ln_")
1104
+ elif ".mlp." in key:
1105
+ key = key.replace(".fc1.", ".c_fc.")
1106
+ key = key.replace(".fc2.", ".c_proj.")
1107
+ elif ".self_attn.out_proj" in key:
1108
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
1109
+ elif ".self_attn." in key:
1110
+ key = None # 特殊なので後で処理する
1111
+ else:
1112
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
1113
+ elif ".position_embedding" in key:
1114
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
1115
+ elif ".token_embedding" in key:
1116
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
1117
+ elif "final_layer_norm" in key:
1118
+ key = key.replace("final_layer_norm", "ln_final")
1119
+ return key
1120
+
1121
+ keys = list(checkpoint.keys())
1122
+ new_sd = {}
1123
+ for key in keys:
1124
+ new_key = convert_key(key)
1125
+ if new_key is None:
1126
+ continue
1127
+ new_sd[new_key] = checkpoint[key]
1128
+
1129
+ # attnの変換
1130
+ for key in keys:
1131
+ if "layers" in key and "q_proj" in key:
1132
+ # 三つを結合
1133
+ key_q = key
1134
+ key_k = key.replace("q_proj", "k_proj")
1135
+ key_v = key.replace("q_proj", "v_proj")
1136
+
1137
+ value_q = checkpoint[key_q]
1138
+ value_k = checkpoint[key_k]
1139
+ value_v = checkpoint[key_v]
1140
+ value = torch.cat([value_q, value_k, value_v])
1141
+
1142
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
1143
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
1144
+ new_sd[new_key] = value
1145
+
1146
+ # 最後の層などを捏造するか
1147
+ if make_dummy_weights:
1148
+ logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
1149
+ keys = list(new_sd.keys())
1150
+ for key in keys:
1151
+ if key.startswith("transformer.resblocks.22."):
1152
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
1153
+
1154
+ # Diffusersに含まれない重みを作っておく
1155
+ new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
1156
+ new_sd["logit_scale"] = torch.tensor(1)
1157
+
1158
+ return new_sd
1159
+
1160
+
1161
+ def save_stable_diffusion_checkpoint(
1162
+ v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
1163
+ ):
1164
+ if ckpt_path is not None:
1165
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1166
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1167
+ if checkpoint is None: # safetensors または state_dictのckpt
1168
+ checkpoint = {}
1169
+ strict = False
1170
+ else:
1171
+ strict = True
1172
+ if "state_dict" in state_dict:
1173
+ del state_dict["state_dict"]
1174
+ else:
1175
+ # 新しく作る
1176
+ assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
1177
+ checkpoint = {}
1178
+ state_dict = {}
1179
+ strict = False
1180
+
1181
+ def update_sd(prefix, sd):
1182
+ for k, v in sd.items():
1183
+ key = prefix + k
1184
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1185
+ if save_dtype is not None:
1186
+ v = v.detach().clone().to("cpu").to(save_dtype)
1187
+ state_dict[key] = v
1188
+
1189
+ # Convert the UNet model
1190
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1191
+ update_sd("model.diffusion_model.", unet_state_dict)
1192
+
1193
+ # Convert the text encoder model
1194
+ if v2:
1195
+ make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1196
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1197
+ update_sd("cond_stage_model.model.", text_enc_dict)
1198
+ else:
1199
+ text_enc_dict = text_encoder.state_dict()
1200
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
1201
+
1202
+ # Convert the VAE
1203
+ if vae is not None:
1204
+ vae_dict = convert_vae_state_dict(vae.state_dict())
1205
+ update_sd("first_stage_model.", vae_dict)
1206
+
1207
+ # Put together new checkpoint
1208
+ key_count = len(state_dict.keys())
1209
+ new_ckpt = {"state_dict": state_dict}
1210
+
1211
+ # epoch and global_step are sometimes not int
1212
+ try:
1213
+ if "epoch" in checkpoint:
1214
+ epochs += checkpoint["epoch"]
1215
+ if "global_step" in checkpoint:
1216
+ steps += checkpoint["global_step"]
1217
+ except:
1218
+ pass
1219
+
1220
+ new_ckpt["epoch"] = epochs
1221
+ new_ckpt["global_step"] = steps
1222
+
1223
+ if is_safetensors(output_file):
1224
+ # TODO Tensor以外のdictの値を削除したほうがいいか
1225
+ save_file(state_dict, output_file, metadata)
1226
+ else:
1227
+ torch.save(new_ckpt, output_file)
1228
+
1229
+ return key_count
1230
+
1231
+
1232
+ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1233
+ if pretrained_model_name_or_path is None:
1234
+ # load default settings for v1/v2
1235
+ if v2:
1236
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1237
+ else:
1238
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1239
+
1240
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1241
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1242
+ if vae is None:
1243
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1244
+
1245
+ # original U-Net cannot be saved, so we need to convert it to the Diffusers version
1246
+ # TODO this consumes a lot of memory
1247
+ diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
1248
+ diffusers_unet.load_state_dict(unet.state_dict())
1249
+
1250
+ pipeline = StableDiffusionPipeline(
1251
+ unet=diffusers_unet,
1252
+ text_encoder=text_encoder,
1253
+ vae=vae,
1254
+ scheduler=scheduler,
1255
+ tokenizer=tokenizer,
1256
+ safety_checker=None,
1257
+ feature_extractor=None,
1258
+ requires_safety_checker=None,
1259
+ )
1260
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1261
+
1262
+
1263
+ VAE_PREFIX = "first_stage_model."
1264
+
1265
+
1266
+ def load_vae(vae_id, dtype):
1267
+ logger.info(f"load VAE: {vae_id}")
1268
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1269
+ # Diffusers local/remote
1270
+ try:
1271
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1272
+ except EnvironmentError as e:
1273
+ logger.error(f"exception occurs in loading vae: {e}")
1274
+ logger.error("retry with subfolder='vae'")
1275
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1276
+ return vae
1277
+
1278
+ # local
1279
+ vae_config = create_vae_diffusers_config()
1280
+
1281
+ if vae_id.endswith(".bin"):
1282
+ # SD 1.5 VAE on Huggingface
1283
+ converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1284
+ else:
1285
+ # StableDiffusion
1286
+ vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
1287
+ vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
1288
+
1289
+ # vae only or full model
1290
+ full_model = False
1291
+ for vae_key in vae_sd:
1292
+ if vae_key.startswith(VAE_PREFIX):
1293
+ full_model = True
1294
+ break
1295
+ if not full_model:
1296
+ sd = {}
1297
+ for key, value in vae_sd.items():
1298
+ sd[VAE_PREFIX + key] = value
1299
+ vae_sd = sd
1300
+ del sd
1301
+
1302
+ # Convert the VAE model.
1303
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1304
+
1305
+ vae = AutoencoderKL(**vae_config)
1306
+ vae.load_state_dict(converted_vae_checkpoint)
1307
+ return vae
1308
+
1309
+
1310
+ # endregion
1311
+
1312
+
1313
+ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1314
+ max_width, max_height = max_reso
1315
+ max_area = max_width * max_height
1316
+
1317
+ resos = set()
1318
+
1319
+ width = int(math.sqrt(max_area) // divisible) * divisible
1320
+ resos.add((width, width))
1321
+
1322
+ width = min_size
1323
+ while width <= max_size:
1324
+ height = min(max_size, int((max_area // width) // divisible) * divisible)
1325
+ if height >= min_size:
1326
+ resos.add((width, height))
1327
+ resos.add((height, width))
1328
+
1329
+ # # make additional resos
1330
+ # if width >= height and width - divisible >= min_size:
1331
+ # resos.add((width - divisible, height))
1332
+ # resos.add((height, width - divisible))
1333
+ # if height >= width and height - divisible >= min_size:
1334
+ # resos.add((width, height - divisible))
1335
+ # resos.add((height - divisible, width))
1336
+
1337
+ width += divisible
1338
+
1339
+ resos = list(resos)
1340
+ resos.sort()
1341
+ return resos
1342
+
1343
+
1344
+ if __name__ == "__main__":
1345
+ resos = make_bucket_resolutions((512, 768))
1346
+ logger.info(f"{len(resos)}")
1347
+ logger.info(f"{resos}")
1348
+ aspect_ratios = [w / h for w, h in resos]
1349
+ logger.info(f"{aspect_ratios}")
1350
+
1351
+ ars = set()
1352
+ for ar in aspect_ratios:
1353
+ if ar in ars:
1354
+ logger.error(f"error! duplicate ar: {ar}")
1355
+ ars.add(ar)
sd-scripts/train_util.py ADDED
The diff for this file is too large to render. See raw diff
 
sd-scripts/utils.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ import threading
4
+ import torch
5
+ from torchvision import transforms
6
+ from typing import *
7
+ from diffusers import EulerAncestralDiscreteScheduler
8
+ import diffusers.schedulers.scheduling_euler_ancestral_discrete
9
+ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
10
+ import cv2
11
+ from PIL import Image
12
+ import numpy as np
13
+
14
+
15
+ def fire_in_thread(f, *args, **kwargs):
16
+ threading.Thread(target=f, args=args, kwargs=kwargs).start()
17
+
18
+
19
+ def add_logging_arguments(parser):
20
+ parser.add_argument(
21
+ "--console_log_level",
22
+ type=str,
23
+ default=None,
24
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
25
+ help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO",
26
+ )
27
+ parser.add_argument(
28
+ "--console_log_file",
29
+ type=str,
30
+ default=None,
31
+ help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する",
32
+ )
33
+ parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力")
34
+
35
+
36
+ def setup_logging(args=None, log_level=None, reset=False):
37
+ if logging.root.handlers:
38
+ if reset:
39
+ # remove all handlers
40
+ for handler in logging.root.handlers[:]:
41
+ logging.root.removeHandler(handler)
42
+ else:
43
+ return
44
+
45
+ # log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO
46
+ if log_level is None and args is not None:
47
+ log_level = args.console_log_level
48
+ if log_level is None:
49
+ log_level = "INFO"
50
+ log_level = getattr(logging, log_level)
51
+
52
+ msg_init = None
53
+ if args is not None and args.console_log_file:
54
+ handler = logging.FileHandler(args.console_log_file, mode="w")
55
+ else:
56
+ handler = None
57
+ if not args or not args.console_log_simple:
58
+ try:
59
+ from rich.logging import RichHandler
60
+ from rich.console import Console
61
+ from rich.logging import RichHandler
62
+
63
+ handler = RichHandler(console=Console(stderr=True))
64
+ except ImportError:
65
+ # print("rich is not installed, using basic logging")
66
+ msg_init = "rich is not installed, using basic logging"
67
+
68
+ if handler is None:
69
+ handler = logging.StreamHandler(sys.stdout) # same as print
70
+ handler.propagate = False
71
+
72
+ formatter = logging.Formatter(
73
+ fmt="%(message)s",
74
+ datefmt="%Y-%m-%d %H:%M:%S",
75
+ )
76
+ handler.setFormatter(formatter)
77
+ logging.root.setLevel(log_level)
78
+ logging.root.addHandler(handler)
79
+
80
+ if msg_init is not None:
81
+ logger = logging.getLogger(__name__)
82
+ logger.info(msg_init)
83
+
84
+
85
+ def pil_resize(image, size, interpolation=Image.LANCZOS):
86
+ has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
87
+
88
+ if has_alpha:
89
+ pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
90
+ else:
91
+ pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
92
+
93
+ resized_pil = pil_image.resize(size, interpolation)
94
+
95
+ # Convert back to cv2 format
96
+ if has_alpha:
97
+ resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
98
+ else:
99
+ resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
100
+
101
+ return resized_cv2
102
+
103
+
104
+ # TODO make inf_utils.py
105
+
106
+
107
+ # region Gradual Latent hires fix
108
+
109
+
110
+ class GradualLatent:
111
+ def __init__(
112
+ self,
113
+ ratio,
114
+ start_timesteps,
115
+ every_n_steps,
116
+ ratio_step,
117
+ s_noise=1.0,
118
+ gaussian_blur_ksize=None,
119
+ gaussian_blur_sigma=0.5,
120
+ gaussian_blur_strength=0.5,
121
+ unsharp_target_x=True,
122
+ ):
123
+ self.ratio = ratio
124
+ self.start_timesteps = start_timesteps
125
+ self.every_n_steps = every_n_steps
126
+ self.ratio_step = ratio_step
127
+ self.s_noise = s_noise
128
+ self.gaussian_blur_ksize = gaussian_blur_ksize
129
+ self.gaussian_blur_sigma = gaussian_blur_sigma
130
+ self.gaussian_blur_strength = gaussian_blur_strength
131
+ self.unsharp_target_x = unsharp_target_x
132
+
133
+ def __str__(self) -> str:
134
+ return (
135
+ f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
136
+ + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
137
+ + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
138
+ + f"unsharp_target_x={self.unsharp_target_x})"
139
+ )
140
+
141
+ def apply_unshark_mask(self, x: torch.Tensor):
142
+ if self.gaussian_blur_ksize is None:
143
+ return x
144
+ blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma)
145
+ # mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength)
146
+ mask = (x - blurred) * self.gaussian_blur_strength
147
+ sharpened = x + mask
148
+ return sharpened
149
+
150
+ def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
151
+ org_dtype = x.dtype
152
+ if org_dtype == torch.bfloat16:
153
+ x = x.float()
154
+
155
+ x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
156
+
157
+ # apply unsharp mask / アンシャープマスクを適用する
158
+ if unsharp and self.gaussian_blur_ksize:
159
+ x = self.apply_unshark_mask(x)
160
+
161
+ return x
162
+
163
+
164
+ class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
165
+ def __init__(self, *args, **kwargs):
166
+ super().__init__(*args, **kwargs)
167
+ self.resized_size = None
168
+ self.gradual_latent = None
169
+
170
+ def set_gradual_latent_params(self, size, gradual_latent: GradualLatent):
171
+ self.resized_size = size
172
+ self.gradual_latent = gradual_latent
173
+
174
+ def step(
175
+ self,
176
+ model_output: torch.FloatTensor,
177
+ timestep: Union[float, torch.FloatTensor],
178
+ sample: torch.FloatTensor,
179
+ generator: Optional[torch.Generator] = None,
180
+ return_dict: bool = True,
181
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
182
+ """
183
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
184
+ process from the learned model outputs (most often the predicted noise).
185
+
186
+ Args:
187
+ model_output (`torch.FloatTensor`):
188
+ The direct output from learned diffusion model.
189
+ timestep (`float`):
190
+ The current discrete timestep in the diffusion chain.
191
+ sample (`torch.FloatTensor`):
192
+ A current instance of a sample created by the diffusion process.
193
+ generator (`torch.Generator`, *optional*):
194
+ A random number generator.
195
+ return_dict (`bool`):
196
+ Whether or not to return a
197
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
198
+
199
+ Returns:
200
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
201
+ If return_dict is `True`,
202
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
203
+ otherwise a tuple is returned where the first element is the sample tensor.
204
+
205
+ """
206
+
207
+ if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
208
+ raise ValueError(
209
+ (
210
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
211
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
212
+ " one of the `scheduler.timesteps` as a timestep."
213
+ ),
214
+ )
215
+
216
+ if not self.is_scale_input_called:
217
+ # logger.warning(
218
+ print(
219
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
220
+ "See `StableDiffusionPipeline` for a usage example."
221
+ )
222
+
223
+ if self.step_index is None:
224
+ self._init_step_index(timestep)
225
+
226
+ sigma = self.sigmas[self.step_index]
227
+
228
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
229
+ if self.config.prediction_type == "epsilon":
230
+ pred_original_sample = sample - sigma * model_output
231
+ elif self.config.prediction_type == "v_prediction":
232
+ # * c_out + input * c_skip
233
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
234
+ elif self.config.prediction_type == "sample":
235
+ raise NotImplementedError("prediction_type not implemented yet: sample")
236
+ else:
237
+ raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
238
+
239
+ sigma_from = self.sigmas[self.step_index]
240
+ sigma_to = self.sigmas[self.step_index + 1]
241
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
242
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
243
+
244
+ # 2. Convert to an ODE derivative
245
+ derivative = (sample - pred_original_sample) / sigma
246
+
247
+ dt = sigma_down - sigma
248
+
249
+ device = model_output.device
250
+ if self.resized_size is None:
251
+ prev_sample = sample + derivative * dt
252
+
253
+ noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
254
+ model_output.shape, dtype=model_output.dtype, device=device, generator=generator
255
+ )
256
+ s_noise = 1.0
257
+ else:
258
+ print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
259
+ s_noise = self.gradual_latent.s_noise
260
+
261
+ if self.gradual_latent.unsharp_target_x:
262
+ prev_sample = sample + derivative * dt
263
+ prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
264
+ else:
265
+ sample = self.gradual_latent.interpolate(sample, self.resized_size)
266
+ derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
267
+ prev_sample = sample + derivative * dt
268
+
269
+ noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
270
+ (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
271
+ dtype=model_output.dtype,
272
+ device=device,
273
+ generator=generator,
274
+ )
275
+
276
+ prev_sample = prev_sample + noise * sigma_up * s_noise
277
+
278
+ # upon completion increase step index by one
279
+ self._step_index += 1
280
+
281
+ if not return_dict:
282
+ return (prev_sample,)
283
+
284
+ return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
285
+
286
+
287
+ # endregion