nroggendorff commited on
Commit
bb88d92
·
verified ·
1 Parent(s): 7271fdc

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +391 -419
train.py CHANGED
@@ -13,11 +13,9 @@
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
- """Fine-tuning script for Stable Diffusion XL for text2image."""
17
 
18
  import argparse
19
- import functools
20
- import gc
21
  import logging
22
  import math
23
  import os
@@ -26,7 +24,6 @@ import shutil
26
  from contextlib import nullcontext
27
  from pathlib import Path
28
 
29
- import accelerate
30
  import datasets
31
  import numpy as np
32
  import torch
@@ -35,45 +32,55 @@ import torch.utils.checkpoint
35
  import transformers
36
  from accelerate import Accelerator
37
  from accelerate.logging import get_logger
38
- from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
39
- from datasets import concatenate_datasets, load_dataset
40
  from huggingface_hub import create_repo, upload_folder
41
  from packaging import version
 
 
42
  from torchvision import transforms
43
  from torchvision.transforms.functional import crop
44
  from tqdm.auto import tqdm
45
  from transformers import AutoTokenizer, PretrainedConfig
46
 
47
  import diffusers
48
- from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
 
 
 
 
 
 
49
  from diffusers.optimization import get_scheduler
50
- from diffusers.training_utils import EMAModel, compute_snr
51
- from diffusers.utils import check_min_version, is_wandb_available
 
 
 
 
 
52
  from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
53
  from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
54
  from diffusers.utils.torch_utils import is_compiled_module
55
 
56
 
 
 
 
57
  # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
58
  check_min_version("0.36.0.dev0")
59
 
60
  logger = get_logger(__name__)
61
  if is_torch_npu_available():
62
- import torch_npu
63
-
64
  torch.npu.config.allow_internal_format = False
65
 
66
- DATASET_NAME_MAPPING = {
67
- "lambdalabs/naruto-blip-captions": ("image", "text"),
68
- }
69
-
70
 
71
  def save_model_card(
72
  repo_id: str,
73
  images: list = None,
74
- validation_prompt: str = None,
75
  base_model: str = None,
76
  dataset_name: str = None,
 
77
  repo_folder: str = None,
78
  vae_path: str = None,
79
  ):
@@ -84,14 +91,15 @@ def save_model_card(
84
  img_str += f"![img_{i}](./image_{i}.png)\n"
85
 
86
  model_description = f"""
87
- # Text-to-image finetuning - {repo_id}
88
 
89
- This pipeline was finetuned from **{base_model}** on the **{dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
90
  {img_str}
91
 
 
 
92
  Special VAE used for training: {vae_path}.
93
  """
94
-
95
  model_card = load_or_create_model_card(
96
  repo_id_or_path=repo_id,
97
  from_training=True,
@@ -105,14 +113,56 @@ Special VAE used for training: {vae_path}.
105
  "stable-diffusion-xl",
106
  "stable-diffusion-xl-diffusers",
107
  "text-to-image",
108
- "diffusers-training",
109
  "diffusers",
 
 
110
  ]
111
  model_card = populate_model_card(model_card, tags=tags)
112
 
113
  model_card.save(os.path.join(repo_folder, "README.md"))
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  def import_model_class_from_model_name_or_path(
117
  pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
118
  ):
@@ -226,16 +276,10 @@ def parse_args(input_args=None):
226
  "value if set."
227
  ),
228
  )
229
- parser.add_argument(
230
- "--proportion_empty_prompts",
231
- type=float,
232
- default=0,
233
- help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
234
- )
235
  parser.add_argument(
236
  "--output_dir",
237
  type=str,
238
- default="sdxl-model-finetuned",
239
  help="The output directory where the model predictions and checkpoints will be written.",
240
  )
241
  parser.add_argument(
@@ -268,6 +312,11 @@ def parse_args(input_args=None):
268
  action="store_true",
269
  help="whether to randomly flip images horizontally",
270
  )
 
 
 
 
 
271
  parser.add_argument(
272
  "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
273
  )
@@ -338,55 +387,6 @@ def parse_args(input_args=None):
338
  parser.add_argument(
339
  "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
340
  )
341
- parser.add_argument(
342
- "--timestep_bias_strategy",
343
- type=str,
344
- default="none",
345
- choices=["earlier", "later", "range", "none"],
346
- help=(
347
- "The timestep bias strategy, which may help direct the model toward learning low or high frequency details."
348
- " Choices: ['earlier', 'later', 'range', 'none']."
349
- " The default is 'none', which means no bias is applied, and training proceeds normally."
350
- " The value of 'later' will increase the frequency of the model's final training timesteps."
351
- ),
352
- )
353
- parser.add_argument(
354
- "--timestep_bias_multiplier",
355
- type=float,
356
- default=1.0,
357
- help=(
358
- "The multiplier for the bias. Defaults to 1.0, which means no bias is applied."
359
- " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it."
360
- ),
361
- )
362
- parser.add_argument(
363
- "--timestep_bias_begin",
364
- type=int,
365
- default=0,
366
- help=(
367
- "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias."
368
- " Defaults to zero, which equates to having no specific bias."
369
- ),
370
- )
371
- parser.add_argument(
372
- "--timestep_bias_end",
373
- type=int,
374
- default=1000,
375
- help=(
376
- "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias."
377
- " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on."
378
- ),
379
- )
380
- parser.add_argument(
381
- "--timestep_bias_portion",
382
- type=float,
383
- default=0.25,
384
- help=(
385
- "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased."
386
- " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines"
387
- " whether the biased portions are in the earlier or later timesteps."
388
- ),
389
- )
390
  parser.add_argument(
391
  "--snr_gamma",
392
  type=float,
@@ -394,7 +394,6 @@ def parse_args(input_args=None):
394
  help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
395
  "More details here: https://huggingface.co/papers/2303.09556.",
396
  )
397
- parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
398
  parser.add_argument(
399
  "--allow_tf32",
400
  action="store_true",
@@ -464,12 +463,23 @@ def parse_args(input_args=None):
464
  )
465
  parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
466
  parser.add_argument(
467
- "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
468
  )
469
  parser.add_argument(
470
- "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
471
  )
472
  parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
 
 
 
 
 
 
 
 
 
 
 
