try fix controlnet batch processing
Browse files- pipeline.py +36 -9
pipeline.py
CHANGED
|
@@ -980,6 +980,24 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
| 980 |
# compute the percentage of total steps we are at
|
| 981 |
current_sampling_percent = i / len(timesteps)
|
| 982 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 983 |
if (
|
| 984 |
current_sampling_percent < controlnet_guidance_start
|
| 985 |
or current_sampling_percent > controlnet_guidance_end
|
|
@@ -988,15 +1006,24 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
| 988 |
down_block_res_samples = None
|
| 989 |
mid_block_res_sample = None
|
| 990 |
else:
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1000 |
|
| 1001 |
# predict the noise residual
|
| 1002 |
noise_pred = self.unet(
|
|
|
|
| 980 |
# compute the percentage of total steps we are at
|
| 981 |
current_sampling_percent = i / len(timesteps)
|
| 982 |
|
| 983 |
+
# if (
|
| 984 |
+
# current_sampling_percent < controlnet_guidance_start
|
| 985 |
+
# or current_sampling_percent > controlnet_guidance_end
|
| 986 |
+
# ):
|
| 987 |
+
# # do not apply the controlnet
|
| 988 |
+
# down_block_res_samples = None
|
| 989 |
+
# mid_block_res_sample = None
|
| 990 |
+
# else:
|
| 991 |
+
# # apply the controlnet
|
| 992 |
+
# down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 993 |
+
# latent_model_input,
|
| 994 |
+
# t,
|
| 995 |
+
# encoder_hidden_states=prompt_embeds,
|
| 996 |
+
# controlnet_cond=controlnet_conditioning_image,
|
| 997 |
+
# conditioning_scale=controlnet_conditioning_scale,
|
| 998 |
+
# return_dict=False,
|
| 999 |
+
# )
|
| 1000 |
+
|
| 1001 |
if (
|
| 1002 |
current_sampling_percent < controlnet_guidance_start
|
| 1003 |
or current_sampling_percent > controlnet_guidance_end
|
|
|
|
| 1006 |
down_block_res_samples = None
|
| 1007 |
mid_block_res_sample = None
|
| 1008 |
else:
|
| 1009 |
+
down_block_res_samples = []
|
| 1010 |
+
mid_block_res_samples = []
|
| 1011 |
+
for i in range(batch_size):
|
| 1012 |
+
# apply the controlnet
|
| 1013 |
+
down_block_res_sample, mid_block_res_sample = self.controlnet(
|
| 1014 |
+
latent_model_input[i * num_images_per_prompt:(i + 1) * num_images_per_prompt],
|
| 1015 |
+
t,
|
| 1016 |
+
encoder_hidden_states=prompt_embeds,
|
| 1017 |
+
controlnet_cond=controlnet_conditioning_image[i],
|
| 1018 |
+
conditioning_scale=controlnet_conditioning_scale,
|
| 1019 |
+
return_dict=False,
|
| 1020 |
+
)
|
| 1021 |
+
down_block_res_samples.append(down_block_res_sample)
|
| 1022 |
+
mid_block_res_samples.append(mid_block_res_sample)
|
| 1023 |
+
|
| 1024 |
+
down_block_res_samples = torch.cat(down_block_res_samples, dim=0)
|
| 1025 |
+
mid_block_res_sample = torch.cat(mid_block_res_samples, dim=0)
|
| 1026 |
+
|
| 1027 |
|
| 1028 |
# predict the noise residual
|
| 1029 |
noise_pred = self.unet(
|