back to original
Browse files- pipeline.py +108 -112
pipeline.py
CHANGED
|
@@ -625,89 +625,89 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
| 625 |
|
| 626 |
return timesteps, num_inference_steps - t_start
|
| 627 |
|
| 628 |
-
# def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
| 629 |
-
# if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
| 630 |
-
# raise ValueError(
|
| 631 |
-
# f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
| 632 |
-
# )
|
| 633 |
-
|
| 634 |
-
# image = image.to(device=device, dtype=dtype)
|
| 635 |
-
|
| 636 |
-
# batch_size = batch_size * num_images_per_prompt
|
| 637 |
-
# if isinstance(generator, list) and len(generator) != batch_size:
|
| 638 |
-
# raise ValueError(
|
| 639 |
-
# f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 640 |
-
# f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 641 |
-
# )
|
| 642 |
-
|
| 643 |
-
# if isinstance(generator, list):
|
| 644 |
-
# init_latents = [
|
| 645 |
-
# self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
| 646 |
-
# ]
|
| 647 |
-
# init_latents = torch.cat(init_latents, dim=0)
|
| 648 |
-
# else:
|
| 649 |
-
# init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
| 650 |
-
|
| 651 |
-
# init_latents = self.vae.config.scaling_factor * init_latents
|
| 652 |
-
|
| 653 |
-
# if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
| 654 |
-
# raise ValueError(
|
| 655 |
-
# f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
| 656 |
-
# )
|
| 657 |
-
# else:
|
| 658 |
-
# init_latents = torch.cat([init_latents], dim=0)
|
| 659 |
-
|
| 660 |
-
# shape = init_latents.shape
|
| 661 |
-
# noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 662 |
-
|
| 663 |
-
# # get latents
|
| 664 |
-
# init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
| 665 |
-
# latents = init_latents
|
| 666 |
-
|
| 667 |
-
# return latents
|
| 668 |
-
|
| 669 |
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
| 670 |
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
| 671 |
raise ValueError(
|
| 672 |
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
| 673 |
)
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
for img in image:
|
| 678 |
-
img_tensor = prepare_image(img)
|
| 679 |
-
img_tensor = img_tensor.to(device=device, dtype=dtype)
|
| 680 |
-
image_tensors.append(img_tensor)
|
| 681 |
-
image = torch.stack(image_tensors, dim=0)
|
| 682 |
-
else:
|
| 683 |
-
image = prepare_image(image)
|
| 684 |
-
image = image.to(device=device, dtype=dtype)
|
| 685 |
-
|
| 686 |
batch_size = batch_size * num_images_per_prompt
|
| 687 |
if isinstance(generator, list) and len(generator) != batch_size:
|
| 688 |
raise ValueError(
|
| 689 |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 690 |
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 691 |
)
|
| 692 |
-
|
| 693 |
if isinstance(generator, list):
|
| 694 |
init_latents = [
|
| 695 |
-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(
|
| 696 |
]
|
| 697 |
init_latents = torch.cat(init_latents, dim=0)
|
| 698 |
else:
|
| 699 |
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
|
|
|
| 700 |
init_latents = self.vae.config.scaling_factor * init_latents
|
| 701 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 702 |
shape = init_latents.shape
|
| 703 |
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 704 |
-
|
| 705 |
# get latents
|
| 706 |
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
| 707 |
latents = init_latents
|
| 708 |
-
|
| 709 |
return latents
|
| 710 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 711 |
|
| 712 |
def _default_height_width(self, height, width, image):
|
| 713 |
if isinstance(image, list):
|
|
@@ -940,27 +940,27 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
| 940 |
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
| 941 |
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
| 942 |
|
| 943 |
-
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
# latent_timestep,
|
| 947 |
-
# batch_size,
|
| 948 |
-
# num_images_per_prompt,
|
| 949 |
-
# prompt_embeds.dtype,
|
| 950 |
-
# device,
|
| 951 |
-
# generator,
|
| 952 |
-
# )
|
| 953 |
-
|
| 954 |
-
latents = [self.prepare_latents(
|
| 955 |
-
img,
|
| 956 |
latent_timestep,
|
| 957 |
batch_size,
|
| 958 |
num_images_per_prompt,
|
| 959 |
prompt_embeds.dtype,
|
| 960 |
device,
|
| 961 |
generator,
|
| 962 |
-
)
|
| 963 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 964 |
|
| 965 |
|
| 966 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
|
@@ -980,24 +980,6 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
| 980 |
# compute the percentage of total steps we are at
|
| 981 |
current_sampling_percent = i / len(timesteps)
|
| 982 |
|
| 983 |
-
# if (
|
| 984 |
-
# current_sampling_percent < controlnet_guidance_start
|
| 985 |
-
# or current_sampling_percent > controlnet_guidance_end
|
| 986 |
-
# ):
|
| 987 |
-
# # do not apply the controlnet
|
| 988 |
-
# down_block_res_samples = None
|
| 989 |
-
# mid_block_res_sample = None
|
| 990 |
-
# else:
|
| 991 |
-
# # apply the controlnet
|
| 992 |
-
# down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 993 |
-
# latent_model_input,
|
| 994 |
-
# t,
|
| 995 |
-
# encoder_hidden_states=prompt_embeds,
|
| 996 |
-
# controlnet_cond=controlnet_conditioning_image,
|
| 997 |
-
# conditioning_scale=controlnet_conditioning_scale,
|
| 998 |
-
# return_dict=False,
|
| 999 |
-
# )
|
| 1000 |
-
|
| 1001 |
if (
|
| 1002 |
current_sampling_percent < controlnet_guidance_start
|
| 1003 |
or current_sampling_percent > controlnet_guidance_end
|
|
@@ -1006,28 +988,42 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
| 1006 |
down_block_res_samples = None
|
| 1007 |
mid_block_res_sample = None
|
| 1008 |
else:
|
| 1009 |
-
|
| 1010 |
-
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
-
|
| 1017 |
-
|
| 1018 |
-
conditioning_scale=controlnet_conditioning_scale,
|
| 1019 |
-
return_dict=False,
|
| 1020 |
-
)
|
| 1021 |
-
|
| 1022 |
-
down_block_res_samples.append(down_block_res_sample)
|
| 1023 |
-
mid_block_res_samples.append(mid_block_res_sample)
|
| 1024 |
-
|
| 1025 |
-
down_block_res_samples = tuple(down_block_res_samples)
|
| 1026 |
-
mid_block_res_sample = torch.cat(mid_block_res_samples, dim=0)
|
| 1027 |
-
|
| 1028 |
-
# down_block_res_samples = torch.cat(down_block_res_samples, dim=0)
|
| 1029 |
-
# mid_block_res_sample = torch.cat(mid_block_res_samples, dim=0)
|
| 1030 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1031 |
|
| 1032 |
# predict the noise residual
|
| 1033 |
noise_pred = self.unet(
|
|
|
|
| 625 |
|
| 626 |
return timesteps, num_inference_steps - t_start
|
| 627 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
| 629 |
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
| 630 |
raise ValueError(
|
| 631 |
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
| 632 |
)
|
| 633 |
+
|
| 634 |
+
image = image.to(device=device, dtype=dtype)
|
| 635 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
batch_size = batch_size * num_images_per_prompt
|
| 637 |
if isinstance(generator, list) and len(generator) != batch_size:
|
| 638 |
raise ValueError(
|
| 639 |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 640 |
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 641 |
)
|
| 642 |
+
|
| 643 |
if isinstance(generator, list):
|
| 644 |
init_latents = [
|
| 645 |
+
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
| 646 |
]
|
| 647 |
init_latents = torch.cat(init_latents, dim=0)
|
| 648 |
else:
|
| 649 |
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
| 650 |
+
|
| 651 |
init_latents = self.vae.config.scaling_factor * init_latents
|
| 652 |
+
|
| 653 |
+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
| 654 |
+
raise ValueError(
|
| 655 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
| 656 |
+
)
|
| 657 |
+
else:
|
| 658 |
+
init_latents = torch.cat([init_latents], dim=0)
|
| 659 |
+
|
| 660 |
shape = init_latents.shape
|
| 661 |
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 662 |
+
|
| 663 |
# get latents
|
| 664 |
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
| 665 |
latents = init_latents
|
| 666 |
+
|
| 667 |
return latents
|
| 668 |
|
| 669 |
+
# def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
| 670 |
+
# if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
| 671 |
+
# raise ValueError(
|
| 672 |
+
# f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
| 673 |
+
# )
|
| 674 |
+
|
| 675 |
+
# if isinstance(image, list):
|
| 676 |
+
# image_tensors = []
|
| 677 |
+
# for img in image:
|
| 678 |
+
# img_tensor = prepare_image(img)
|
| 679 |
+
# img_tensor = img_tensor.to(device=device, dtype=dtype)
|
| 680 |
+
# image_tensors.append(img_tensor)
|
| 681 |
+
# image = torch.stack(image_tensors, dim=0)
|
| 682 |
+
# else:
|
| 683 |
+
# image = prepare_image(image)
|
| 684 |
+
# image = image.to(device=device, dtype=dtype)
|
| 685 |
+
|
| 686 |
+
# batch_size = batch_size * num_images_per_prompt
|
| 687 |
+
# if isinstance(generator, list) and len(generator) != batch_size:
|
| 688 |
+
# raise ValueError(
|
| 689 |
+
# f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 690 |
+
# f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 691 |
+
# )
|
| 692 |
+
|
| 693 |
+
# if isinstance(generator, list):
|
| 694 |
+
# init_latents = [
|
| 695 |
+
# self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(image.shape[0])
|
| 696 |
+
# ]
|
| 697 |
+
# init_latents = torch.cat(init_latents, dim=0)
|
| 698 |
+
# else:
|
| 699 |
+
# init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
| 700 |
+
# init_latents = self.vae.config.scaling_factor * init_latents
|
| 701 |
+
|
| 702 |
+
# shape = init_latents.shape
|
| 703 |
+
# noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 704 |
+
|
| 705 |
+
# # get latents
|
| 706 |
+
# init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
| 707 |
+
# latents = init_latents
|
| 708 |
+
|
| 709 |
+
# return latents
|
| 710 |
+
|
| 711 |
|
| 712 |
def _default_height_width(self, height, width, image):
|
| 713 |
if isinstance(image, list):
|
|
|
|
| 940 |
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
| 941 |
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
| 942 |
|
| 943 |
+
6. Prepare latent variables
|
| 944 |
+
latents = self.prepare_latents(
|
| 945 |
+
image,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 946 |
latent_timestep,
|
| 947 |
batch_size,
|
| 948 |
num_images_per_prompt,
|
| 949 |
prompt_embeds.dtype,
|
| 950 |
device,
|
| 951 |
generator,
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
# latents = [self.prepare_latents(
|
| 955 |
+
# img,
|
| 956 |
+
# latent_timestep,
|
| 957 |
+
# batch_size,
|
| 958 |
+
# num_images_per_prompt,
|
| 959 |
+
# prompt_embeds.dtype,
|
| 960 |
+
# device,
|
| 961 |
+
# generator,
|
| 962 |
+
# ) for img in images]
|
| 963 |
+
# latents = torch.cat(latents)
|
| 964 |
|
| 965 |
|
| 966 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
|
|
|
| 980 |
# compute the percentage of total steps we are at
|
| 981 |
current_sampling_percent = i / len(timesteps)
|
| 982 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 983 |
if (
|
| 984 |
current_sampling_percent < controlnet_guidance_start
|
| 985 |
or current_sampling_percent > controlnet_guidance_end
|
|
|
|
| 988 |
down_block_res_samples = None
|
| 989 |
mid_block_res_sample = None
|
| 990 |
else:
|
| 991 |
+
# apply the controlnet
|
| 992 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 993 |
+
latent_model_input,
|
| 994 |
+
t,
|
| 995 |
+
encoder_hidden_states=prompt_embeds,
|
| 996 |
+
controlnet_cond=controlnet_conditioning_image,
|
| 997 |
+
conditioning_scale=controlnet_conditioning_scale,
|
| 998 |
+
return_dict=False,
|
| 999 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1000 |
|
| 1001 |
+
# if (
|
| 1002 |
+
# current_sampling_percent < controlnet_guidance_start
|
| 1003 |
+
# or current_sampling_percent > controlnet_guidance_end
|
| 1004 |
+
# ):
|
| 1005 |
+
# # do not apply the controlnet
|
| 1006 |
+
# down_block_res_samples = None
|
| 1007 |
+
# mid_block_res_sample = None
|
| 1008 |
+
# else:
|
| 1009 |
+
# down_block_res_samples = []
|
| 1010 |
+
# mid_block_res_samples = []
|
| 1011 |
+
# for i in range(batch_size):
|
| 1012 |
+
# # apply the controlnet
|
| 1013 |
+
# down_block_res_sample, mid_block_res_sample = self.controlnet(
|
| 1014 |
+
# latent_model_input[i * num_images_per_prompt:(i + 1) * num_images_per_prompt],
|
| 1015 |
+
# t,
|
| 1016 |
+
# encoder_hidden_states=prompt_embeds[i * num_images_per_prompt:(i + 1) * num_images_per_prompt],
|
| 1017 |
+
# controlnet_cond=controlnet_conditioning_image[i],
|
| 1018 |
+
# conditioning_scale=controlnet_conditioning_scale,
|
| 1019 |
+
# return_dict=False,
|
| 1020 |
+
# )
|
| 1021 |
+
|
| 1022 |
+
# down_block_res_samples.append(down_block_res_sample)
|
| 1023 |
+
# mid_block_res_samples.append(mid_block_res_sample)
|
| 1024 |
+
|
| 1025 |
+
# down_block_res_samples = tuple(down_block_res_samples)
|
| 1026 |
+
# mid_block_res_sample = torch.cat(mid_block_res_samples, dim=0)
|
| 1027 |
|
| 1028 |
# predict the noise residual
|
| 1029 |
noise_pred = self.unet(
|