473
  parser.add_argument(
474
  "--image_interpolation_mode",
475
  type=str,
@@ -492,109 +502,53 @@ def parse_args(input_args=None):
492
  # Sanity checks
493
  if args.dataset_name is None and args.train_data_dir is None:
494
  raise ValueError("Need either a dataset name or a training folder.")
495
- if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
496
- raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
497
 
498
  return args
499
 
500
 
501
- # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
502
- def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):
503
- prompt_embeds_list = []
504
- prompt_batch = batch[caption_column]
505
-
506
- captions = []
507
- for caption in prompt_batch:
508
- if random.random() < proportion_empty_prompts:
509
- captions.append("")
510
- elif isinstance(caption, str):
511
- captions.append(caption)
512
- elif isinstance(caption, (list, np.ndarray)):
513
- # take a random caption if there are multiple
514
- captions.append(random.choice(caption) if is_train else caption[0])
515
-
516
- with torch.no_grad():
517
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
518
- text_inputs = tokenizer(
519
- captions,
520
- padding="max_length",
521
- max_length=tokenizer.model_max_length,
522
- truncation=True,
523
- return_tensors="pt",
524
- )
525
- text_input_ids = text_inputs.input_ids
526
- prompt_embeds = text_encoder(
527
- text_input_ids.to(text_encoder.device),
528
- output_hidden_states=True,
529
- return_dict=False,
530
- )
531
-
532
- # We are only ALWAYS interested in the pooled output of the final text encoder
533
- pooled_prompt_embeds = prompt_embeds[0]
534
- prompt_embeds = prompt_embeds[-1][-2]
535
- bs_embed, seq_len, _ = prompt_embeds.shape
536
- prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
537
- prompt_embeds_list.append(prompt_embeds)
538
-
539
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
540
- pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
541
- return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}
542
-
543
-
544
- def compute_vae_encodings(batch, vae):
545
- images = batch.pop("pixel_values")
546
- pixel_values = torch.stack(list(images))
547
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
548
- pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
549
 
550
- with torch.no_grad():
551
- model_input = vae.encode(pixel_values).latent_dist.sample()
552
- model_input = model_input * vae.config.scaling_factor
553
 
554
- # There might have slightly performance improvement
555
- # by changing model_input.cpu() to accelerator.gather(model_input)
556
- return {"model_input": model_input.cpu()}
 
 
 
 
 
 
 
557
 
558
 
559
- def generate_timestep_weights(args, num_timesteps):
560
- weights = torch.ones(num_timesteps)
 
561
 
562
- # Determine the indices to bias
563
- num_to_bias = int(args.timestep_bias_portion * num_timesteps)
 
 
 
 
 
564
 
565
- if args.timestep_bias_strategy == "later":
566
- bias_indices = slice(-num_to_bias, None)
567
- elif args.timestep_bias_strategy == "earlier":
568
- bias_indices = slice(0, num_to_bias)
569
- elif args.timestep_bias_strategy == "range":
570
- # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.
571
- range_begin = args.timestep_bias_begin
572
- range_end = args.timestep_bias_end
573
- if range_begin < 0:
574
- raise ValueError(
575
- "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero."
576
- )
577
- if range_end > num_timesteps:
578
- raise ValueError(
579
- "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps."
580
- )
581
- bias_indices = slice(range_begin, range_end)
582
- else: # 'none' or any other string
583
- return weights
584
- if args.timestep_bias_multiplier <= 0:
585
- return ValueError(
586
- "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps."
587
- " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead."
588
- " A timestep bias multiplier less than or equal to 0 is not allowed."
589
  )
590
 
591
- # Apply the bias
592
- weights[bias_indices] *= args.timestep_bias_multiplier
593
-
594
- # Normalize
595
- weights /= weights.sum()
 
596
 
597
- return weights
 
 
598
 
599
 
600
  def main(args):
@@ -606,30 +560,22 @@ def main(args):
606
 
607
  logging_dir = Path(args.output_dir, args.logging_dir)
608
 
609
- accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
610
-
611
  if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
612
  # due to pytorch#99272, MPS does not yet support bfloat16.
613
  raise ValueError(
614
  "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
615
  )
616
 
 
 
617
  accelerator = Accelerator(
618
  gradient_accumulation_steps=args.gradient_accumulation_steps,
619
  mixed_precision=args.mixed_precision,
620
  log_with=args.report_to,
621
  project_config=accelerator_project_config,
 
622
  )
623
 
624
- # Disable AMP for MPS.
625
- if torch.backends.mps.is_available():
626
- accelerator.native_amp = False
627
-
628
- if args.report_to == "wandb":
629
- if not is_wandb_available():
630
- raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
631
- import wandb
632
-
633
  # Make one log on every process with the configuration for debugging.
634
  logging.basicConfig(
635
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -684,7 +630,6 @@ def main(args):
684
 
685
  # Load scheduler and models
686
  noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
687
- # Check for terminal SNR in combination with SNR Gamma
688
  text_encoder_one = text_encoder_cls_one.from_pretrained(
689
  args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
690
  )
@@ -706,14 +651,13 @@ def main(args):
706
  args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
707
  )
708
 
709
- # Freeze vae and text encoders.
710
  vae.requires_grad_(False)
711
  text_encoder_one.requires_grad_(False)
712
  text_encoder_two.requires_grad_(False)
713
- # Set unet as trainable.
714
- unet.train()
715
 
716
- # For mixed precision training we cast all non-trainable weights to half-precision
717
  # as these weights are only used for inference, keeping weights in full precision is not required.
718
  weight_dtype = torch.float32
719
  if accelerator.mixed_precision == "fp16":
@@ -723,22 +667,22 @@ def main(args):
723
 
724
  # Move unet, vae and text_encoder to device and cast to weight_dtype
725
  # The VAE is in float32 to avoid NaN losses.
726
- vae.to(accelerator.device, dtype=torch.float32)
 
 
 
 
 
727
  text_encoder_one.to(accelerator.device, dtype=weight_dtype)
728
  text_encoder_two.to(accelerator.device, dtype=weight_dtype)
729
 
730
- # Create EMA for the unet.
731
- if args.use_ema:
732
- ema_unet = UNet2DConditionModel.from_pretrained(
733
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
734
- )
735
- ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
736
  if args.enable_npu_flash_attention:
737
  if is_torch_npu_available():
738
  logger.info("npu flash attention enabled.")
739
  unet.enable_npu_flash_attention()
740
  else:
741
  raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
 
742
  if args.enable_xformers_memory_efficient_attention:
743
  if is_xformers_available():
744
  import xformers
@@ -752,44 +696,122 @@ def main(args):
752
  else:
753
  raise ValueError("xformers is not available. Make sure it is installed correctly")
754
 
755
- # `accelerate` 0.16.0 will have better support for customized saving
756
- if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
757
- # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
758
- def save_model_hook(models, weights, output_dir):
759
- if accelerator.is_main_process:
760
- if args.use_ema:
761
- ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
762
 
763
- for i, model in enumerate(models):
764
- model.save_pretrained(os.path.join(output_dir, "unet"))
 
 
765
 
