Upload pipeline.py
Browse files- pipeline.py +65 -38
pipeline.py
CHANGED
|
@@ -20,50 +20,30 @@ import numpy as np
|
|
| 20 |
import PIL.Image
|
| 21 |
import torch
|
| 22 |
import torch.nn.functional as F
|
| 23 |
-
from transformers import (
|
| 24 |
-
|
| 25 |
-
CLIPTextModel,
|
| 26 |
-
CLIPTokenizer,
|
| 27 |
-
CLIPVisionModelWithProjection,
|
| 28 |
-
)
|
| 29 |
|
| 30 |
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 31 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 32 |
-
from diffusers.loaders import (
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
)
|
| 38 |
-
from diffusers.models import (
|
| 39 |
-
AutoencoderKL,
|
| 40 |
-
ControlNetModel,
|
| 41 |
-
ImageProjection,
|
| 42 |
-
UNet2DConditionModel,
|
| 43 |
-
)
|
| 44 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
| 45 |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
| 46 |
-
from diffusers.pipelines.pipeline_utils import DiffusionPipeline,
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
from diffusers.pipelines.stable_diffusion.safety_checker import
|
| 51 |
-
StableDiffusionSafetyChecker
|
| 52 |
-
)
|
| 53 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 54 |
-
from diffusers.utils import (
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
scale_lora_layers,
|
| 60 |
-
unscale_lora_layers,
|
| 61 |
-
)
|
| 62 |
-
from diffusers.utils.torch_utils import (
|
| 63 |
-
is_compiled_module,
|
| 64 |
-
is_torch_version,
|
| 65 |
-
randn_tensor,
|
| 66 |
-
)
|
| 67 |
|
| 68 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 69 |
|
|
@@ -691,6 +671,7 @@ class StableDiffusionControlNetPipeline(
|
|
| 691 |
control_guidance_start=0.0,
|
| 692 |
control_guidance_end=1.0,
|
| 693 |
callback_on_step_end_tensor_inputs=None,
|
|
|
|
| 694 |
):
|
| 695 |
if callback_steps is not None and (
|
| 696 |
not isinstance(callback_steps, int) or callback_steps <= 0
|
|
@@ -853,6 +834,9 @@ class StableDiffusionControlNetPipeline(
|
|
| 853 |
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
| 854 |
)
|
| 855 |
|
|
|
|
|
|
|
|
|
|
| 856 |
def check_image(self, image, prompt, prompt_embeds):
|
| 857 |
image_is_pil = isinstance(image, PIL.Image.Image)
|
| 858 |
image_is_tensor = isinstance(image, torch.Tensor)
|
|
@@ -894,6 +878,16 @@ class StableDiffusionControlNetPipeline(
|
|
| 894 |
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
| 895 |
)
|
| 896 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 897 |
def prepare_image(
|
| 898 |
self,
|
| 899 |
image,
|
|
@@ -995,6 +989,20 @@ class StableDiffusionControlNetPipeline(
|
|
| 995 |
assert emb.shape == (w.shape[0], embedding_dim)
|
| 996 |
return emb
|
| 997 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 998 |
@property
|
| 999 |
def guidance_scale(self):
|
| 1000 |
return self._guidance_scale
|
|
@@ -1173,6 +1181,8 @@ class StableDiffusionControlNetPipeline(
|
|
| 1173 |
callback = kwargs.pop("callback", None)
|
| 1174 |
callback_steps = kwargs.pop("callback_steps", None)
|
| 1175 |
|
|
|
|
|
|
|
| 1176 |
if callback is not None:
|
| 1177 |
deprecate(
|
| 1178 |
"callback",
|
|
@@ -1233,6 +1243,7 @@ class StableDiffusionControlNetPipeline(
|
|
| 1233 |
control_guidance_start,
|
| 1234 |
control_guidance_end,
|
| 1235 |
callback_on_step_end_tensor_inputs,
|
|
|
|
| 1236 |
)
|
| 1237 |
|
| 1238 |
self._guidance_scale = guidance_scale
|
|
@@ -1439,6 +1450,7 @@ class StableDiffusionControlNetPipeline(
|
|
| 1439 |
controlnet_cond_scale = controlnet_cond_scale[0]
|
| 1440 |
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
| 1441 |
|
|
|
|
| 1442 |
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 1443 |
control_model_input,
|
| 1444 |
t,
|
|
@@ -1449,6 +1461,21 @@ class StableDiffusionControlNetPipeline(
|
|
| 1449 |
return_dict=False,
|
| 1450 |
)
|
| 1451 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1452 |
if guess_mode and self.do_classifier_free_guidance:
|
| 1453 |
# Inferred ControlNet only for the conditional batch.
|
| 1454 |
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
|
|
|
| 20 |
import PIL.Image
|
| 21 |
import torch
|
| 22 |
import torch.nn.functional as F
|
| 23 |
+
from transformers import (CLIPImageProcessor, CLIPTextModel, CLIPTokenizer,
|
| 24 |
+
CLIPVisionModelWithProjection)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 27 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 28 |
+
from diffusers.loaders import (FromSingleFileMixin, IPAdapterMixin,
|
| 29 |
+
StableDiffusionLoraLoaderMixin,
|
| 30 |
+
TextualInversionLoaderMixin)
|
| 31 |
+
from diffusers.models import (AutoencoderKL, ControlNetModel, ImageProjection,
|
| 32 |
+
UNet2DConditionModel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
| 34 |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
| 35 |
+
from diffusers.pipelines.pipeline_utils import (DiffusionPipeline,
|
| 36 |
+
StableDiffusionMixin)
|
| 37 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import \
|
| 38 |
+
StableDiffusionPipelineOutput
|
| 39 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
| 40 |
+
StableDiffusionSafetyChecker
|
|
|
|
| 41 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 42 |
+
from diffusers.utils import (USE_PEFT_BACKEND, deprecate, logging,
|
| 43 |
+
replace_example_docstring, scale_lora_layers,
|
| 44 |
+
unscale_lora_layers)
|
| 45 |
+
from diffusers.utils.torch_utils import (is_compiled_module, is_torch_version,
|
| 46 |
+
randn_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 49 |
|
|
|
|
| 671 |
control_guidance_start=0.0,
|
| 672 |
control_guidance_end=1.0,
|
| 673 |
callback_on_step_end_tensor_inputs=None,
|
| 674 |
+
effective_region_mask=None,
|
| 675 |
):
|
| 676 |
if callback_steps is not None and (
|
| 677 |
not isinstance(callback_steps, int) or callback_steps <= 0
|
|
|
|
| 834 |
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
|
| 835 |
)
|
| 836 |
|
| 837 |
+
if effective_region_mask is not None:
|
| 838 |
+
self.check_mask(effective_region_mask)
|
| 839 |
+
|
| 840 |
def check_image(self, image, prompt, prompt_embeds):
|
| 841 |
image_is_pil = isinstance(image, PIL.Image.Image)
|
| 842 |
image_is_tensor = isinstance(image, torch.Tensor)
|
|
|
|
| 878 |
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
| 879 |
)
|
| 880 |
|
| 881 |
+
def check_mask(self, mask):
|
| 882 |
+
image_is_pil = isinstance(mask, PIL.Image.Image)
|
| 883 |
+
image_is_tensor = isinstance(mask, torch.Tensor)
|
| 884 |
+
image_is_np = isinstance(mask, np.ndarray)
|
| 885 |
+
|
| 886 |
+
if not image_is_pil and not image_is_tensor and not image_is_np:
|
| 887 |
+
raise TypeError(
|
| 888 |
+
f"mask must be passed and be one of PIL image, numpy array, or torch tensor, but is {type(mask)}"
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
def prepare_image(
|
| 892 |
self,
|
| 893 |
image,
|
|
|
|
| 989 |
assert emb.shape == (w.shape[0], embedding_dim)
|
| 990 |
return emb
|
| 991 |
|
| 992 |
+
def apply_effective_region_mask(
|
| 993 |
+
self, effective_region_mask: torch.Tensor, out: torch.Tensor
|
| 994 |
+
) -> torch.Tensor:
|
| 995 |
+
if effective_region_mask is None:
|
| 996 |
+
return out
|
| 997 |
+
|
| 998 |
+
B, C, H, W = out.shape
|
| 999 |
+
mask = F.interpolate(
|
| 1000 |
+
effective_region_mask.to(out.device),
|
| 1001 |
+
size=(H, W),
|
| 1002 |
+
mode="bilinear",
|
| 1003 |
+
)
|
| 1004 |
+
return out * mask
|
| 1005 |
+
|
| 1006 |
@property
|
| 1007 |
def guidance_scale(self):
|
| 1008 |
return self._guidance_scale
|
|
|
|
| 1181 |
callback = kwargs.pop("callback", None)
|
| 1182 |
callback_steps = kwargs.pop("callback_steps", None)
|
| 1183 |
|
| 1184 |
+
effective_region_mask = kwargs.pop("effective_region_mask", None)
|
| 1185 |
+
|
| 1186 |
if callback is not None:
|
| 1187 |
deprecate(
|
| 1188 |
"callback",
|
|
|
|
| 1243 |
control_guidance_start,
|
| 1244 |
control_guidance_end,
|
| 1245 |
callback_on_step_end_tensor_inputs,
|
| 1246 |
+
effective_region_mask,
|
| 1247 |
)
|
| 1248 |
|
| 1249 |
self._guidance_scale = guidance_scale
|
|
|
|
| 1450 |
controlnet_cond_scale = controlnet_cond_scale[0]
|
| 1451 |
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
| 1452 |
|
| 1453 |
+
# Controlnet is returning the residuals to be added to SD here
|
| 1454 |
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 1455 |
control_model_input,
|
| 1456 |
t,
|
|
|
|
| 1461 |
return_dict=False,
|
| 1462 |
)
|
| 1463 |
|
| 1464 |
+
# Apply mask here
|
| 1465 |
+
# Note that downblocks are ordered from largest->smallest
|
| 1466 |
+
if effective_region_mask is not None:
|
| 1467 |
+
masked_down_block_res_samples = ()
|
| 1468 |
+
for down_block_res_sample in down_block_res_samples:
|
| 1469 |
+
down_block_res_sample = self.apply_effective_region_mask(
|
| 1470 |
+
effective_region_mask, down_block_res_sample
|
| 1471 |
+
)
|
| 1472 |
+
masked_down_block_res_samples = (
|
| 1473 |
+
masked_down_block_res_samples + (down_block_res_sample,)
|
| 1474 |
+
)
|
| 1475 |
+
mid_block_res_sample = self.apply_effective_region_mask(
|
| 1476 |
+
effective_region_mask, mid_block_res_sample
|
| 1477 |
+
)
|
| 1478 |
+
|
| 1479 |
if guess_mode and self.do_classifier_free_guidance:
|
| 1480 |
# Inferred ControlNet only for the conditional batch.
|
| 1481 |
# To apply the output of ControlNet to both the unconditional and conditional batches,
|