set num_views as attr of attn_processor to support torch.compile
Browse files- pipeline_imagedream.py +1 -6
pipeline_imagedream.py
CHANGED
|
@@ -32,7 +32,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|
| 32 |
StableDiffusionSafetyChecker,
|
| 33 |
)
|
| 34 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 35 |
-
from diffusers.utils import deprecate
|
| 36 |
from transformers import (
|
| 37 |
CLIPImageProcessor,
|
| 38 |
CLIPTextModel,
|
|
@@ -90,11 +90,6 @@ class ImageDreamPipeline(StableDiffusionPipeline):
|
|
| 90 |
setattr(self.image_encoder, "visual_projection", nn.Identity())
|
| 91 |
add_imagedream_attn_processor(self.unet)
|
| 92 |
set_num_views(self.unet, self.num_views + 1)
|
| 93 |
-
logging.set_verbosity_error()
|
| 94 |
-
print(
|
| 95 |
-
"ImageDream Cross-Attention uses `num_views` kwarg, "
|
| 96 |
-
"and set logging verbosity to error."
|
| 97 |
-
)
|
| 98 |
|
| 99 |
def unload_ip_adapter(self) -> None:
|
| 100 |
super().unload_ip_adapter()
|
|
|
|
| 32 |
StableDiffusionSafetyChecker,
|
| 33 |
)
|
| 34 |
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 35 |
+
from diffusers.utils import deprecate
|
| 36 |
from transformers import (
|
| 37 |
CLIPImageProcessor,
|
| 38 |
CLIPTextModel,
|
|
|
|
| 90 |
setattr(self.image_encoder, "visual_projection", nn.Identity())
|
| 91 |
add_imagedream_attn_processor(self.unet)
|
| 92 |
set_num_views(self.unet, self.num_views + 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
def unload_ip_adapter(self) -> None:
|
| 95 |
super().unload_ip_adapter()
|