766
- # make sure to pop weight so that corresponding model is not saved again
767
- if weights:
768
- weights.pop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769
 
770
- def load_model_hook(models, input_dir):
771
- if args.use_ema:
772
- load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
773
- ema_unet.load_state_dict(load_model.state_dict())
774
- ema_unet.to(accelerator.device)
775
- del load_model
776
 
777
- for _ in range(len(models)):
778
- # pop models so that they are not loaded again
779
- model = models.pop()
 
 
 
 
 
 
 
 
780
 
781
- # load diffusers style into model
782
- load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
783
- model.register_to_config(**load_model.config)
784
 
785
- model.load_state_dict(load_model.state_dict())
786
- del load_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
787
 
788
- accelerator.register_save_state_pre_hook(save_model_hook)
789
- accelerator.register_load_state_pre_hook(load_model_hook)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
790
 
791
  if args.gradient_checkpointing:
792
  unet.enable_gradient_checkpointing()
 
 
 
793
 
794
  # Enable TF32 for faster training on Ampere GPUs,
795
  # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
@@ -801,6 +823,13 @@ def main(args):
801
  args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
802
  )
803
 
 
 
 
 
 
 
 
804
  # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
805
  if args.use_8bit_adam:
806
  try:
@@ -815,7 +844,13 @@ def main(args):
815
  optimizer_class = torch.optim.AdamW
816
 
817
  # Optimizer creation
818
- params_to_optimize = unet.parameters()
 
 
 
 
 
 
819
  optimizer = optimizer_class(
820
  params_to_optimize,
821
  lr=args.learning_rate,
@@ -870,13 +905,39 @@ def main(args):
870
  )
871
 
872
  # Preprocessing the datasets.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873
  interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
 
 
874
  if interpolation is None:
875
- raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
876
- train_resize = transforms.Resize(args.resolution, interpolation=interpolation)
 
877
  train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
878
  train_flip = transforms.RandomHorizontalFlip(p=1.0)
879
- train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
 
 
 
 
 
880
 
881
  def preprocess_train(examples):
882
  images = [image.convert("RGB") for image in examples[image_column]]
@@ -905,73 +966,44 @@ def main(args):
905
  examples["original_sizes"] = original_sizes
906
  examples["crop_top_lefts"] = crop_top_lefts
907
  examples["pixel_values"] = all_images
 
 
 
 
 
 
 
908
  return examples
909
 
910
  with accelerator.main_process_first():
911
  if args.max_train_samples is not None:
912
  dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
913
  # Set the training transforms
914
- train_dataset = dataset["train"].with_transform(preprocess_train)
915
-
916
- # Let's first compute all the embeddings so that we can free up the text encoders
917
- # from memory. We will pre-compute the VAE encodings too.
918
- text_encoders = [text_encoder_one, text_encoder_two]
919
- tokenizers = [tokenizer_one, tokenizer_two]
920
- compute_embeddings_fn = functools.partial(
921
- encode_prompt,
922
- text_encoders=text_encoders,
923
- tokenizers=tokenizers,
924
- proportion_empty_prompts=args.proportion_empty_prompts,
925
- caption_column=args.caption_column,
926
- )
927
- compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
928
- with accelerator.main_process_first():
929
- from datasets.fingerprint import Hasher
930
-
931
- # fingerprint used by the cache for the other processes to load the result
932
- # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
933
- new_fingerprint = Hasher.hash(args)
934
- new_fingerprint_for_vae = Hasher.hash((vae_path, args))
935
- train_dataset_with_embeddings = train_dataset.map(
936
- compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
937
- )
938
- train_dataset_with_vae = train_dataset.map(
939
- compute_vae_encodings_fn,
940
- batched=True,
941
- batch_size=args.train_batch_size,
942
- new_fingerprint=new_fingerprint_for_vae,
943
- )
944
- precomputed_dataset = concatenate_datasets(
945
- [train_dataset_with_embeddings, train_dataset_with_vae.remove_columns(["image", "text"])], axis=1
946
- )
947
- precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)
948
-
949
- del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
950
- del text_encoders, tokenizers, vae
951
- gc.collect()
952
- if is_torch_npu_available():
953
- torch_npu.npu.empty_cache()
954
- elif torch.cuda.is_available():
955
- torch.cuda.empty_cache()
956
 
957
  def collate_fn(examples):
958
- model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
 
959
  original_sizes = [example["original_sizes"] for example in examples]
960
  crop_top_lefts = [example["crop_top_lefts"] for example in examples]
961
- prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
962
- pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
963
-
964
- return {
965
- "model_input": model_input,
966
- "prompt_embeds": prompt_embeds,
967
- "pooled_prompt_embeds": pooled_prompt_embeds,
968
  "original_sizes": original_sizes,
969
  "crop_top_lefts": crop_top_lefts,
970
  }
971
 
 
 
 
 
 
972
  # DataLoaders creation:
973
  train_dataloader = torch.utils.data.DataLoader(
974
- precomputed_dataset,
975
  shuffle=True,
976
  collate_fn=collate_fn,
977
  batch_size=args.train_batch_size,
@@ -993,12 +1025,14 @@ def main(args):
993
  )
994
 
995
  # Prepare everything with our `accelerator`.
996
- unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
997
- unet, optimizer, train_dataloader, lr_scheduler
998
- )
999
-
1000
- if args.use_ema:
1001
- ema_unet.to(accelerator.device)
 
 
1002
 
1003
  # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1004
  num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -1010,24 +1044,13 @@ def main(args):
1010
  # We need to initialize the trackers we use, and also store our configuration.
1011
  # The trackers initializes automatically on the main process.
1012
  if accelerator.is_main_process:
1013
- accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args))
1014
-
1015
- # Function for unwrapping if torch.compile() was used in accelerate.
1016
- def unwrap_model(model):
1017
- model = accelerator.unwrap_model(model)
1018
- model = model._orig_mod if is_compiled_module(model) else model
1019
- return model
1020
-
1021
- if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
1022
- autocast_ctx = nullcontext()
1023
- else:
1024
- autocast_ctx = torch.autocast(accelerator.device.type)
1025
 
1026
  # Train!
1027
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1028
 
1029
  logger.info("***** Running training *****")
1030
- logger.info(f" Num examples = {len(precomputed_dataset)}")
1031
  logger.info(f" Num Epochs = {args.num_train_epochs}")
1032
  logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1033
  logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
@@ -1073,11 +1096,25 @@ def main(args):
1073
  )
1074
 
1075
  for epoch in range(first_epoch, args.num_train_epochs):
 
 
 
 
1076
  train_loss = 0.0
1077
  for step, batch in enumerate(train_dataloader):
1078
  with accelerator.accumulate(unet):
 
 
 
 
 
 
 
 
 
 
 
