set num_views as attr of attn_processor to support torch.compile
Browse files- pipeline_imagedream.py +24 -11
pipeline_imagedream.py
CHANGED
|
@@ -76,7 +76,7 @@ class ImageDreamPipeline(StableDiffusionPipeline):
|
|
| 76 |
weight_name: Union[str, List[str]] = "ip-adapter-plus_imagedream.bin",
|
| 77 |
image_encoder_folder: Optional[str] = "image_encoder",
|
| 78 |
**kwargs,
|
| 79 |
-
):
|
| 80 |
super().load_ip_adapter(
|
| 81 |
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
| 82 |
subfolder=subfolder,
|
|
@@ -89,12 +89,17 @@ class ImageDreamPipeline(StableDiffusionPipeline):
|
|
| 89 |
if weight_name == "ip-adapter-plus_imagedream.bin":
|
| 90 |
setattr(self.image_encoder, "visual_projection", nn.Identity())
|
| 91 |
add_imagedream_attn_processor(self.unet)
|
|
|
|
| 92 |
logging.set_verbosity_error()
|
| 93 |
print(
|
| 94 |
"ImageDream Cross-Attention uses `num_views` kwarg, "
|
| 95 |
"and set logging verbosity to error."
|
| 96 |
)
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
def encode_image_to_latents(
|
| 99 |
self,
|
| 100 |
image: PipelineImageInput,
|
|
@@ -326,9 +331,7 @@ class ImageDreamPipeline(StableDiffusionPipeline):
|
|
| 326 |
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 327 |
).to(device=device, dtype=latents.dtype)
|
| 328 |
|
| 329 |
-
|
| 330 |
-
if self.cross_attention_kwargs is not None:
|
| 331 |
-
cross_attention_kwargs.update(self.cross_attention_kwargs)
|
| 332 |
|
| 333 |
# fmt: off
|
| 334 |
# 7. Denoising loop
|
|
@@ -352,7 +355,7 @@ class ImageDreamPipeline(StableDiffusionPipeline):
|
|
| 352 |
class_labels=camera,
|
| 353 |
encoder_hidden_states=prompt_embeds,
|
| 354 |
timestep_cond=timestep_cond,
|
| 355 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
| 356 |
added_cond_kwargs=added_cond_kwargs,
|
| 357 |
return_dict=False,
|
| 358 |
)[0]
|
|
@@ -508,7 +511,7 @@ def get_camera(
|
|
| 508 |
# fmt: on
|
| 509 |
|
| 510 |
|
| 511 |
-
def add_imagedream_attn_processor(unet: UNet2DConditionModel) ->
|
| 512 |
attn_procs = {}
|
| 513 |
for key, attn_processor in unet.attn_processors.items():
|
| 514 |
if "attn1" in key:
|
|
@@ -519,7 +522,18 @@ def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> nn.Module:
|
|
| 519 |
return unet
|
| 520 |
|
| 521 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
def __call__(
|
| 524 |
self,
|
| 525 |
attn: Attention,
|
|
@@ -527,11 +541,10 @@ class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
|
|
| 527 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 528 |
attention_mask: Optional[torch.Tensor] = None,
|
| 529 |
temb: Optional[torch.Tensor] = None,
|
| 530 |
-
num_views: int = 1,
|
| 531 |
*args,
|
| 532 |
**kwargs,
|
| 533 |
):
|
| 534 |
-
if num_views == 1:
|
| 535 |
return super().__call__(
|
| 536 |
attn=attn,
|
| 537 |
hidden_states=hidden_states,
|
|
@@ -544,11 +557,11 @@ class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
|
|
| 544 |
|
| 545 |
input_ndim = hidden_states.ndim
|
| 546 |
B = hidden_states.size(0)
|
| 547 |
-
if B % num_views:
|
| 548 |
raise ValueError(
|
| 549 |
-
f"`batch_size`(got {B}) must be a multiple of `num_views`(got {num_views})."
|
| 550 |
)
|
| 551 |
-
real_B = B // num_views
|
| 552 |
if input_ndim == 4:
|
| 553 |
H, W = hidden_states.shape[2:]
|
| 554 |
hidden_states = hidden_states.reshape(real_B, -1, H, W).transpose(1, 2)
|
|
|
|
| 76 |
weight_name: Union[str, List[str]] = "ip-adapter-plus_imagedream.bin",
|
| 77 |
image_encoder_folder: Optional[str] = "image_encoder",
|
| 78 |
**kwargs,
|
| 79 |
+
) -> None:
|
| 80 |
super().load_ip_adapter(
|
| 81 |
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
| 82 |
subfolder=subfolder,
|
|
|
|
| 89 |
if weight_name == "ip-adapter-plus_imagedream.bin":
|
| 90 |
setattr(self.image_encoder, "visual_projection", nn.Identity())
|
| 91 |
add_imagedream_attn_processor(self.unet)
|
| 92 |
+
set_num_views(self.unet, self.num_views + 1)
|
| 93 |
logging.set_verbosity_error()
|
| 94 |
print(
|
| 95 |
"ImageDream Cross-Attention uses `num_views` kwarg, "
|
| 96 |
"and set logging verbosity to error."
|
| 97 |
)
|
| 98 |
|
| 99 |
+
def unload_ip_adapter(self) -> None:
|
| 100 |
+
super().unload_ip_adapter()
|
| 101 |
+
set_num_views(self.unet, self.num_views)
|
| 102 |
+
|
| 103 |
def encode_image_to_latents(
|
| 104 |
self,
|
| 105 |
image: PipelineImageInput,
|
|
|
|
| 331 |
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 332 |
).to(device=device, dtype=latents.dtype)
|
| 333 |
|
| 334 |
+
set_num_views(self.unet, num_views)
|
|
|
|
|
|
|
| 335 |
|
| 336 |
# fmt: off
|
| 337 |
# 7. Denoising loop
|
|
|
|
| 355 |
class_labels=camera,
|
| 356 |
encoder_hidden_states=prompt_embeds,
|
| 357 |
timestep_cond=timestep_cond,
|
| 358 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 359 |
added_cond_kwargs=added_cond_kwargs,
|
| 360 |
return_dict=False,
|
| 361 |
)[0]
|
|
|
|
| 511 |
# fmt: on
|
| 512 |
|
| 513 |
|
| 514 |
+
def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> UNet2DConditionModel:
|
| 515 |
attn_procs = {}
|
| 516 |
for key, attn_processor in unet.attn_processors.items():
|
| 517 |
if "attn1" in key:
|
|
|
|
| 522 |
return unet
|
| 523 |
|
| 524 |
|
| 525 |
+
def set_num_views(unet: UNet2DConditionModel, num_views: int) -> UNet2DConditionModel:
|
| 526 |
+
for key, attn_processor in unet.attn_processors.items():
|
| 527 |
+
if isinstance(attn_processor, ImageDreamAttnProcessor2_0):
|
| 528 |
+
attn_processor.num_views = num_views
|
| 529 |
+
return unet
|
| 530 |
+
|
| 531 |
+
|
| 532 |
class ImageDreamAttnProcessor2_0(AttnProcessor2_0):
|
| 533 |
+
def __init__(self, num_views: int = 4):
|
| 534 |
+
super().__init__()
|
| 535 |
+
self.num_views = num_views
|
| 536 |
+
|
| 537 |
def __call__(
|
| 538 |
self,
|
| 539 |
attn: Attention,
|
|
|
|
| 541 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 542 |
attention_mask: Optional[torch.Tensor] = None,
|
| 543 |
temb: Optional[torch.Tensor] = None,
|
|
|
|
| 544 |
*args,
|
| 545 |
**kwargs,
|
| 546 |
):
|
| 547 |
+
if self.num_views == 1:
|
| 548 |
return super().__call__(
|
| 549 |
attn=attn,
|
| 550 |
hidden_states=hidden_states,
|
|
|
|
| 557 |
|
| 558 |
input_ndim = hidden_states.ndim
|
| 559 |
B = hidden_states.size(0)
|
| 560 |
+
if B % self.num_views:
|
| 561 |
raise ValueError(
|
| 562 |
+
f"`batch_size`(got {B}) must be a multiple of `num_views`(got {self.num_views})."
|
| 563 |
)
|
| 564 |
+
real_B = B // self.num_views
|
| 565 |
if input_ndim == 4:
|
| 566 |
H, W = hidden_states.shape[2:]
|
| 567 |
hidden_states = hidden_states.reshape(real_B, -1, H, W).transpose(1, 2)
|