KamCastle commited on
Commit
9041274
·
verified ·
1 Parent(s): 0356f3f

Upload 4 files

Browse files
Files changed (4) hide show
  1. sdxl_train.py +819 -0
  2. sdxl_train_util.py +389 -0
  3. train_db.py +539 -0
  4. train_util.py +0 -0
sdxl_train.py ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training with captions
2
+
3
+ import argparse
4
+ import gc
5
+ import math
6
+ import os
7
+ from multiprocessing import Value
8
+ from typing import List
9
+ import toml
10
+
11
+ from tqdm import tqdm
12
+ import torch
13
+
14
+ from library.ipex_interop import init_ipex
15
+
16
+ init_ipex()
17
+
18
+ from accelerate.utils import set_seed
19
+ from diffusers import DDPMScheduler
20
+ from library import sdxl_model_util
21
+
22
+ import library.train_util as train_util
23
+ import library.config_util as config_util
24
+ import library.sdxl_train_util as sdxl_train_util
25
+ from library.config_util import (
26
+ ConfigSanitizer,
27
+ BlueprintGenerator,
28
+ )
29
+ import library.custom_train_functions as custom_train_functions
30
+ from library.custom_train_functions import (
31
+ apply_snr_weight,
32
+ prepare_scheduler_for_custom_training,
33
+ scale_v_prediction_loss_like_noise_prediction,
34
+ add_v_prediction_like_loss,
35
+ apply_debiased_estimation,
36
+ )
37
+ from library.sdxl_original_unet import SdxlUNet2DConditionModel
38
+ from library.train_util import EMAModel
39
+
40
+
41
+ UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23
42
+
43
+
44
+ def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List[float]) -> List[dict]:
45
+ block_params = [[] for _ in range(len(block_lrs))]
46
+
47
+ for i, (name, param) in enumerate(unet.named_parameters()):
48
+ if name.startswith("time_embed.") or name.startswith("label_emb."):
49
+ block_index = 0 # 0
50
+ elif name.startswith("input_blocks."): # 1-9
51
+ block_index = 1 + int(name.split(".")[1])
52
+ elif name.startswith("middle_block."): # 10-12
53
+ block_index = 10 + int(name.split(".")[1])
54
+ elif name.startswith("output_blocks."): # 13-21
55
+ block_index = 13 + int(name.split(".")[1])
56
+ elif name.startswith("out."): # 22
57
+ block_index = 22
58
+ else:
59
+ raise ValueError(f"unexpected parameter name: {name}")
60
+
61
+ block_params[block_index].append(param)
62
+
63
+ params_to_optimize = []
64
+ for i, params in enumerate(block_params):
65
+ if block_lrs[i] == 0: # 0のときは学習しない do not optimize when lr is 0
66
+ continue
67
+ params_to_optimize.append({"params": params, "lr": block_lrs[i]})
68
+
69
+ return params_to_optimize
70
+
71
+
72
+ def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type):
73
+ names = []
74
+ block_index = 0
75
+ while block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR + 2:
76
+ if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR:
77
+ if block_lrs[block_index] == 0:
78
+ block_index += 1
79
+ continue
80
+ names.append(f"block{block_index}")
81
+ elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR:
82
+ names.append("text_encoder1")
83
+ elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1:
84
+ names.append("text_encoder2")
85
+
86
+ block_index += 1
87
+
88
+ train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
89
+
90
+
91
+ def train(args):
92
+ train_util.verify_training_args(args)
93
+ train_util.prepare_dataset_args(args, True)
94
+ sdxl_train_util.verify_sdxl_training_args(args)
95
+
96
+ assert not args.weighted_captions, "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
97
+ assert (
98
+ not args.train_text_encoder or not args.cache_text_encoder_outputs
99
+ ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
100
+
101
+ if args.block_lr:
102
+ block_lrs = [float(lr) for lr in args.block_lr.split(",")]
103
+ assert (
104
+ len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR
105
+ ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください"
106
+ else:
107
+ block_lrs = None
108
+
109
+ cache_latents = args.cache_latents
110
+ use_dreambooth_method = args.in_json is None
111
+
112
+ if args.seed is not None:
113
+ set_seed(args.seed) # 乱数系列を初期化する
114
+
115
+ tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
116
+
117
+ # データセットを準備する
118
+ if args.dataset_class is None:
119
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
120
+ if args.dataset_config is not None:
121
+ print(f"Load dataset config from {args.dataset_config}")
122
+ user_config = config_util.load_user_config(args.dataset_config)
123
+ ignored = ["train_data_dir", "in_json"]
124
+ if any(getattr(args, attr) is not None for attr in ignored):
125
+ print(
126
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
127
+ ", ".join(ignored)
128
+ )
129
+ )
130
+ else:
131
+ if use_dreambooth_method:
132
+ print("Using DreamBooth method.")
133
+ user_config = {
134
+ "datasets": [
135
+ {
136
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
137
+ args.train_data_dir, args.reg_data_dir
138
+ )
139
+ }
140
+ ]
141
+ }
142
+ else:
143
+ print("Training with captions.")
144
+ user_config = {
145
+ "datasets": [
146
+ {
147
+ "subsets": [
148
+ {
149
+ "image_dir": args.train_data_dir,
150
+ "metadata_file": args.in_json,
151
+ }
152
+ ]
153
+ }
154
+ ]
155
+ }
156
+
157
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
158
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
159
+ else:
160
+ train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2])
161
+
162
+ current_epoch = Value("i", 0)
163
+ current_step = Value("i", 0)
164
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
165
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
166
+
167
+ train_dataset_group.verify_bucket_reso_steps(32)
168
+
169
+ if args.debug_dataset:
170
+ train_util.debug_dataset(train_dataset_group, True)
171
+ return
172
+ if len(train_dataset_group) == 0:
173
+ print(
174
+ "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
175
+ )
176
+ return
177
+
178
+ if cache_latents:
179
+ assert (
180
+ train_dataset_group.is_latent_cacheable()
181
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
182
+
183
+ if args.cache_text_encoder_outputs:
184
+ assert (
185
+ train_dataset_group.is_text_encoder_output_cacheable()
186
+ ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
187
+
188
+ # acceleratorを準備する
189
+ print("prepare accelerator")
190
+ accelerator = train_util.prepare_accelerator(args)
191
+
192
+ # mixed precisionに対応した型を用意しておき適宜castする
193
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
194
+ vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
195
+
196
+ # モデルを読み込む
197
+ (
198
+ load_stable_diffusion_format,
199
+ text_encoder1,
200
+ text_encoder2,
201
+ vae,
202
+ unet,
203
+ logit_scale,
204
+ ckpt_info,
205
+ ) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
206
+ # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
207
+
208
+ # verify load/save model formats
209
+ if load_stable_diffusion_format:
210
+ src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
211
+ src_diffusers_model_path = None
212
+ else:
213
+ src_stable_diffusion_ckpt = None
214
+ src_diffusers_model_path = args.pretrained_model_name_or_path
215
+
216
+ if args.save_model_as is None:
217
+ save_stable_diffusion_format = load_stable_diffusion_format
218
+ use_safetensors = args.use_safetensors
219
+ else:
220
+ save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
221
+ use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
222
+ # assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります"
223
+
224
+ # Diffusers版のxformers使用フラグを設定する関数
225
+ def set_diffusers_xformers_flag(model, valid):
226
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
227
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
228
+ module.set_use_memory_efficient_attention_xformers(valid)
229
+
230
+ for child in module.children():
231
+ fn_recursive_set_mem_eff(child)
232
+
233
+ fn_recursive_set_mem_eff(model)
234
+
235
+ # モデルに xformers とか memory efficient attention を組み込む
236
+ if args.diffusers_xformers:
237
+ # もうU-Netを独自にしたので動かないけどVAEのxformersは動くはず
238
+ accelerator.print("Use xformers by Diffusers")
239
+ # set_diffusers_xformers_flag(unet, True)
240
+ set_diffusers_xformers_flag(vae, True)
241
+ else:
242
+ # Windows版のxformersはfloatで学習できなかったりするのでxformersを使わない設定も可能にしておく必要がある
243
+ accelerator.print("Disable Diffusers' xformers")
244
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
245
+ if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
246
+ vae.set_use_memory_efficient_attention_xformers(args.xformers)
247
+
248
+ # 学習を準備する
249
+ if cache_latents:
250
+ vae.to(accelerator.device, dtype=vae_dtype)
251
+ vae.requires_grad_(False)
252
+ vae.eval()
253
+ with torch.no_grad():
254
+ train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
255
+ vae.to("cpu")
256
+ if torch.cuda.is_available():
257
+ torch.cuda.empty_cache()
258
+ gc.collect()
259
+
260
+ accelerator.wait_for_everyone()
261
+
262
+ # 学習を準備する:モデルを適切な状態にする
263
+ if args.gradient_checkpointing:
264
+ unet.enable_gradient_checkpointing()
265
+ train_unet = args.learning_rate > 0
266
+ train_text_encoder1 = False
267
+ train_text_encoder2 = False
268
+
269
+ if args.train_text_encoder:
270
+ # TODO each option for two text encoders?
271
+ accelerator.print("enable text encoder training")
272
+ if args.gradient_checkpointing:
273
+ text_encoder1.gradient_checkpointing_enable()
274
+ text_encoder2.gradient_checkpointing_enable()
275
+ lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
276
+ lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
277
+ train_text_encoder1 = lr_te1 > 0
278
+ train_text_encoder2 = lr_te2 > 0
279
+
280
+ # caching one text encoder output is not supported
281
+ if not train_text_encoder1:
282
+ text_encoder1.to(weight_dtype)
283
+ if not train_text_encoder2:
284
+ text_encoder2.to(weight_dtype)
285
+ text_encoder1.requires_grad_(train_text_encoder1)
286
+ text_encoder2.requires_grad_(train_text_encoder2)
287
+ text_encoder1.train(train_text_encoder1)
288
+ text_encoder2.train(train_text_encoder2)
289
+ else:
290
+ text_encoder1.to(weight_dtype)
291
+ text_encoder2.to(weight_dtype)
292
+ text_encoder1.requires_grad_(False)
293
+ text_encoder2.requires_grad_(False)
294
+ text_encoder1.eval()
295
+ text_encoder2.eval()
296
+
297
+ # TextEncoderの出力をキャッシュする
298
+ if args.cache_text_encoder_outputs:
299
+ # Text Encodes are eval and no grad
300
+ with torch.no_grad(), accelerator.autocast():
301
+ train_dataset_group.cache_text_encoder_outputs(
302
+ (tokenizer1, tokenizer2),
303
+ (text_encoder1, text_encoder2),
304
+ accelerator.device,
305
+ None,
306
+ args.cache_text_encoder_outputs_to_disk,
307
+ accelerator.is_main_process,
308
+ )
309
+ accelerator.wait_for_everyone()
310
+
311
+ if not cache_latents:
312
+ vae.requires_grad_(False)
313
+ vae.eval()
314
+ vae.to(accelerator.device, dtype=vae_dtype)
315
+
316
+ unet.requires_grad_(train_unet)
317
+ if not train_unet:
318
+ unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
319
+
320
+ training_models = []
321
+ params_to_optimize = []
322
+ if train_unet:
323
+ training_models.append(unet)
324
+ if block_lrs is None:
325
+ params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate})
326
+ else:
327
+ params_to_optimize.extend(get_block_params_to_optimize(unet, block_lrs))
328
+
329
+ if train_text_encoder1:
330
+ training_models.append(text_encoder1)
331
+ params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
332
+ if train_text_encoder2:
333
+ training_models.append(text_encoder2)
334
+ params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})
335
+
336
+ # calculate number of trainable parameters
337
+ n_params = 0
338
+ for params in params_to_optimize:
339
+ for p in params["params"]:
340
+ n_params += p.numel()
341
+
342
+ accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
343
+ accelerator.print(f"number of models: {len(training_models)}")
344
+ accelerator.print(f"number of trainable parameters: {n_params}")
345
+
346
+ # 学習に必要なクラスを準備する
347
+ accelerator.print("prepare optimizer, data loader etc.")
348
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
349
+
350
+ # dataloaderを準備する
351
+ # DataLoaderのプロセス数:0はメインプロセスになる
352
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
353
+ train_dataloader = torch.utils.data.DataLoader(
354
+ train_dataset_group,
355
+ batch_size=1,
356
+ shuffle=True,
357
+ collate_fn=collator,
358
+ num_workers=n_workers,
359
+ persistent_workers=args.persistent_data_loader_workers,
360
+ )
361
+
362
+ # 学習ステップ数を計算する
363
+ if args.max_train_epochs is not None:
364
+ args.max_train_steps = args.max_train_epochs * math.ceil(
365
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
366
+ )
367
+ accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
368
+
369
+ # データセット側にも学習ステップを送信
370
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
371
+
372
+ # lr schedulerを用意する
373
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
374
+
375
+ # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
376
+ if args.full_fp16:
377
+ assert (
378
+ args.mixed_precision == "fp16"
379
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
380
+ accelerator.print("enable full fp16 training.")
381
+ unet.to(weight_dtype)
382
+ text_encoder1.to(weight_dtype)
383
+ text_encoder2.to(weight_dtype)
384
+ elif args.full_bf16:
385
+ assert (
386
+ args.mixed_precision == "bf16"
387
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
388
+ accelerator.print("enable full bf16 training.")
389
+ unet.to(weight_dtype)
390
+ text_encoder1.to(weight_dtype)
391
+ text_encoder2.to(weight_dtype)
392
+
393
+ if args.enable_ema:
394
+ #ema_dtype = weight_dtype if (args.full_bf16 or args.full_fp16) else torch.float
395
+ ema = EMAModel(params_to_optimize, decay=args.ema_decay, beta=args.ema_exp_beta, max_train_steps=args.max_train_steps)
396
+ ema.to(accelerator.device, dtype=weight_dtype)
397
+ ema = accelerator.prepare(ema)
398
+ else:
399
+ ema = None
400
+ # acceleratorがなんかよろしくやってくれるらしい
401
+ if train_unet:
402
+ unet = accelerator.prepare(unet)
403
+ if train_text_encoder1:
404
+ # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
405
+ text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
406
+ text_encoder1.text_model.final_layer_norm.requires_grad_(False)
407
+ text_encoder1 = accelerator.prepare(text_encoder1)
408
+ if train_text_encoder2:
409
+ text_encoder2 = accelerator.prepare(text_encoder2)
410
+
411
+ optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
412
+
413
+ # TextEncoderの出力をキャッシュするときにはCPUへ移動する
414
+ if args.cache_text_encoder_outputs:
415
+ # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
416
+ text_encoder1.to("cpu", dtype=torch.float32)
417
+ text_encoder2.to("cpu", dtype=torch.float32)
418
+ if torch.cuda.is_available():
419
+ torch.cuda.empty_cache()
420
+ else:
421
+ # make sure Text Encoders are on GPU
422
+ text_encoder1.to(accelerator.device)
423
+ text_encoder2.to(accelerator.device)
424
+
425
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
426
+ if args.full_fp16:
427
+ train_util.patch_accelerator_for_fp16_training(accelerator)
428
+
429
+ # resumeする
430
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
431
+
432
+ # epoch数を計算する
433
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
434
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
435
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
436
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
437
+
438
+ # 学習する
439
+ # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
440
+ accelerator.print("running training / 学習開始")
441
+ accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
442
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
443
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
444
+ accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
445
+ # accelerator.print(
446
+ # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
447
+ # )
448
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
449
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
450
+
451
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
452
+ global_step = 0
453
+
454
+ noise_scheduler = DDPMScheduler(
455
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
456
+ )
457
+ prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
458
+ if args.zero_terminal_snr:
459
+ custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
460
+
461
+ if accelerator.is_main_process:
462
+ init_kwargs = {}
463
+ if args.wandb_run_name:
464
+ init_kwargs['wandb'] = {'name': args.wandb_run_name}
465
+ if args.log_tracker_config is not None:
466
+ init_kwargs = toml.load(args.log_tracker_config)
467
+ accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
468
+
469
+ # For --sample_at_first
470
+ sdxl_train_util.sample_images(
471
+ accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
472
+ )
473
+
474
+ loss_recorder = train_util.LossRecorder()
475
+ for epoch in range(num_train_epochs):
476
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
477
+ current_epoch.value = epoch + 1
478
+
479
+ for m in training_models:
480
+ m.train()
481
+
482
+ for step, batch in enumerate(train_dataloader):
483
+ current_step.value = global_step
484
+ with accelerator.accumulate(*training_models):
485
+ if "latents" in batch and batch["latents"] is not None:
486
+ latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
487
+ else:
488
+ with torch.no_grad():
489
+ # latentに変換
490
+ latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
491
+
492
+ # NaNが含まれていれば警告を表示し0に置き換える
493
+ if torch.any(torch.isnan(latents)):
494
+ accelerator.print("NaN found in latents, replacing with zeros")
495
+ latents = torch.nan_to_num(latents, 0, out=latents)
496
+ latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
497
+
498
+ if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
499
+ input_ids1 = batch["input_ids"]
500
+ input_ids2 = batch["input_ids2"]
501
+ with torch.set_grad_enabled(args.train_text_encoder):
502
+ # Get the text embedding for conditioning
503
+ # TODO support weighted captions
504
+ # if args.weighted_captions:
505
+ # encoder_hidden_states = get_weighted_text_embeddings(
506
+ # tokenizer,
507
+ # text_encoder,
508
+ # batch["captions"],
509
+ # accelerator.device,
510
+ # args.max_token_length // 75 if args.max_token_length else 1,
511
+ # clip_skip=args.clip_skip,
512
+ # )
513
+ # else:
514
+ input_ids1 = input_ids1.to(accelerator.device)
515
+ input_ids2 = input_ids2.to(accelerator.device)
516
+ # unwrap_model is fine for models not wrapped by accelerator
517
+ encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
518
+ args.max_token_length,
519
+ input_ids1,
520
+ input_ids2,
521
+ tokenizer1,
522
+ tokenizer2,
523
+ text_encoder1,
524
+ text_encoder2,
525
+ None if not args.full_fp16 else weight_dtype,
526
+ accelerator=accelerator,
527
+ )
528
+ else:
529
+ encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
530
+ encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
531
+ pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
532
+
533
+ # # verify that the text encoder outputs are correct
534
+ # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
535
+ # args.max_token_length,
536
+ # batch["input_ids"].to(text_encoder1.device),
537
+ # batch["input_ids2"].to(text_encoder1.device),
538
+ # tokenizer1,
539
+ # tokenizer2,
540
+ # text_encoder1,
541
+ # text_encoder2,
542
+ # None if not args.full_fp16 else weight_dtype,
543
+ # )
544
+ # b_size = encoder_hidden_states1.shape[0]
545
+ # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
546
+ # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
547
+ # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
548
+ # print("text encoder outputs verified")
549
+
550
+ # get size embeddings
551
+ orig_size = batch["original_sizes_hw"]
552
+ crop_size = batch["crop_top_lefts"]
553
+ target_size = batch["target_sizes_hw"]
554
+ embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
555
+
556
+ # concat embeddings
557
+ vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
558
+ text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
559
+
560
+ # Sample noise, sample a random timestep for each image, and add noise to the latents,
561
+ # with noise offset and/or multires noise if specified
562
+ noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
563
+
564
+ noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
565
+
566
+ # Predict the noise residual
567
+ with accelerator.autocast():
568
+ noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
569
+
570
+ target = noise
571
+
572
+ if (
573
+ args.min_snr_gamma
574
+ or args.scale_v_pred_loss_like_noise_pred
575
+ or args.v_pred_like_loss
576
+ or args.debiased_estimation_loss
577
+ ):
578
+ # do not mean over batch dimension for snr weight or scale v-pred loss
579
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
580
+ loss = loss.mean([1, 2, 3])
581
+
582
+ if args.min_snr_gamma:
583
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
584
+ if args.scale_v_pred_loss_like_noise_pred:
585
+ loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
586
+ if args.v_pred_like_loss:
587
+ loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
588
+ if args.debiased_estimation_loss:
589
+ loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
590
+
591
+ loss = loss.mean() # mean over batch dimension
592
+ else:
593
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
594
+
595
+ accelerator.backward(loss)
596
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
597
+ params_to_clip = []
598
+ for m in training_models:
599
+ params_to_clip.extend(m.parameters())
600
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
601
+
602
+ optimizer.step()
603
+ lr_scheduler.step()
604
+ optimizer.zero_grad(set_to_none=True)
605
+ if args.enable_ema:
606
+ with torch.no_grad(), accelerator.autocast():
607
+ ema.step(params_to_optimize)
608
+
609
+ # Checks if the accelerator has performed an optimization step behind the scenes
610
+ if accelerator.sync_gradients:
611
+ progress_bar.update(1)
612
+ global_step += 1
613
+
614
+ sdxl_train_util.sample_images(
615
+ accelerator,
616
+ args,
617
+ None,
618
+ global_step,
619
+ accelerator.device,
620
+ vae,
621
+ [tokenizer1, tokenizer2],
622
+ [text_encoder1, text_encoder2],
623
+ unet,
624
+ )
625
+
626
+ # 指定ステップごとにモデルを保存
627
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
628
+ accelerator.wait_for_everyone()
629
+ if accelerator.is_main_process:
630
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
631
+ sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise(
632
+ args,
633
+ False,
634
+ accelerator,
635
+ src_path,
636
+ save_stable_diffusion_format,
637
+ use_safetensors,
638
+ save_dtype,
639
+ epoch,
640
+ num_train_epochs,
641
+ global_step,
642
+ accelerator.unwrap_model(text_encoder1),
643
+ accelerator.unwrap_model(text_encoder2),
644
+ accelerator.unwrap_model(unet),
645
+ vae,
646
+ logit_scale,
647
+ ckpt_info,
648
+ ema=ema,
649
+ params_to_replace=params_to_optimize,
650
+ )
651
+
652
+ current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
653
+ if args.logging_dir is not None:
654
+ logs = {"loss": current_loss}
655
+ if block_lrs is None:
656
+ train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
657
+ else:
658
+ append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs
659
+
660
+ accelerator.log(logs, step=global_step)
661
+
662
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
663
+ avr_loss: float = loss_recorder.moving_average
664
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
665
+ progress_bar.set_postfix(**logs)
666
+
667
+ if global_step >= args.max_train_steps:
668
+ break
669
+
670
+ if args.logging_dir is not None:
671
+ logs = {"loss/epoch": loss_recorder.moving_average}
672
+ accelerator.log(logs, step=epoch + 1)
673
+
674
+ accelerator.wait_for_everyone()
675
+
676
+ if args.save_every_n_epochs is not None:
677
+ if accelerator.is_main_process:
678
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
679
+ sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise(
680
+ args,
681
+ True,
682
+ accelerator,
683
+ src_path,
684
+ save_stable_diffusion_format,
685
+ use_safetensors,
686
+ save_dtype,
687
+ epoch,
688
+ num_train_epochs,
689
+ global_step,
690
+ accelerator.unwrap_model(text_encoder1),
691
+ accelerator.unwrap_model(text_encoder2),
692
+ accelerator.unwrap_model(unet),
693
+ vae,
694
+ logit_scale,
695
+ ckpt_info,
696
+ ema=ema,
697
+ params_to_replace=params_to_optimize,
698
+ )
699
+
700
+ sdxl_train_util.sample_images(
701
+ accelerator,
702
+ args,
703
+ epoch + 1,
704
+ global_step,
705
+ accelerator.device,
706
+ vae,
707
+ [tokenizer1, tokenizer2],
708
+ [text_encoder1, text_encoder2],
709
+ unet,
710
+ )
711
+
712
+ is_main_process = accelerator.is_main_process
713
+ # if is_main_process:
714
+ unet = accelerator.unwrap_model(unet)
715
+ text_encoder1 = accelerator.unwrap_model(text_encoder1)
716
+ text_encoder2 = accelerator.unwrap_model(text_encoder2)
717
+ if args.enable_ema:
718
+ ema = accelerator.unwrap_model(ema)
719
+
720
+ accelerator.end_training()
721
+
722
+ if args.save_state: # and is_main_process:
723
+ train_util.save_state_on_train_end(args, accelerator)
724
+
725
+ del accelerator # この後メモリを使うのでこれは消す
726
+
727
+ if is_main_process:
728
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
729
+ if args.enable_ema and not args.ema_save_only_ema_weights:
730
+ temp_name = args.output_name
731
+ args.output_name = args.output_name + "-non-EMA"
732
+ sdxl_train_util.save_sd_model_on_train_end(
733
+ args,
734
+ src_path,
735
+ save_stable_diffusion_format,
736
+ use_safetensors,
737
+ save_dtype,
738
+ epoch,
739
+ global_step,
740
+ text_encoder1,
741
+ text_encoder2,
742
+ unet,
743
+ vae,
744
+ logit_scale,
745
+ ckpt_info,
746
+ )
747
+ args.output_name = temp_name
748
+ if args.enable_ema:
749
+ print("Saving EMA:")
750
+ ema.copy_to(params_to_optimize)
751
+
752
+ sdxl_train_util.save_sd_model_on_train_end(
753
+ args,
754
+ src_path,
755
+ save_stable_diffusion_format,
756
+ use_safetensors,
757
+ save_dtype,
758
+ epoch,
759
+ global_step,
760
+ text_encoder1,
761
+ text_encoder2,
762
+ unet,
763
+ vae,
764
+ logit_scale,
765
+ ckpt_info,
766
+ )
767
+ print("model saved.")
768
+
769
+
770
+ def setup_parser() -> argparse.ArgumentParser:
771
+ parser = argparse.ArgumentParser()
772
+
773
+ train_util.add_sd_models_arguments(parser)
774
+ train_util.add_dataset_arguments(parser, True, True, True)
775
+ train_util.add_training_arguments(parser, False)
776
+ train_util.add_sd_saving_arguments(parser)
777
+ train_util.add_optimizer_arguments(parser)
778
+ config_util.add_config_arguments(parser)
779
+ custom_train_functions.add_custom_train_arguments(parser)
780
+ sdxl_train_util.add_sdxl_training_arguments(parser)
781
+
782
+ parser.add_argument(
783
+ "--learning_rate_te1",
784
+ type=float,
785
+ default=None,
786
+ help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率",
787
+ )
788
+ parser.add_argument(
789
+ "--learning_rate_te2",
790
+ type=float,
791
+ default=None,
792
+ help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率",
793
+ )
794
+
795
+ parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
796
+ parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
797
+ parser.add_argument(
798
+ "--no_half_vae",
799
+ action="store_true",
800
+ help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
801
+ )
802
+ parser.add_argument(
803
+ "--block_lr",
804
+ type=str,
805
+ default=None,
806
+ help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
807
+ + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
808
+ )
809
+
810
+ return parser
811
+
812
+
813
+ if __name__ == "__main__":
814
+ parser = setup_parser()
815
+
816
+ args = parser.parse_args()
817
+ args = train_util.read_config_from_file(args, parser)
818
+
819
+ train(args)
sdxl_train_util.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import math
4
+ import os
5
+ from typing import Optional
6
+ import torch
7
+ from accelerate import init_empty_weights
8
+ from tqdm import tqdm
9
+ from transformers import CLIPTokenizer
10
+ from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
11
+ from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
12
+
13
+ TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
14
+ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
15
+
16
+ # DEFAULT_NOISE_OFFSET = 0.0357
17
+
18
+
19
+ def load_target_model(args, accelerator, model_version: str, weight_dtype):
20
+ # load models for each process
21
+ model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
22
+ for pi in range(accelerator.state.num_processes):
23
+ if pi == accelerator.state.local_process_index:
24
+ print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
25
+
26
+ (
27
+ load_stable_diffusion_format,
28
+ text_encoder1,
29
+ text_encoder2,
30
+ vae,
31
+ unet,
32
+ logit_scale,
33
+ ckpt_info,
34
+ ) = _load_target_model(
35
+ args.pretrained_model_name_or_path,
36
+ args.vae,
37
+ model_version,
38
+ weight_dtype,
39
+ accelerator.device if args.lowram else "cpu",
40
+ model_dtype,
41
+ )
42
+
43
+ # work on low-ram device
44
+ if args.lowram:
45
+ text_encoder1.to(accelerator.device)
46
+ text_encoder2.to(accelerator.device)
47
+ unet.to(accelerator.device)
48
+ vae.to(accelerator.device)
49
+
50
+ gc.collect()
51
+ torch.cuda.empty_cache()
52
+ accelerator.wait_for_everyone()
53
+
54
+ return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
55
+
56
+
57
+ def _load_target_model(
58
+ name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None
59
+ ):
60
+ # model_dtype only work with full fp16/bf16
61
+ name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
62
+ load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
63
+
64
+ if load_stable_diffusion_format:
65
+ print(f"load StableDiffusion checkpoint: {name_or_path}")
66
+ (
67
+ text_encoder1,
68
+ text_encoder2,
69
+ vae,
70
+ unet,
71
+ logit_scale,
72
+ ckpt_info,
73
+ ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
74
+ else:
75
+ # Diffusers model is loaded to CPU
76
+ from diffusers import StableDiffusionXLPipeline
77
+
78
+ variant = "fp16" if weight_dtype == torch.float16 else None
79
+ print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
80
+ try:
81
+ try:
82
+ pipe = StableDiffusionXLPipeline.from_pretrained(
83
+ name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
84
+ )
85
+ except EnvironmentError as ex:
86
+ if variant is not None:
87
+ print("try to load fp32 model")
88
+ pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
89
+ else:
90
+ raise ex
91
+ except EnvironmentError as ex:
92
+ print(
93
+ f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
94
+ )
95
+ raise ex
96
+
97
+ text_encoder1 = pipe.text_encoder
98
+ text_encoder2 = pipe.text_encoder_2
99
+
100
+ # convert to fp32 for cache text_encoders outputs
101
+ if text_encoder1.dtype != torch.float32:
102
+ text_encoder1 = text_encoder1.to(dtype=torch.float32)
103
+ if text_encoder2.dtype != torch.float32:
104
+ text_encoder2 = text_encoder2.to(dtype=torch.float32)
105
+
106
+ vae = pipe.vae
107
+ unet = pipe.unet
108
+ del pipe
109
+
110
+ # Diffusers U-Net to original U-Net
111
+ state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
112
+ with init_empty_weights():
113
+ unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
114
+ sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
115
+ print("U-Net converted to original U-Net")
116
+
117
+ logit_scale = None
118
+ ckpt_info = None
119
+
120
+ # VAEを読み込む
121
+ if vae_path is not None:
122
+ vae = model_util.load_vae(vae_path, weight_dtype)
123
+ print("additional VAE loaded")
124
+
125
+ return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
126
+
127
+
128
+ def load_tokenizers(args: argparse.Namespace):
129
+ print("prepare tokenizers")
130
+
131
+ original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
132
+ tokeniers = []
133
+ for i, original_path in enumerate(original_paths):
134
+ tokenizer: CLIPTokenizer = None
135
+ if args.tokenizer_cache_dir:
136
+ local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
137
+ if os.path.exists(local_tokenizer_path):
138
+ print(f"load tokenizer from cache: {local_tokenizer_path}")
139
+ tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
140
+
141
+ if tokenizer is None:
142
+ tokenizer = CLIPTokenizer.from_pretrained(original_path)
143
+
144
+ if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
145
+ print(f"save Tokenizer to cache: {local_tokenizer_path}")
146
+ tokenizer.save_pretrained(local_tokenizer_path)
147
+
148
+ if i == 1:
149
+ tokenizer.pad_token_id = 0 # fix pad token id to make same as open clip tokenizer
150
+
151
+ tokeniers.append(tokenizer)
152
+
153
+ if hasattr(args, "max_token_length") and args.max_token_length is not None:
154
+ print(f"update token length: {args.max_token_length}")
155
+
156
+ return tokeniers
157
+
158
+
159
+ def match_mixed_precision(args, weight_dtype):
160
+ if args.full_fp16:
161
+ assert (
162
+ weight_dtype == torch.float16
163
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
164
+ return weight_dtype
165
+ elif args.full_bf16:
166
+ assert (
167
+ weight_dtype == torch.bfloat16
168
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
169
+ return weight_dtype
170
+ else:
171
+ return None
172
+
173
+
174
+ def timestep_embedding(timesteps, dim, max_period=10000):
175
+ """
176
+ Create sinusoidal timestep embeddings.
177
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
178
+ These may be fractional.
179
+ :param dim: the dimension of the output.
180
+ :param max_period: controls the minimum frequency of the embeddings.
181
+ :return: an [N x dim] Tensor of positional embeddings.
182
+ """
183
+ half = dim // 2
184
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
185
+ device=timesteps.device
186
+ )
187
+ args = timesteps[:, None].float() * freqs[None]
188
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
189
+ if dim % 2:
190
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
191
+ return embedding
192
+
193
+
194
+ def get_timestep_embedding(x, outdim):
195
+ assert len(x.shape) == 2
196
+ b, dims = x.shape[0], x.shape[1]
197
+ x = torch.flatten(x)
198
+ emb = timestep_embedding(x, outdim)
199
+ emb = torch.reshape(emb, (b, dims * outdim))
200
+ return emb
201
+
202
+
203
+ def get_size_embeddings(orig_size, crop_size, target_size, device):
204
+ emb1 = get_timestep_embedding(orig_size, 256)
205
+ emb2 = get_timestep_embedding(crop_size, 256)
206
+ emb3 = get_timestep_embedding(target_size, 256)
207
+ vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
208
+ return vector
209
+
210
+
211
+ def save_sd_model_on_train_end(
212
+ args: argparse.Namespace,
213
+ src_path: str,
214
+ save_stable_diffusion_format: bool,
215
+ use_safetensors: bool,
216
+ save_dtype: torch.dtype,
217
+ epoch: int,
218
+ global_step: int,
219
+ text_encoder1,
220
+ text_encoder2,
221
+ unet,
222
+ vae,
223
+ logit_scale,
224
+ ckpt_info,
225
+ ):
226
+ def sd_saver(ckpt_file, epoch_no, global_step):
227
+ sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
228
+ sdxl_model_util.save_stable_diffusion_checkpoint(
229
+ ckpt_file,
230
+ text_encoder1,
231
+ text_encoder2,
232
+ unet,
233
+ epoch_no,
234
+ global_step,
235
+ ckpt_info,
236
+ vae,
237
+ logit_scale,
238
+ sai_metadata,
239
+ save_dtype,
240
+ )
241
+
242
+ def diffusers_saver(out_dir):
243
+ sdxl_model_util.save_diffusers_checkpoint(
244
+ out_dir,
245
+ text_encoder1,
246
+ text_encoder2,
247
+ unet,
248
+ src_path,
249
+ vae,
250
+ use_safetensors=use_safetensors,
251
+ save_dtype=save_dtype,
252
+ )
253
+
254
+ train_util.save_sd_model_on_train_end_common(
255
+ args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
256
+ )
257
+
258
+
259
+ # epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
260
+ # on_epoch_end: Trueならepoch終了時、Falseならstep経過時
261
+ def save_sd_model_on_epoch_end_or_stepwise(
262
+ args: argparse.Namespace,
263
+ on_epoch_end: bool,
264
+ accelerator,
265
+ src_path,
266
+ save_stable_diffusion_format: bool,
267
+ use_safetensors: bool,
268
+ save_dtype: torch.dtype,
269
+ epoch: int,
270
+ num_train_epochs: int,
271
+ global_step: int,
272
+ text_encoder1,
273
+ text_encoder2,
274
+ unet,
275
+ vae,
276
+ logit_scale,
277
+ ckpt_info,
278
+ ema = None,
279
+ params_to_replace = None,
280
+ ):
281
+ def sd_saver(ckpt_file, epoch_no, global_step):
282
+ sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
283
+ sdxl_model_util.save_stable_diffusion_checkpoint(
284
+ ckpt_file,
285
+ text_encoder1,
286
+ text_encoder2,
287
+ unet,
288
+ epoch_no,
289
+ global_step,
290
+ ckpt_info,
291
+ vae,
292
+ logit_scale,
293
+ sai_metadata,
294
+ save_dtype,
295
+ )
296
+
297
+ def diffusers_saver(out_dir):
298
+ sdxl_model_util.save_diffusers_checkpoint(
299
+ out_dir,
300
+ text_encoder1,
301
+ text_encoder2,
302
+ unet,
303
+ src_path,
304
+ vae,
305
+ use_safetensors=use_safetensors,
306
+ save_dtype=save_dtype,
307
+ )
308
+
309
+ if args.enable_ema and not args.ema_save_only_ema_weights and ema:
310
+ temp_name = args.output_name
311
+ args.output_name = args.output_name + "-non-EMA"
312
+
313
+ train_util.save_sd_model_on_epoch_end_or_stepwise_common(
314
+ args,
315
+ on_epoch_end,
316
+ accelerator,
317
+ save_stable_diffusion_format,
318
+ use_safetensors,
319
+ epoch,
320
+ num_train_epochs,
321
+ global_step,
322
+ sd_saver,
323
+ diffusers_saver,
324
+ )
325
+ args.output_name = temp_name if temp_name else args.output_name
326
+ if args.enable_ema and ema:
327
+ with ema.ema_parameters(params_to_replace):
328
+ print("Saving EMA:")
329
+ train_util.save_sd_model_on_epoch_end_or_stepwise_common(
330
+ args,
331
+ on_epoch_end,
332
+ accelerator,
333
+ save_stable_diffusion_format,
334
+ use_safetensors,
335
+ epoch,
336
+ num_train_epochs,
337
+ global_step,
338
+ sd_saver,
339
+ diffusers_saver,
340
+ )
341
+
342
+
343
+ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
344
+ parser.add_argument(
345
+ "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
346
+ )
347
+ parser.add_argument(
348
+ "--cache_text_encoder_outputs_to_disk",
349
+ action="store_true",
350
+ help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
351
+ )
352
+
353
+
354
+ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
355
+ assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
356
+ if args.v_parameterization:
357
+ print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
358
+
359
+ if args.clip_skip is not None:
360
+ print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
361
+
362
+ # if args.multires_noise_iterations:
363
+ # print(
364
+ # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
365
+ # )
366
+ # else:
367
+ # if args.noise_offset is None:
368
+ # args.noise_offset = DEFAULT_NOISE_OFFSET
369
+ # elif args.noise_offset != DEFAULT_NOISE_OFFSET:
370
+ # print(
371
+ # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
372
+ # )
373
+ # print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
374
+
375
+ assert (
376
+ not hasattr(args, "weighted_captions") or not args.weighted_captions
377
+ ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
378
+
379
+ if supportTextEncoderCaching:
380
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
381
+ args.cache_text_encoder_outputs = True
382
+ print(
383
+ "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
384
+ + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
385
+ )
386
+
387
+
388
+ def sample_images(*args, **kwargs):
389
+ return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
train_db.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DreamBooth training
2
+ # XXX dropped option: fine_tune
3
+
4
+ import gc
5
+ import argparse
6
+ import itertools
7
+ import math
8
+ import os
9
+ from multiprocessing import Value
10
+ import toml
11
+
12
+ from tqdm import tqdm
13
+ import torch
14
+
15
+ from library.ipex_interop import init_ipex
16
+
17
+ init_ipex()
18
+
19
+ from accelerate.utils import set_seed
20
+ from diffusers import DDPMScheduler
21
+
22
+ import library.train_util as train_util
23
+ import library.config_util as config_util
24
+ from library.config_util import (
25
+ ConfigSanitizer,
26
+ BlueprintGenerator,
27
+ )
28
+ import library.custom_train_functions as custom_train_functions
29
+ from library.custom_train_functions import (
30
+ apply_snr_weight,
31
+ get_weighted_text_embeddings,
32
+ prepare_scheduler_for_custom_training,
33
+ pyramid_noise_like,
34
+ apply_noise_offset,
35
+ scale_v_prediction_loss_like_noise_prediction,
36
+ apply_debiased_estimation,
37
+ )
38
+
39
+ # perlin_noise,
40
+
41
+
42
+ def train(args):
43
+ train_util.verify_training_args(args)
44
+ train_util.prepare_dataset_args(args, False)
45
+
46
+ cache_latents = args.cache_latents
47
+
48
+ if args.seed is not None:
49
+ set_seed(args.seed) # 乱数系列を初期化する
50
+
51
+ tokenizer = train_util.load_tokenizer(args)
52
+
53
+ # データセットを準備する
54
+ if args.dataset_class is None:
55
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
56
+ if args.dataset_config is not None:
57
+ print(f"Load dataset config from {args.dataset_config}")
58
+ user_config = config_util.load_user_config(args.dataset_config)
59
+ ignored = ["train_data_dir", "reg_data_dir"]
60
+ if any(getattr(args, attr) is not None for attr in ignored):
61
+ print(
62
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
63
+ ", ".join(ignored)
64
+ )
65
+ )
66
+ else:
67
+ user_config = {
68
+ "datasets": [
69
+ {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
70
+ ]
71
+ }
72
+
73
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
74
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
75
+ else:
76
+ train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
77
+
78
+ current_epoch = Value("i", 0)
79
+ current_step = Value("i", 0)
80
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
81
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
82
+
83
+ if args.no_token_padding:
84
+ train_dataset_group.disable_token_padding()
85
+
86
+ if args.debug_dataset:
87
+ train_util.debug_dataset(train_dataset_group)
88
+ return
89
+
90
+ if cache_latents:
91
+ assert (
92
+ train_dataset_group.is_latent_cacheable()
93
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
94
+
95
+ # acceleratorを準備する
96
+ print("prepare accelerator")
97
+
98
+ if args.gradient_accumulation_steps > 1:
99
+ print(
100
+ f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
101
+ )
102
+ print(
103
+ f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
104
+ )
105
+
106
+ accelerator = train_util.prepare_accelerator(args)
107
+
108
+ # mixed precisionに対応した型を用意しておき適宜castする
109
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
110
+ vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
111
+
112
+ # モデルを読み込む
113
+ text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
114
+
115
+ # verify load/save model formats
116
+ if load_stable_diffusion_format:
117
+ src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
118
+ src_diffusers_model_path = None
119
+ else:
120
+ src_stable_diffusion_ckpt = None
121
+ src_diffusers_model_path = args.pretrained_model_name_or_path
122
+
123
+ if args.save_model_as is None:
124
+ save_stable_diffusion_format = load_stable_diffusion_format
125
+ use_safetensors = args.use_safetensors
126
+ else:
127
+ save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
128
+ use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
129
+
130
+ # モデルに xformers とか memory efficient attention を組み込む
131
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
132
+
133
+ # 学習を準備する
134
+ if cache_latents:
135
+ vae.to(accelerator.device, dtype=vae_dtype)
136
+ vae.requires_grad_(False)
137
+ vae.eval()
138
+ with torch.no_grad():
139
+ train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
140
+ vae.to("cpu")
141
+ if torch.cuda.is_available():
142
+ torch.cuda.empty_cache()
143
+ gc.collect()
144
+
145
+ accelerator.wait_for_everyone()
146
+
147
+ # 学習を準備する:モデルを適切な状態にする
148
+ train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
149
+ unet.requires_grad_(True) # 念のため追加
150
+ text_encoder.requires_grad_(train_text_encoder)
151
+ if not train_text_encoder:
152
+ accelerator.print("Text Encoder is not trained.")
153
+
154
+ if args.gradient_checkpointing:
155
+ unet.enable_gradient_checkpointing()
156
+ text_encoder.gradient_checkpointing_enable()
157
+
158
+ if not cache_latents:
159
+ vae.requires_grad_(False)
160
+ vae.eval()
161
+ vae.to(accelerator.device, dtype=weight_dtype)
162
+
163
+ # 学習に必要なクラスを準備する
164
+ accelerator.print("prepare optimizer, data loader etc.")
165
+ if train_text_encoder:
166
+ if args.learning_rate_te is None:
167
+ # wightout list, adamw8bit is crashed
168
+ trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
169
+ else:
170
+ trainable_params = [
171
+ {"params": list(unet.parameters()), "lr": args.learning_rate},
172
+ {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
173
+ ]
174
+ else:
175
+ trainable_params = unet.parameters()
176
+
177
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params)
178
+
179
+ # dataloaderを準備する
180
+ # DataLoaderのプロセス数:0はメインプロセスになる
181
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
182
+ train_dataloader = torch.utils.data.DataLoader(
183
+ train_dataset_group,
184
+ batch_size=1,
185
+ shuffle=True,
186
+ collate_fn=collator,
187
+ num_workers=n_workers,
188
+ persistent_workers=args.persistent_data_loader_workers,
189
+ )
190
+
191
+ # 学習ステップ数を計算する
192
+ if args.max_train_epochs is not None:
193
+ args.max_train_steps = args.max_train_epochs * math.ceil(
194
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
195
+ )
196
+ accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
197
+
198
+ # データセット側にも学習ステップを送信
199
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
200
+
201
+ if args.stop_text_encoder_training is None:
202
+ args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
203
+
204
+ # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
205
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
206
+
207
+ # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
208
+ if args.full_fp16:
209
+ assert (
210
+ args.mixed_precision == "fp16"
211
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
212
+ accelerator.print("enable full fp16 training.")
213
+ unet.to(weight_dtype)
214
+ text_encoder.to(weight_dtype)
215
+
216
+ if args.enable_ema:
217
+ #ema_dtype = weight_dtype if (args.full_bf16 or args.full_fp16) else torch.float
218
+ ema = EMAModel(params_to_optimize, decay=args.ema_decay, beta=args.ema_exp_beta, max_train_steps=args.max_train_steps)
219
+ ema.to(accelerator.device, dtype=weight_dtype)
220
+ ema = accelerator.prepare(ema)
221
+ else:
222
+ ema = None
223
+ # acceleratorがなんかよろしくやってくれるらしい
224
+ if train_text_encoder:
225
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
226
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
227
+ )
228
+ else:
229
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
230
+
231
+ if not train_text_encoder:
232
+ text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
233
+
234
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
235
+ if args.full_fp16:
236
+ train_util.patch_accelerator_for_fp16_training(accelerator)
237
+
238
+ # resumeする
239
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
240
+
241
+ # epoch数を計算する
242
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
243
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
244
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
245
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
246
+
247
+ # 学習する
248
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
249
+ accelerator.print("running training / 学習開始")
250
+ accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
251
+ accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
252
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
253
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
254
+ accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
255
+ accelerator.print(
256
+ f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
257
+ )
258
+ accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
259
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
260
+
261
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
262
+ global_step = 0
263
+
264
+ noise_scheduler = DDPMScheduler(
265
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
266
+ )
267
+ prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
268
+ if args.zero_terminal_snr:
269
+ custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
270
+
271
+ if accelerator.is_main_process:
272
+ init_kwargs = {}
273
+ if args.wandb_run_name:
274
+ init_kwargs['wandb'] = {'name': args.wandb_run_name}
275
+ if args.log_tracker_config is not None:
276
+ init_kwargs = toml.load(args.log_tracker_config)
277
+ accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
278
+
279
+ # For --sample_at_first
280
+ train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
281
+
282
+ loss_recorder = train_util.LossRecorder()
283
+ for epoch in range(num_train_epochs):
284
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
285
+ current_epoch.value = epoch + 1
286
+
287
+ # 指定したステップ数までText Encoderを学習する:epoch最初の状態
288
+ unet.train()
289
+ # train==True is required to enable gradient_checkpointing
290
+ if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
291
+ text_encoder.train()
292
+
293
+ for step, batch in enumerate(train_dataloader):
294
+ current_step.value = global_step
295
+ # 指定したステップ数でText Encoderの学習を止める
296
+ if global_step == args.stop_text_encoder_training:
297
+ accelerator.print(f"stop text encoder training at step {global_step}")
298
+ if not args.gradient_checkpointing:
299
+ text_encoder.train(False)
300
+ text_encoder.requires_grad_(False)
301
+
302
+ with accelerator.accumulate(unet):
303
+ with torch.no_grad():
304
+ # latentに変換
305
+ if cache_latents:
306
+ latents = batch["latents"].to(accelerator.device)
307
+ else:
308
+ latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
309
+ latents = latents * 0.18215
310
+ b_size = latents.shape[0]
311
+
312
+ # Get the text embedding for conditioning
313
+ with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
314
+ if args.weighted_captions:
315
+ encoder_hidden_states = get_weighted_text_embeddings(
316
+ tokenizer,
317
+ text_encoder,
318
+ batch["captions"],
319
+ accelerator.device,
320
+ args.max_token_length // 75 if args.max_token_length else 1,
321
+ clip_skip=args.clip_skip,
322
+ )
323
+ else:
324
+ input_ids = batch["input_ids"].to(accelerator.device)
325
+ encoder_hidden_states = train_util.get_hidden_states(
326
+ args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
327
+ )
328
+
329
+ # Sample noise, sample a random timestep for each image, and add noise to the latents,
330
+ # with noise offset and/or multires noise if specified
331
+ noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
332
+
333
+ # Predict the noise residual
334
+ with accelerator.autocast():
335
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
336
+
337
+ if args.v_parameterization:
338
+ # v-parameterization training
339
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
340
+ else:
341
+ target = noise
342
+
343
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
344
+ loss = loss.mean([1, 2, 3])
345
+
346
+ loss_weights = batch["loss_weights"] # 各sampleごとのweight
347
+ loss = loss * loss_weights
348
+
349
+ if args.min_snr_gamma:
350
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
351
+ if args.scale_v_pred_loss_like_noise_pred:
352
+ loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
353
+ if args.debiased_estimation_loss:
354
+ loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
355
+
356
+ loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
357
+
358
+ accelerator.backward(loss)
359
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
360
+ if train_text_encoder:
361
+ params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
362
+ else:
363
+ params_to_clip = unet.parameters()
364
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
365
+
366
+ optimizer.step()
367
+ lr_scheduler.step()
368
+ optimizer.zero_grad(set_to_none=True)
369
+ if args.enable_ema:
370
+ with torch.no_grad(), accelerator.autocast():
371
+ ema.step(params_to_optimize)
372
+
373
+ # Checks if the accelerator has performed an optimization step behind the scenes
374
+ if accelerator.sync_gradients:
375
+ progress_bar.update(1)
376
+ global_step += 1
377
+
378
+ train_util.sample_images(
379
+ accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
380
+ )
381
+
382
+ # 指定ステップごとにモデルを保存
383
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
384
+ accelerator.wait_for_everyone()
385
+ if accelerator.is_main_process:
386
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
387
+ train_util.save_sd_model_on_epoch_end_or_stepwise(
388
+ args,
389
+ False,
390
+ accelerator,
391
+ src_path,
392
+ save_stable_diffusion_format,
393
+ use_safetensors,
394
+ save_dtype,
395
+ epoch,
396
+ num_train_epochs,
397
+ global_step,
398
+ accelerator.unwrap_model(text_encoder),
399
+ accelerator.unwrap_model(unet),
400
+ vae,
401
+ logit_scale,
402
+ ckpt_info,
403
+ ema=ema,
404
+ params_to_replace=params_to_optimize,
405
+ )
406
+
407
+ current_loss = loss.detach().item()
408
+ if args.logging_dir is not None:
409
+ logs = {"loss": current_loss}
410
+ train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
411
+ accelerator.log(logs, step=global_step)
412
+
413
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
414
+ avr_loss: float = loss_recorder.moving_average
415
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
416
+ progress_bar.set_postfix(**logs)
417
+
418
+ if global_step >= args.max_train_steps:
419
+ break
420
+
421
+ if args.logging_dir is not None:
422
+ logs = {"loss/epoch": loss_recorder.moving_average}
423
+ accelerator.log(logs, step=epoch + 1)
424
+
425
+ accelerator.wait_for_everyone()
426
+
427
+ if args.save_every_n_epochs is not None:
428
+ if accelerator.is_main_process:
429
+ # checking for saving is in util
430
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
431
+ train_util.save_sd_model_on_epoch_end_or_stepwise(
432
+ args,
433
+ True,
434
+ accelerator,
435
+ src_path,
436
+ save_stable_diffusion_format,
437
+ use_safetensors,
438
+ save_dtype,
439
+ epoch,
440
+ num_train_epochs,
441
+ global_step,
442
+ accelerator.unwrap_model(text_encoder),
443
+ accelerator.unwrap_model(unet),
444
+ vae,
445
+ logit_scale,
446
+ ckpt_info,
447
+ ema=ema,
448
+ params_to_replace=params_to_optimize,
449
+ )
450
+
451
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
452
+
453
+ is_main_process = accelerator.is_main_process
454
+ if is_main_process:
455
+ unet = accelerator.unwrap_model(unet)
456
+ if args.enable_ema:
457
+ ema = accelerator.unwrap_model(ema)
458
+
459
+ accelerator.end_training()
460
+
461
+ if args.save_state and is_main_process:
462
+ train_util.save_state_on_train_end(args, accelerator)
463
+
464
+ del accelerator # この後メモリを使うのでこれは消す
465
+
466
+ if is_main_process:
467
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
468
+ train_util.save_sd_model_on_train_end(
469
+ args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
470
+ if args.enable_ema and not args.ema_save_only_ema_weights:
471
+ temp_name = args.output_name
472
+ args.output_name = args.output_name + "-non-EMA"
473
+ sdxl_train_util.save_sd_model_on_train_end(
474
+ args,
475
+ src_path,
476
+ save_stable_diffusion_format,
477
+ use_safetensors,
478
+ save_dtype,
479
+ epoch,
480
+ global_step,
481
+ text_encoder1,
482
+ text_encoder2,
483
+ unet,
484
+ vae,
485
+ logit_scale,
486
+ ckpt_info,
487
+ )
488
+ args.output_name = temp_name
489
+ if args.enable_ema:
490
+ print("Saving EMA:")
491
+ ema.copy_to(params_to_optimize)
492
+ )
493
+ print("model saved.")
494
+
495
+
496
+ def setup_parser() -> argparse.ArgumentParser:
497
+ parser = argparse.ArgumentParser()
498
+
499
+ train_util.add_sd_models_arguments(parser)
500
+ train_util.add_dataset_arguments(parser, True, False, True)
501
+ train_util.add_training_arguments(parser, True)
502
+ train_util.add_sd_saving_arguments(parser)
503
+ train_util.add_optimizer_arguments(parser)
504
+ config_util.add_config_arguments(parser)
505
+ custom_train_functions.add_custom_train_arguments(parser)
506
+
507
+ parser.add_argument(
508
+ "--learning_rate_te",
509
+ type=float,
510
+ default=None,
511
+ help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
512
+ )
513
+ parser.add_argument(
514
+ "--no_token_padding",
515
+ action="store_true",
516
+ help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)",
517
+ )
518
+ parser.add_argument(
519
+ "--stop_text_encoder_training",
520
+ type=int,
521
+ default=None,
522
+ help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
523
+ )
524
+ parser.add_argument(
525
+ "--no_half_vae",
526
+ action="store_true",
527
+ help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
528
+ )
529
+
530
+ return parser
531
+
532
+
533
+ if __name__ == "__main__":
534
+ parser = setup_parser()
535
+
536
+ args = parser.parse_args()
537
+ args = train_util.read_config_from_file(args, parser)
538
+
539
+ train(args)
train_util.py ADDED
The diff for this file is too large to render. See raw diff