1079
  # Sample noise that we'll add to the latents
1080
- model_input = batch["model_input"].to(accelerator.device)
1081
  noise = torch.randn_like(model_input)
1082
  if args.noise_offset:
1083
  # https://www.crosslabs.org//blog/diffusion-with-offset-noise
@@ -1086,29 +1123,23 @@ def main(args):
1086
  )
1087
 
1088
  bsz = model_input.shape[0]
1089
- if args.timestep_bias_strategy == "none":
1090
- # Sample a random timestep for each image without bias.
1091
- timesteps = torch.randint(
1092
- 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1093
- )
1094
- else:
1095
- # Sample a random timestep for each image, potentially biased by the timestep weights.
1096
- # Biasing the timestep weights allows us to spend less time training irrelevant timesteps.
1097
- weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(
1098
- model_input.device
1099
- )
1100
- timesteps = torch.multinomial(weights, bsz, replacement=True).long()
1101
 
1102
  # Add noise to the model input according to the noise magnitude at each timestep
1103
  # (this is the forward diffusion process)
1104
- noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)
1105
 
1106
  # time ids
1107
  def compute_time_ids(original_size, crops_coords_top_left):
1108
  # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1109
  target_size = (args.resolution, args.resolution)
1110
  add_time_ids = list(original_size + crops_coords_top_left + target_size)
1111
- add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
 
1112
  return add_time_ids
1113
 
1114
  add_time_ids = torch.cat(
@@ -1117,8 +1148,12 @@ def main(args):
1117
 
1118
  # Predict the noise residual
1119
  unet_added_conditions = {"time_ids": add_time_ids}
1120
- prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype)
1121
- pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
 
 
 
 
1122
  unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
