End of training
Browse files- README.md +6 -0
- checkpoint-100/optimizer.bin +2 -2
- checkpoint-100/pytorch_lora_weights.safetensors +2 -2
- checkpoint-200/optimizer.bin +2 -2
- checkpoint-200/pytorch_lora_weights.safetensors +2 -2
- checkpoint-300/optimizer.bin +2 -2
- checkpoint-300/pytorch_lora_weights.safetensors +2 -2
- checkpoint-400/optimizer.bin +2 -2
- checkpoint-400/pytorch_lora_weights.safetensors +2 -2
- checkpoint-500/optimizer.bin +2 -2
- checkpoint-500/pytorch_lora_weights.safetensors +2 -2
- examples/dreambooth/train_dreambooth_lora.py +21 -3
- image_0.png +0 -0
- image_1.png +0 -0
- image_2.png +0 -0
- image_3.png +0 -0
- pytorch_lora_weights.safetensors +1 -1
README.md
CHANGED
|
@@ -8,6 +8,12 @@ tags:
|
|
| 8 |
- diffusers-training
|
| 9 |
- stable-diffusion
|
| 10 |
- stable-diffusion-diffusers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
base_model: runwayml/stable-diffusion-v1-5
|
| 12 |
inference: true
|
| 13 |
instance_prompt: a photo of sks dog
|
|
|
|
| 8 |
- diffusers-training
|
| 9 |
- stable-diffusion
|
| 10 |
- stable-diffusion-diffusers
|
| 11 |
+
- text-to-image
|
| 12 |
+
- diffusers
|
| 13 |
+
- lora
|
| 14 |
+
- diffusers-training
|
| 15 |
+
- stable-diffusion
|
| 16 |
+
- stable-diffusion-diffusers
|
| 17 |
base_model: runwayml/stable-diffusion-v1-5
|
| 18 |
inference: true
|
| 19 |
instance_prompt: a photo of sks dog
|
checkpoint-100/optimizer.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3fef76dd1515c660279fab0218004aa69f28503f51d852ae91cca05e803d2c9c
|
| 3 |
+
size 6584954
|
checkpoint-100/pytorch_lora_weights.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:38a118bf5f1a6b3e2bf9961d6b0f877d4115480724b8385cc7d663a476bfa2a5
|
| 3 |
+
size 3226184
|
checkpoint-200/optimizer.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1deb8bd6df4ae7fcf08371f1db98545fbf342326ddf361ccf9159284eaae8676
|
| 3 |
+
size 6584954
|
checkpoint-200/pytorch_lora_weights.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58259087cf6305a421542c60428fba4d5a9303d76e728320b80903115f9bf3fc
|
| 3 |
+
size 3226184
|
checkpoint-300/optimizer.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9254592e65ed4498ffcd54ca7577d5a3c88ea0d3f8f238b8e932fbdf4e0d60e7
|
| 3 |
+
size 6584954
|
checkpoint-300/pytorch_lora_weights.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:683502e4639c5a2d8dd42312c31af8083c312a4f696ff41fed1bd7b005170f96
|
| 3 |
+
size 3226184
|
checkpoint-400/optimizer.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:059700890bc1c2a7db1642ed88058c52f9ad676ae797bf3e6034ceb529d90b76
|
| 3 |
+
size 6584954
|
checkpoint-400/pytorch_lora_weights.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8e8407fe57dc47e869f0909b843591c3b7343e91caa866acca3cf8321d0aea28
|
| 3 |
+
size 3226184
|
checkpoint-500/optimizer.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cfb252d5c6851d8f59de49f3d0f65dd422fe4dcf9bccf6fe95a6cb779e830aa6
|
| 3 |
+
size 6584954
|
checkpoint-500/pytorch_lora_weights.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:962469684f6d9dd5033894b8be075ca09ef8a6c439da9bb85c8316c9e53e8166
|
| 3 |
+
size 3226184
|
examples/dreambooth/train_dreambooth_lora.py
CHANGED
|
@@ -142,6 +142,7 @@ def log_validation(
|
|
| 142 |
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
| 143 |
|
| 144 |
pipeline = pipeline.to(accelerator.device)
|
|
|
|
| 145 |
pipeline.set_progress_bar_config(disable=True)
|
| 146 |
|
| 147 |
# run inference
|
|
@@ -151,7 +152,9 @@ def log_validation(
|
|
| 151 |
images = []
|
| 152 |
for _ in range(args.num_validation_images):
|
| 153 |
with torch.cuda.amp.autocast():
|
|
|
|
| 154 |
image = pipeline(**pipeline_args, generator=generator).images[0]
|
|
|
|
| 155 |
images.append(image)
|
| 156 |
else:
|
| 157 |
images = []
|
|
@@ -748,7 +751,7 @@ def main(args):
|
|
| 748 |
log_with=args.report_to,
|
| 749 |
project_config=accelerator_project_config,
|
| 750 |
)
|
| 751 |
-
|
| 752 |
# Disable AMP for MPS.
|
| 753 |
if torch.backends.mps.is_available():
|
| 754 |
accelerator.native_amp = False
|
|
@@ -784,8 +787,10 @@ def main(args):
|
|
| 784 |
if args.seed is not None:
|
| 785 |
set_seed(args.seed)
|
| 786 |
|
|
|
|
| 787 |
# Generate class images if prior preservation is enabled.
|
| 788 |
if args.with_prior_preservation:
|
|
|
|
| 789 |
class_images_dir = Path(args.class_data_dir)
|
| 790 |
if not class_images_dir.exists():
|
| 791 |
class_images_dir.mkdir(parents=True)
|
|
@@ -815,6 +820,8 @@ def main(args):
|
|
| 815 |
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
| 816 |
|
| 817 |
sample_dataloader = accelerator.prepare(sample_dataloader)
|
|
|
|
|
|
|
| 818 |
pipeline.to(accelerator.device)
|
| 819 |
|
| 820 |
for example in tqdm(
|
|
@@ -882,11 +889,14 @@ def main(args):
|
|
| 882 |
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
| 883 |
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 884 |
weight_dtype = torch.float32
|
|
|
|
|
|
|
|
|
|
| 885 |
if accelerator.mixed_precision == "fp16":
|
| 886 |
weight_dtype = torch.float16
|
| 887 |
elif accelerator.mixed_precision == "bf16":
|
| 888 |
weight_dtype = torch.bfloat16
|
| 889 |
-
|
| 890 |
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
| 891 |
unet.to(accelerator.device, dtype=weight_dtype)
|
| 892 |
if vae is not None:
|
|
@@ -1091,6 +1101,7 @@ def main(args):
|
|
| 1091 |
validation_prompt_negative_prompt_embeds = None
|
| 1092 |
pre_computed_class_prompt_encoder_hidden_states = None
|
| 1093 |
|
|
|
|
| 1094 |
# Dataset and DataLoaders creation:
|
| 1095 |
train_dataset = DreamBoothDataset(
|
| 1096 |
instance_data_root=args.instance_data_dir,
|
|
@@ -1204,7 +1215,9 @@ def main(args):
|
|
| 1204 |
)
|
| 1205 |
|
| 1206 |
for epoch in range(first_epoch, args.num_train_epochs):
|
|
|
|
| 1207 |
unet.train()
|
|
|
|
| 1208 |
if args.train_text_encoder:
|
| 1209 |
text_encoder.train()
|
| 1210 |
for step, batch in enumerate(train_dataloader):
|
|
@@ -1335,6 +1348,7 @@ def main(args):
|
|
| 1335 |
break
|
| 1336 |
|
| 1337 |
if accelerator.is_main_process:
|
|
|
|
| 1338 |
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
|
| 1339 |
# create pipeline
|
| 1340 |
pipeline = DiffusionPipeline.from_pretrained(
|
|
@@ -1345,6 +1359,8 @@ def main(args):
|
|
| 1345 |
variant=args.variant,
|
| 1346 |
torch_dtype=weight_dtype,
|
| 1347 |
)
|
|
|
|
|
|
|
| 1348 |
|
| 1349 |
if args.pre_compute_text_embeddings:
|
| 1350 |
pipeline_args = {
|
|
@@ -1353,7 +1369,7 @@ def main(args):
|
|
| 1353 |
}
|
| 1354 |
else:
|
| 1355 |
pipeline_args = {"prompt": args.validation_prompt}
|
| 1356 |
-
|
| 1357 |
images = log_validation(
|
| 1358 |
pipeline,
|
| 1359 |
args,
|
|
@@ -1391,6 +1407,8 @@ def main(args):
|
|
| 1391 |
# load attention processors
|
| 1392 |
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
|
| 1393 |
|
|
|
|
|
|
|
| 1394 |
# run inference
|
| 1395 |
images = []
|
| 1396 |
if args.validation_prompt and args.num_validation_images > 0:
|
|
|
|
| 142 |
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
| 143 |
|
| 144 |
pipeline = pipeline.to(accelerator.device)
|
| 145 |
+
pipeline.enable_attention_slicing()
|
| 146 |
pipeline.set_progress_bar_config(disable=True)
|
| 147 |
|
| 148 |
# run inference
|
|
|
|
| 152 |
images = []
|
| 153 |
for _ in range(args.num_validation_images):
|
| 154 |
with torch.cuda.amp.autocast():
|
| 155 |
+
print("Gen Image")
|
| 156 |
image = pipeline(**pipeline_args, generator=generator).images[0]
|
| 157 |
+
|
| 158 |
images.append(image)
|
| 159 |
else:
|
| 160 |
images = []
|
|
|
|
| 751 |
log_with=args.report_to,
|
| 752 |
project_config=accelerator_project_config,
|
| 753 |
)
|
| 754 |
+
print("Accelerator setup")
|
| 755 |
# Disable AMP for MPS.
|
| 756 |
if torch.backends.mps.is_available():
|
| 757 |
accelerator.native_amp = False
|
|
|
|
| 787 |
if args.seed is not None:
|
| 788 |
set_seed(args.seed)
|
| 789 |
|
| 790 |
+
print("Before prior preservation")
|
| 791 |
# Generate class images if prior preservation is enabled.
|
| 792 |
if args.with_prior_preservation:
|
| 793 |
+
print("In prior preservation")
|
| 794 |
class_images_dir = Path(args.class_data_dir)
|
| 795 |
if not class_images_dir.exists():
|
| 796 |
class_images_dir.mkdir(parents=True)
|
|
|
|
| 820 |
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
| 821 |
|
| 822 |
sample_dataloader = accelerator.prepare(sample_dataloader)
|
| 823 |
+
print("printing accelerator defined device")
|
| 824 |
+
print(accelerator.device)
|
| 825 |
pipeline.to(accelerator.device)
|
| 826 |
|
| 827 |
for example in tqdm(
|
|
|
|
| 889 |
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
| 890 |
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 891 |
weight_dtype = torch.float32
|
| 892 |
+
print("What is accelerator mixed precision")
|
| 893 |
+
print(accelerator.mixed_precision)
|
| 894 |
+
print(accelerator.device)
|
| 895 |
if accelerator.mixed_precision == "fp16":
|
| 896 |
weight_dtype = torch.float16
|
| 897 |
elif accelerator.mixed_precision == "bf16":
|
| 898 |
weight_dtype = torch.bfloat16
|
| 899 |
+
print(weight_dtype)
|
| 900 |
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
| 901 |
unet.to(accelerator.device, dtype=weight_dtype)
|
| 902 |
if vae is not None:
|
|
|
|
| 1101 |
validation_prompt_negative_prompt_embeds = None
|
| 1102 |
pre_computed_class_prompt_encoder_hidden_states = None
|
| 1103 |
|
| 1104 |
+
print("Getting dataset")
|
| 1105 |
# Dataset and DataLoaders creation:
|
| 1106 |
train_dataset = DreamBoothDataset(
|
| 1107 |
instance_data_root=args.instance_data_dir,
|
|
|
|
| 1215 |
)
|
| 1216 |
|
| 1217 |
for epoch in range(first_epoch, args.num_train_epochs):
|
| 1218 |
+
print("Unet train start")
|
| 1219 |
unet.train()
|
| 1220 |
+
print("Unet train done")
|
| 1221 |
if args.train_text_encoder:
|
| 1222 |
text_encoder.train()
|
| 1223 |
for step, batch in enumerate(train_dataloader):
|
|
|
|
| 1348 |
break
|
| 1349 |
|
| 1350 |
if accelerator.is_main_process:
|
| 1351 |
+
print("Accelerator is main process")
|
| 1352 |
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
|
| 1353 |
# create pipeline
|
| 1354 |
pipeline = DiffusionPipeline.from_pretrained(
|
|
|
|
| 1359 |
variant=args.variant,
|
| 1360 |
torch_dtype=weight_dtype,
|
| 1361 |
)
|
| 1362 |
+
pipeline = pipeline.to(accelerator.device)
|
| 1363 |
+
pipeline.enable_attention_slicing()
|
| 1364 |
|
| 1365 |
if args.pre_compute_text_embeddings:
|
| 1366 |
pipeline_args = {
|
|
|
|
| 1369 |
}
|
| 1370 |
else:
|
| 1371 |
pipeline_args = {"prompt": args.validation_prompt}
|
| 1372 |
+
print("Going for images")
|
| 1373 |
images = log_validation(
|
| 1374 |
pipeline,
|
| 1375 |
args,
|
|
|
|
| 1407 |
# load attention processors
|
| 1408 |
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
|
| 1409 |
|
| 1410 |
+
pipeline = pipeline.to(accelerator.device)
|
| 1411 |
+
pipeline.enable_attention_slicing()
|
| 1412 |
# run inference
|
| 1413 |
images = []
|
| 1414 |
if args.validation_prompt and args.num_validation_images > 0:
|
image_0.png
CHANGED
|
|
image_1.png
CHANGED
|
|
image_2.png
CHANGED
|
|
image_3.png
CHANGED
|
|
pytorch_lora_weights.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 3226184
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:962469684f6d9dd5033894b8be075ca09ef8a6c439da9bb85c8316c9e53e8166
|
| 3 |
size 3226184
|