1123
  model_pred = unet(
1124
  noisy_model_input,
@@ -1137,11 +1172,6 @@ def main(args):
1137
  target = noise
1138
  elif noise_scheduler.config.prediction_type == "v_prediction":
1139
  target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1140
- elif noise_scheduler.config.prediction_type == "sample":
1141
- # We set the target to latents here, but the model_pred will return the noise sample prediction.
1142
- target = model_input
1143
- # We will have to subtract the noise residual from the prediction to get the target sample.
1144
- model_pred = model_pred - noise
1145
  else:
1146
  raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1147
 
@@ -1163,7 +1193,9 @@ def main(args):
1163
  loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1164
  loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1165
  loss = loss.mean()
1166
-
 
 
1167
  # Gather the losses across all processes for logging (if we use distributed training).
1168
  avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1169
  train_loss += avg_loss.item() / args.gradient_accumulation_steps
@@ -1171,16 +1203,13 @@ def main(args):
1171
  # Backpropagate
1172
  accelerator.backward(loss)
1173
  if accelerator.sync_gradients:
1174
- params_to_clip = unet.parameters()
1175
- accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1176
  optimizer.step()
1177
  lr_scheduler.step()
1178
  optimizer.zero_grad()
1179
 
1180
  # Checks if the accelerator has performed an optimization step behind the scenes
1181
  if accelerator.sync_gradients:
1182
- if args.use_ema:
1183
- ema_unet.step(unet.parameters())
1184
  progress_bar.update(1)
1185
  global_step += 1
1186
  accelerator.log({"train_loss": train_loss}, step=global_step)
@@ -1221,137 +1250,80 @@ def main(args):
1221
 
1222
  if accelerator.is_main_process:
1223
  if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1224
- logger.info(
1225
- f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1226
- f" {args.validation_prompt}."
1227
- )
1228
- if args.use_ema:
1229
- # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
1230
- ema_unet.store(unet.parameters())
1231
- ema_unet.copy_to(unet.parameters())
1232
-
1233
  # create pipeline
1234
- vae = AutoencoderKL.from_pretrained(
1235
- vae_path,
1236
- subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1237
- revision=args.revision,
1238
- variant=args.variant,
1239
- )
1240
  pipeline = StableDiffusionXLPipeline.from_pretrained(
1241
  args.pretrained_model_name_or_path,
1242
  vae=vae,
1243
- unet=accelerator.unwrap_model(unet),
 
 
1244
  revision=args.revision,
1245
  variant=args.variant,
1246
  torch_dtype=weight_dtype,
1247
  )
1248
- if args.prediction_type is not None:
1249
- scheduler_args = {"prediction_type": args.prediction_type}
1250
- pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1251
 
1252
- pipeline = pipeline.to(accelerator.device)
1253
- pipeline.set_progress_bar_config(disable=True)
1254
-
1255
- # run inference
1256
- generator = (
1257
- torch.Generator(device=accelerator.device).manual_seed(args.seed)
1258
- if args.seed is not None
1259
- else None
1260
- )
1261
- pipeline_args = {"prompt": args.validation_prompt}
1262
-
1263
- with autocast_ctx:
1264
- images = [
1265
- pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
1266
- for _ in range(args.num_validation_images)
1267
- ]
1268
-
1269
- for tracker in accelerator.trackers:
1270
- if tracker.name == "tensorboard":
1271
- np_images = np.stack([np.asarray(img) for img in images])
1272
- tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
1273
- if tracker.name == "wandb":
1274
- tracker.log(
1275
- {
1276
- "validation": [
1277
- wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1278
- for i, image in enumerate(images)
1279
- ]
1280
- }
1281
- )
1282
 
1283
  del pipeline
1284
- if is_torch_npu_available():
1285
- torch_npu.npu.empty_cache()
1286
- elif torch.cuda.is_available():
1287
- torch.cuda.empty_cache()
1288
-
1289
- if args.use_ema:
1290
- # Switch back to the original UNet parameters.
1291
- ema_unet.restore(unet.parameters())
1292
 
 
1293
  accelerator.wait_for_everyone()
1294
  if accelerator.is_main_process:
1295
  unet = unwrap_model(unet)
1296
- if args.use_ema:
1297
- ema_unet.copy_to(unet.parameters())
1298
 
1299
- # Serialize pipeline.
1300
- vae = AutoencoderKL.from_pretrained(
1301
- vae_path,
1302
- subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1303
- revision=args.revision,
1304
- variant=args.variant,
1305
- torch_dtype=weight_dtype,
 
 
 
 
 
 
 
 
1306
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
1307
  pipeline = StableDiffusionXLPipeline.from_pretrained(
1308
  args.pretrained_model_name_or_path,
1309
- unet=unet,
1310
  vae=vae,
1311
  revision=args.revision,
1312
  variant=args.variant,
1313
  torch_dtype=weight_dtype,
1314
  )
1315
- if args.prediction_type is not None:
1316
- scheduler_args = {"prediction_type": args.prediction_type}
1317
- pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1318
- pipeline.save_pretrained(args.output_dir)
1319
 
1320
  # run inference
1321
- images = []
1322
  if args.validation_prompt and args.num_validation_images > 0:
1323
- pipeline = pipeline.to(accelerator.device)
1324
- generator = (
1325
- torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
1326
- )
1327
-
1328
- with autocast_ctx:
1329
- images = [
1330
- pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1331
- for _ in range(args.num_validation_images)
1332
- ]
1333
-
1334
- for tracker in accelerator.trackers:
1335
- if tracker.name == "tensorboard":
1336
- np_images = np.stack([np.asarray(img) for img in images])
1337
- tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1338
- if tracker.name == "wandb":
1339
- tracker.log(
1340
- {
1341
- "test": [
1342
- wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1343
- for i, image in enumerate(images)
1344
- ]
1345
- }
1346
- )
1347
 
1348
  if args.push_to_hub:
1349
  save_model_card(
1350
- repo_id=repo_id,
1351
  images=images,
1352
- validation_prompt=args.validation_prompt,
1353
  base_model=args.pretrained_model_name_or_path,
1354
  dataset_name=args.dataset_name,
 
1355
  repo_folder=args.output_dir,
1356
  vae_path=args.pretrained_vae_model_name_or_path,
1357
  )
 
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
+ """Fine-tuning script for Stable Diffusion XL for text2image with support for LoRA."""
17
 
18
  import argparse
 
 
19
  import logging
20
  import math
21
  import os
 
24
  from contextlib import nullcontext
25
  from pathlib import Path
26
 
 
27
  import datasets
28
  import numpy as np
29
  import torch
 
32
  import transformers
33
  from accelerate import Accelerator
34
  from accelerate.logging import get_logger
35
+ from accelerate.utils import DistributedDataParallelKwargs, DistributedType, ProjectConfiguration, set_seed
36
+ from datasets import load_dataset
37
  from huggingface_hub import create_repo, upload_folder
38
  from packaging import version
39
+ from peft import LoraConfig, set_peft_model_state_dict
40
+ from peft.utils import get_peft_model_state_dict
41
  from torchvision import transforms
42
  from torchvision.transforms.functional import crop
43
  from tqdm.auto import tqdm
44
  from transformers import AutoTokenizer, PretrainedConfig
45
 
46
  import diffusers
47
+ from diffusers import (
48
+ AutoencoderKL,
49
+ DDPMScheduler,
50
+ StableDiffusionXLPipeline,
51
+ UNet2DConditionModel,
52
+ )
53
+ from diffusers.loaders import StableDiffusionLoraLoaderMixin
54
  from diffusers.optimization import get_scheduler
55
+ from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
56
+ from diffusers.utils import (
57
+ check_min_version,
58
+ convert_state_dict_to_diffusers,
59
+ convert_unet_state_dict_to_peft,
60
+ is_wandb_available,
61
+ )
62
  from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
63
  from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
64
  from diffusers.utils.torch_utils import is_compiled_module
65
 
66
 
67
+ if is_wandb_available():
68
+ import wandb
69
+
70
  # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
71
  check_min_version("0.36.0.dev0")
72
 
73
  logger = get_logger(__name__)
74
  if is_torch_npu_available():
 
 
75
  torch.npu.config.allow_internal_format = False
76
 
 
 
 
 
77
 
78
  def save_model_card(
79
  repo_id: str,
80
  images: list = None,
 
81
  base_model: str = None,
82
  dataset_name: str = None,
83
+ train_text_encoder: bool = False,
84
  repo_folder: str = None,
85
  vae_path: str = None,
86
  ):
 
91
  img_str += f"![img_{i}](./image_{i}.png)\n"
92
 
93
  model_description = f"""
94
+ # LoRA text2image fine-tuning - {repo_id}
95
 
96
+ These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
97
  {img_str}
98
 
99
+ LoRA for the text encoder was enabled: {train_text_encoder}.
100
+
101
  Special VAE used for training: {vae_path}.
102
  """
 
103
  model_card = load_or_create_model_card(
104
  repo_id_or_path=repo_id,
105
  from_training=True,
 
113
  "stable-diffusion-xl",
114
  "stable-diffusion-xl-diffusers",
115
  "text-to-image",
 
116
  "diffusers",
117
+ "diffusers-training",
118
+ "lora",
119
  ]
120
  model_card = populate_model_card(model_card, tags=tags)
121
 
122
  model_card.save(os.path.join(repo_folder, "README.md"))
123
 
124
 
125
+ def log_validation(
126
+ pipeline,
127
+ args,
128
+ accelerator,
129
+ epoch,
130
+ is_final_validation=False,
131
+ ):
132
+ logger.info(
133
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
134
+ f" {args.validation_prompt}."
135
+ )
136
+ pipeline = pipeline.to(accelerator.device)
137
+ pipeline.set_progress_bar_config(disable=True)
138
+
139
+ # run inference
140
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
141
+ pipeline_args = {"prompt": args.validation_prompt}
142
+ if torch.backends.mps.is_available():
143
+ autocast_ctx = nullcontext()
144
+ else:
145
+ autocast_ctx = torch.autocast(accelerator.device.type)
146
+
147
+ with autocast_ctx:
148
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
149
+
150
+ for tracker in accelerator.trackers:
151
+ phase_name = "test" if is_final_validation else "validation"
152
+ if tracker.name == "tensorboard":
153
+ np_images = np.stack([np.asarray(img) for img in images])
154
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
155
+ if tracker.name == "wandb":
156
+ tracker.log(
157
+ {
158
+ phase_name: [
159
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
160
+ ]
161
+ }
162
+ )
163
+ return images
164
+
165
+
166
  def import_model_class_from_model_name_or_path(
167
  pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
168
  ):
 
276
  "value if set."
277
  ),
278
  )
 
 
 
 
 
 
279
  parser.add_argument(
280
  "--output_dir",
281
  type=str,
282
+ default="sd-model-finetuned-lora",
283
  help="The output directory where the model predictions and checkpoints will be written.",
284
  )
285
  parser.add_argument(
 
312
  action="store_true",
313
  help="whether to randomly flip images horizontally",
314
  )
315
+ parser.add_argument(
316
+ "--train_text_encoder",
317
+ action="store_true",
318
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
319
+ )
320
  parser.add_argument(
321
  "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
322
  )
 
387
  parser.add_argument(
388
  "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
389
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  parser.add_argument(
391
  "--snr_gamma",
392
  type=float,
 
394
  help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
395
  "More details here: https://huggingface.co/papers/2303.09556.",
396
  )
 
397
  parser.add_argument(
398
  "--allow_tf32",
399
  action="store_true",
 
463
  )
464
  parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
465
  parser.add_argument(
466
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
467
  )
468
  parser.add_argument(
469
+ "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
470
  )
471
  parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
472
+ parser.add_argument(
473
+ "--rank",
474
+ type=int,
475
+ default=4,
476
+ help=("The dimension of the LoRA update matrices."),
477
+ )
478
+ parser.add_argument(
479
+ "--debug_loss",
480
+ action="store_true",
481
+ help="debug loss for each image, if filenames are available in the dataset",
482
+ )
483
  parser.add_argument(
484
  "--image_interpolation_mode",
485
  type=str,
 
502
  # Sanity checks
503
  if args.dataset_name is None and args.train_data_dir is None:
504
  raise ValueError("Need either a dataset name or a training folder.")
 
 
505
 
506
  return args
507
 
508
 
509
+ DATASET_NAME_MAPPING = {
510
+ "lambdalabs/naruto-blip-captions": ("image", "text"),
511
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
 
 
 
 
513
 
514
+ def tokenize_prompt(tokenizer, prompt):
515
+ text_inputs = tokenizer(
516
+ prompt,
517
+ padding="max_length",
518
+ max_length=tokenizer.model_max_length,
519
+ truncation=True,
520
+ return_tensors="pt",
521
+ )
522
+ text_input_ids = text_inputs.input_ids
523
+ return text_input_ids
524
 
525
 
526
+ # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
527
+ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
528
+ prompt_embeds_list = []
529
 
530
+ for i, text_encoder in enumerate(text_encoders):
531
+ if tokenizers is not None:
532
+ tokenizer = tokenizers[i]
533
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
534
+ else:
535
+ assert text_input_ids_list is not None
536
+ text_input_ids = text_input_ids_list[i]
537
 
538
+ prompt_embeds = text_encoder(
539
+ text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
  )
541
 
542
+ # We are only ALWAYS interested in the pooled output of the final text encoder
543
+ pooled_prompt_embeds = prompt_embeds[0]
544
+ prompt_embeds = prompt_embeds[-1][-2]
545
+ bs_embed, seq_len, _ = prompt_embeds.shape
546
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
547
+ prompt_embeds_list.append(prompt_embeds)
548
 
549
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
550
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
551
+ return prompt_embeds, pooled_prompt_embeds
552
 
553
 
554
  def main(args):
 
560
 
561
  logging_dir = Path(args.output_dir, args.logging_dir)
562
 
 
 
563
  if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
564
  # due to pytorch#99272, MPS does not yet support bfloat16.
565
  raise ValueError(
566
  "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
567
  )
568
 
569
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
570
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
571
  accelerator = Accelerator(
572
  gradient_accumulation_steps=args.gradient_accumulation_steps,
573
  mixed_precision=args.mixed_precision,
574
  log_with=args.report_to,
575
  project_config=accelerator_project_config,
576
+ kwargs_handlers=[kwargs],
577
  )
578
 
 
 
 
 
 
 
 
 
 
579
  # Make one log on every process with the configuration for debugging.
580
  logging.basicConfig(
581
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 
630
 
631
  # Load scheduler and models
632
  noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
 
633
  text_encoder_one = text_encoder_cls_one.from_pretrained(
634
  args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
635
  )
 
651
  args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
652
  )
653
 
654
+ # We only train the additional adapter LoRA layers
655
  vae.requires_grad_(False)
656
  text_encoder_one.requires_grad_(False)
657
  text_encoder_two.requires_grad_(False)
658
+ unet.requires_grad_(False)
 
659
 
660
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
661
  # as these weights are only used for inference, keeping weights in full precision is not required.
662
  weight_dtype = torch.float32
663
  if accelerator.mixed_precision == "fp16":
 
667
 
668
  # Move unet, vae and text_encoder to device and cast to weight_dtype
669
  # The VAE is in float32 to avoid NaN losses.
670
+ unet.to(accelerator.device, dtype=weight_dtype)
671
+
672
+ if args.pretrained_vae_model_name_or_path is None:
673
+ vae.to(accelerator.device, dtype=torch.float32)
674
+ else:
675
+ vae.to(accelerator.device, dtype=weight_dtype)
676
  text_encoder_one.to(accelerator.device, dtype=weight_dtype)
677
  text_encoder_two.to(accelerator.device, dtype=weight_dtype)
678
 
 
 
 
 
 
 
679
  if args.enable_npu_flash_attention:
680
  if is_torch_npu_available():
681
  logger.info("npu flash attention enabled.")
682
  unet.enable_npu_flash_attention()
683
  else:
684
  raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
685
+
686
  if args.enable_xformers_memory_efficient_attention:
687
  if is_xformers_available():
688
  import xformers
 
696
  else:
697
  raise ValueError("xformers is not available. Make sure it is installed correctly")
698
 
699
+ # now we will add new LoRA weights to the attention layers
700
+ # Set correct lora layers
701
+ unet_lora_config = LoraConfig(
702
+ r=args.rank,
703
+ lora_alpha=args.rank,
704
+ init_lora_weights="gaussian",
705
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
706
+ )
707
+
708
+ unet.add_adapter(unet_lora_config)
709
+
710
+ # The text encoder comes from 🤗 transformers, we will also attach adapters to it.
711
+ if args.train_text_encoder:
712
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
713
+ text_lora_config = LoraConfig(
714
+ r=args.rank,
715
+ lora_alpha=args.rank,
716
+ init_lora_weights="gaussian",
717
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
718
+ )
719
+ text_encoder_one.add_adapter(text_lora_config)
720
+ text_encoder_two.add_adapter(text_lora_config)
721
 
722
+ def unwrap_model(model):
723
+ model = accelerator.unwrap_model(model)
724
+ model = model._orig_mod if is_compiled_module(model) else model
725
+ return model
726
 
727
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
728
+ def save_model_hook(models, weights, output_dir):
729
+ if accelerator.is_main_process:
730
+ # there are only two options here. Either are just the unet attn processor layers
731
+ # or there are the unet and text encoder attn layers
732
+ unet_lora_layers_to_save = None
733
+ text_encoder_one_lora_layers_to_save = None
734
+ text_encoder_two_lora_layers_to_save = None
735
+
736
+ for model in models:
737
+ if isinstance(unwrap_model(model), type(unwrap_model(unet))):
738
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
739
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
740
+ text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
741
+ get_peft_model_state_dict(model)
742
+ )
743
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
744
+ text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
745
+ get_peft_model_state_dict(model)
746
+ )
747
+ else:
748
+ raise ValueError(f"unexpected save model: {model.__class__}")
749
 
750
+ # make sure to pop weight so that corresponding model is not saved again
751
+ if weights:
752
+ weights.pop()
 
 
 
753
 
754
+ StableDiffusionXLPipeline.save_lora_weights(
755
+ output_dir,
756
+ unet_lora_layers=unet_lora_layers_to_save,
757
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
758
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
759
+ )
760
+
761
+ def load_model_hook(models, input_dir):
762
+ unet_ = None
763
+ text_encoder_one_ = None
764
+ text_encoder_two_ = None
765
 
766
+ while len(models) > 0:
767
+ model = models.pop()
 
768
 
769
+ if isinstance(model, type(unwrap_model(unet))):
770
+ unet_ = model
771
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
772
+ text_encoder_one_ = model
773
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
774
+ text_encoder_two_ = model
775
+ else:
776
+ raise ValueError(f"unexpected save model: {model.__class__}")
777
+
778
+ lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
779
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
780
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
781
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
782
+ if incompatible_keys is not None:
783
+ # check only for unexpected keys
784
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
785
+ if unexpected_keys:
786
+ logger.warning(
787
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
788
+ f" {unexpected_keys}. "
789
+ )
790
 
791
+ if args.train_text_encoder:
792
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
793
+
794
+ _set_state_dict_into_text_encoder(
795
+ lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
796
+ )
797
+
798
+ # Make sure the trainable params are in float32. This is again needed since the base models
799
+ # are in `weight_dtype`. More details:
800
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
801
+ if args.mixed_precision == "fp16":
802
+ models = [unet_]
803
+ if args.train_text_encoder:
804
+ models.extend([text_encoder_one_, text_encoder_two_])
805
+ cast_training_params(models, dtype=torch.float32)
806
+
807
+ accelerator.register_save_state_pre_hook(save_model_hook)
808
+ accelerator.register_load_state_pre_hook(load_model_hook)
809
 
810
  if args.gradient_checkpointing:
811
  unet.enable_gradient_checkpointing()
812
+ if args.train_text_encoder:
813
+ text_encoder_one.gradient_checkpointing_enable()
814
+ text_encoder_two.gradient_checkpointing_enable()
815
 
816
  # Enable TF32 for faster training on Ampere GPUs,
817
  # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
 
823
  args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
824
  )
825
 
826
+ # Make sure the trainable params are in float32.
827
+ if args.mixed_precision == "fp16":
828
+ models = [unet]
829
+ if args.train_text_encoder:
830
+ models.extend([text_encoder_one, text_encoder_two])
831
+ cast_training_params(models, dtype=torch.float32)
832
+
833
  # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
834
  if args.use_8bit_adam:
835
  try:
 
844
  optimizer_class = torch.optim.AdamW
845
 
846
  # Optimizer creation
847
+ params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
848
+ if args.train_text_encoder:
849
+ params_to_optimize = (
850
+ params_to_optimize
851
+ + list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
852
+ + list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
853
+ )
854
  optimizer = optimizer_class(
855
  params_to_optimize,
856
  lr=args.learning_rate,
 
905
  )
906
 
907
  # Preprocessing the datasets.
908
+ # We need to tokenize input captions and transform the images.
909
+ def tokenize_captions(examples, is_train=True):
910
+ captions = []
911
+ for caption in examples[caption_column]:
912
+ if isinstance(caption, str):
913
+ captions.append(caption)
914
+ elif isinstance(caption, (list, np.ndarray)):
915
+ # take a random caption if there are multiple
916
+ captions.append(random.choice(caption) if is_train else caption[0])
917
+ else:
918
+ raise ValueError(
919
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
920
+ )
921
+ tokens_one = tokenize_prompt(tokenizer_one, captions)
922
+ tokens_two = tokenize_prompt(tokenizer_two, captions)
923
+ return tokens_one, tokens_two
924
+
925
+ # Get the specified interpolation method from the args
926
  interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
927
+
928
+ # Raise an error if the interpolation method is invalid
929
  if interpolation is None:
930
+ raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
931
+ # Preprocessing the datasets.
932
+ train_resize = transforms.Resize(args.resolution, interpolation=interpolation) # Use dynamic interpolation method
933
  train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
934
  train_flip = transforms.RandomHorizontalFlip(p=1.0)
935
+ train_transforms = transforms.Compose(
936
+ [
937
+ transforms.ToTensor(),
938
+ transforms.Normalize([0.5], [0.5]),
939
+ ]
940
+ )
941
 
942
  def preprocess_train(examples):
943
  images = [image.convert("RGB") for image in examples[image_column]]
 
966
  examples["original_sizes"] = original_sizes
967
  examples["crop_top_lefts"] = crop_top_lefts
968
  examples["pixel_values"] = all_images
969
+ tokens_one, tokens_two = tokenize_captions(examples)
970
+ examples["input_ids_one"] = tokens_one
971
+ examples["input_ids_two"] = tokens_two
972
+ if args.debug_loss:
973
+ fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename]
974
+ if fnames:
975
+ examples["filenames"] = fnames
976
  return examples
977
 
978
  with accelerator.main_process_first():
979
  if args.max_train_samples is not None:
980
  dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
981
  # Set the training transforms
982
+ train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
983
 
984
  def collate_fn(examples):
985
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
986
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
987
  original_sizes = [example["original_sizes"] for example in examples]
988
  crop_top_lefts = [example["crop_top_lefts"] for example in examples]
989
+ input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
990
+ input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
991
+ result = {
992
+ "pixel_values": pixel_values,
993
+ "input_ids_one": input_ids_one,
994
+ "input_ids_two": input_ids_two,
 
995
  "original_sizes": original_sizes,
996
  "crop_top_lefts": crop_top_lefts,
997
  }
998
 
999
+ filenames = [example["filenames"] for example in examples if "filenames" in example]
1000
+ if filenames:
1001
+ result["filenames"] = filenames
1002
+ return result
1003
+
1004
  # DataLoaders creation:
1005
  train_dataloader = torch.utils.data.DataLoader(
1006
+ train_dataset,
1007
  shuffle=True,
1008
  collate_fn=collate_fn,
1009
  batch_size=args.train_batch_size,
 
1025
  )
1026
 
1027
  # Prepare everything with our `accelerator`.
1028
+ if args.train_text_encoder:
1029
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1030
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
1031
+ )
1032
+ else:
1033
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1034
+ unet, optimizer, train_dataloader, lr_scheduler
1035
+ )
1036
 
1037
  # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1038
  num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
 
1044
  # We need to initialize the trackers we use, and also store our configuration.
1045
  # The trackers initializes automatically on the main process.
1046
  if accelerator.is_main_process:
1047
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
 
 
 
 
 
 
 
 
 
 
 
1048
 
1049
  # Train!
1050
  total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1051
 
1052
  logger.info("***** Running training *****")
1053
+ logger.info(f" Num examples = {len(train_dataset)}")
1054
  logger.info(f" Num Epochs = {args.num_train_epochs}")
1055
  logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1056
  logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
 
1096
  )
1097
 
1098
  for epoch in range(first_epoch, args.num_train_epochs):
1099
+ unet.train()
1100
+ if args.train_text_encoder:
1101
+ text_encoder_one.train()
1102
+ text_encoder_two.train()
1103
  train_loss = 0.0
1104
  for step, batch in enumerate(train_dataloader):
1105
  with accelerator.accumulate(unet):
1106
+ # Convert images to latent space
1107
+ if args.pretrained_vae_model_name_or_path is not None:
1108
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
1109
+ else:
1110
+ pixel_values = batch["pixel_values"]
1111
+
1112
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1113
+ model_input = model_input * vae.config.scaling_factor
1114
+ if args.pretrained_vae_model_name_or_path is None:
1115
+ model_input = model_input.to(weight_dtype)
1116
+
1117
  # Sample noise that we'll add to the latents
 
1118
  noise = torch.randn_like(model_input)
1119
  if args.noise_offset:
1120
  # https://www.crosslabs.org//blog/diffusion-with-offset-noise
 
1123
  )
1124
 
1125
  bsz = model_input.shape[0]
1126
+ # Sample a random timestep for each image
1127
+ timesteps = torch.randint(
1128
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1129
+ )
1130
+ timesteps = timesteps.long()
 
 
 
 
 
 
 
1131
 
1132
  # Add noise to the model input according to the noise magnitude at each timestep
1133
  # (this is the forward diffusion process)
1134
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1135
 
1136
  # time ids
1137
  def compute_time_ids(original_size, crops_coords_top_left):
1138
  # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1139
  target_size = (args.resolution, args.resolution)
1140
  add_time_ids = list(original_size + crops_coords_top_left + target_size)
1141
+ add_time_ids = torch.tensor([add_time_ids])
1142
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1143
  return add_time_ids
1144
 
1145
  add_time_ids = torch.cat(
 
1148
 
1149
  # Predict the noise residual
1150
  unet_added_conditions = {"time_ids": add_time_ids}
1151
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1152
+ text_encoders=[text_encoder_one, text_encoder_two],
1153
+ tokenizers=None,
1154
+ prompt=None,
1155
+ text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"]],
1156
+ )
1157
  unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
1158
  model_pred = unet(
1159
  noisy_model_input,
 
1172
  target = noise
1173
  elif noise_scheduler.config.prediction_type == "v_prediction":
1174
  target = noise_scheduler.get_velocity(model_input, noise, timesteps)
 
 
 
 
 
1175
  else:
1176
  raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1177
 
 
1193
  loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1194
  loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1195
  loss = loss.mean()
1196
+ if args.debug_loss and "filenames" in batch:
1197
+ for fname in batch["filenames"]:
1198
+ accelerator.log({"loss_for_" + fname: loss}, step=global_step)
1199
  # Gather the losses across all processes for logging (if we use distributed training).
1200
  avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1201
  train_loss += avg_loss.item() / args.gradient_accumulation_steps
 
1203
  # Backpropagate
1204
  accelerator.backward(loss)
1205
  if accelerator.sync_gradients:
1206
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
 
1207
  optimizer.step()
1208
  lr_scheduler.step()
1209
  optimizer.zero_grad()
1210
 
1211
  # Checks if the accelerator has performed an optimization step behind the scenes
1212
  if accelerator.sync_gradients:
 
 
1213
  progress_bar.update(1)
1214
  global_step += 1
1215
  accelerator.log({"train_loss": train_loss}, step=global_step)
 
1250
 
1251
  if accelerator.is_main_process:
1252
  if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
 
 
 
 
 
 
 
 
 
1253
  # create pipeline
 
 
 
 
 
 
1254
  pipeline = StableDiffusionXLPipeline.from_pretrained(
1255
  args.pretrained_model_name_or_path,
1256
  vae=vae,
1257
+ text_encoder=unwrap_model(text_encoder_one),
1258
+ text_encoder_2=unwrap_model(text_encoder_two),
1259
+ unet=unwrap_model(unet),
1260
  revision=args.revision,
1261
  variant=args.variant,
1262
  torch_dtype=weight_dtype,
1263
  )
 
 
 
1264
 
1265
+ images = log_validation(pipeline, args, accelerator, epoch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1266
 
1267
  del pipeline
1268
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
1269
 
1270
+ # Save the lora layers
1271
  accelerator.wait_for_everyone()
1272
  if accelerator.is_main_process:
1273
  unet = unwrap_model(unet)
1274
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
 
1275
 
1276
+ if args.train_text_encoder:
1277
+ text_encoder_one = unwrap_model(text_encoder_one)
1278
+ text_encoder_two = unwrap_model(text_encoder_two)
1279
+
1280
+ text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
1281
+ text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
1282
+ else:
1283
+ text_encoder_lora_layers = None
1284
+ text_encoder_2_lora_layers = None
1285
+
1286
+ StableDiffusionXLPipeline.save_lora_weights(
1287
+ save_directory=args.output_dir,
1288
+ unet_lora_layers=unet_lora_state_dict,
1289
+ text_encoder_lora_layers=text_encoder_lora_layers,
1290
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
1291
  )
1292
+
1293
+ del unet
1294
+ del text_encoder_one
1295
+ del text_encoder_two
1296
+ del text_encoder_lora_layers
1297
+ del text_encoder_2_lora_layers
1298
+ torch.cuda.empty_cache()
1299
+
1300
+ # Final inference
1301
+ # Make sure vae.dtype is consistent with the unet.dtype
1302
+ if args.mixed_precision == "fp16":
1303
+ vae.to(weight_dtype)
1304
+ # Load previous pipeline
1305
  pipeline = StableDiffusionXLPipeline.from_pretrained(
1306
  args.pretrained_model_name_or_path,
 
1307
  vae=vae,
1308
  revision=args.revision,
1309
  variant=args.variant,
1310
  torch_dtype=weight_dtype,
1311
  )
1312
+
1313
+ # load attention processors
1314
+ pipeline.load_lora_weights(args.output_dir)
 
1315
 
1316
  # run inference
 
1317
  if args.validation_prompt and args.num_validation_images > 0:
1318
+ images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1319
 
1320
  if args.push_to_hub:
1321
  save_model_card(
1322
+ repo_id,
1323
  images=images,
 
1324
  base_model=args.pretrained_model_name_or_path,
1325
  dataset_name=args.dataset_name,
1326
+ train_text_encoder=args.train_text_encoder,
1327
  repo_folder=args.output_dir,
1328
  vae_path=args.pretrained_vae_model_name_or_path,
1329
  )