Spaces:
Running
on
Zero
Running
on
Zero
srkanth
#2
by
srikanthsri
- opened
- app.py +1 -1
- pipeline_objectclear.py +35 -54
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import spaces
|
| 3 |
import os
|
| 4 |
from PIL import Image
|
| 5 |
import torch
|
|
@@ -11,6 +10,7 @@ import argparse
|
|
| 11 |
import numpy as np
|
| 12 |
import torchvision.transforms.functional as TF
|
| 13 |
from scipy.ndimage import convolve, zoom
|
|
|
|
| 14 |
from utils import resize_by_short_side
|
| 15 |
|
| 16 |
from tools.interact_tools import SamControler
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
import os
|
| 3 |
from PIL import Image
|
| 4 |
import torch
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
import torchvision.transforms.functional as TF
|
| 12 |
from scipy.ndimage import convolve, zoom
|
| 13 |
+
import spaces
|
| 14 |
from utils import resize_by_short_side
|
| 15 |
|
| 16 |
from tools.interact_tools import SamControler
|
pipeline_objectclear.py
CHANGED
|
@@ -14,7 +14,6 @@
|
|
| 14 |
|
| 15 |
import inspect
|
| 16 |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 17 |
-
import os
|
| 18 |
|
| 19 |
import numpy as np
|
| 20 |
import PIL.Image
|
|
@@ -335,7 +334,6 @@ def retrieve_timesteps(
|
|
| 335 |
class ObjectClearPipelineOutput(StableDiffusionXLPipelineOutput):
|
| 336 |
attns: Optional[List[PIL.Image.Image]] = None
|
| 337 |
|
| 338 |
-
|
| 339 |
class ObjectClearPipeline(
|
| 340 |
DiffusionPipeline,
|
| 341 |
StableDiffusionMixin,
|
|
@@ -430,7 +428,7 @@ class ObjectClearPipeline(
|
|
| 430 |
requires_aesthetics_score: bool = False,
|
| 431 |
force_zeros_for_empty_prompt: bool = True,
|
| 432 |
add_watermarker: Optional[bool] = None,
|
| 433 |
-
apply_attention_guided_fusion: bool =
|
| 434 |
):
|
| 435 |
super().__init__()
|
| 436 |
|
|
@@ -465,7 +463,9 @@ class ObjectClearPipeline(
|
|
| 465 |
|
| 466 |
if self.config.apply_attention_guided_fusion:
|
| 467 |
self.cross_attention_scores = {}
|
| 468 |
-
self.
|
|
|
|
|
|
|
| 469 |
|
| 470 |
|
| 471 |
@classmethod
|
|
@@ -486,17 +486,14 @@ class ObjectClearPipeline(
|
|
| 486 |
)
|
| 487 |
|
| 488 |
postfuse_module = PostfuseModule(embed_dim=2048, embed_dim_img=768)
|
| 489 |
-
sub_folder = "postfuse_module"
|
| 490 |
filename = "model.safetensors"
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
else:
|
| 499 |
-
safetensor_path = os.path.join(pretrained_model_name_or_path, sub_folder, filename)
|
| 500 |
state_dict_postfuse = load_file(safetensor_path)
|
| 501 |
postfuse_module.load_state_dict(state_dict_postfuse)
|
| 502 |
|
|
@@ -540,7 +537,7 @@ class ObjectClearPipeline(
|
|
| 540 |
|
| 541 |
return image_embeds, uncond_image_embeds
|
| 542 |
|
| 543 |
-
def unet_store_cross_attention_scores(self, unet, attention_scores
|
| 544 |
from diffusers.models.attention_processor import (
|
| 545 |
Attention,
|
| 546 |
AttnProcessor,
|
|
@@ -548,25 +545,34 @@ class ObjectClearPipeline(
|
|
| 548 |
)
|
| 549 |
import types
|
| 550 |
|
| 551 |
-
|
| 552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
|
| 554 |
def make_new_get_attention_scores_fn(name):
|
| 555 |
def new_get_attention_scores(module, query, key, attention_mask=None):
|
| 556 |
attention_probs = module.old_get_attention_scores(
|
| 557 |
query, key, attention_mask
|
| 558 |
)
|
| 559 |
-
|
| 560 |
-
attention_scores[name] = attention_probs
|
| 561 |
return attention_probs
|
|
|
|
| 562 |
return new_get_attention_scores
|
| 563 |
|
| 564 |
for name, module in unet.named_modules():
|
| 565 |
-
if isinstance(module, Attention) and
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
"get_attention_scores": module.get_attention_scores
|
| 569 |
-
}
|
| 570 |
if isinstance(module.processor, AttnProcessor2_0):
|
| 571 |
module.set_processor(AttnProcessor())
|
| 572 |
module.old_get_attention_scores = module.get_attention_scores
|
|
@@ -575,19 +581,6 @@ class ObjectClearPipeline(
|
|
| 575 |
)
|
| 576 |
module.get_attention_scores = module.new_get_attention_scores
|
| 577 |
|
| 578 |
-
return unet, original_state
|
| 579 |
-
|
| 580 |
-
def unet_restore_attention_processor(self, unet, original_state):
|
| 581 |
-
from diffusers.models.attention_processor import Attention
|
| 582 |
-
|
| 583 |
-
for name, module in unet.named_modules():
|
| 584 |
-
if isinstance(module, Attention) and "attn2" in name and name in original_state:
|
| 585 |
-
module.get_attention_scores = original_state[name]["get_attention_scores"]
|
| 586 |
-
module.set_processor(original_state[name]["processor"])
|
| 587 |
-
if hasattr(module, "old_get_attention_scores"):
|
| 588 |
-
delattr(module, "old_get_attention_scores")
|
| 589 |
-
if hasattr(module, "new_get_attention_scores"):
|
| 590 |
-
delattr(module, "new_get_attention_scores")
|
| 591 |
return unet
|
| 592 |
|
| 593 |
def resize_attn_map_divide2(self, attn_map, mask, fuse_index):
|
|
@@ -1433,7 +1426,7 @@ class ObjectClearPipeline(
|
|
| 1433 |
on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
|
| 1434 |
resizing to the original image size for inpainting. This is useful when the masked area is small while
|
| 1435 |
the image is large and contain information irrelevant for inpainting, such as background.
|
| 1436 |
-
strength (`float`, *optional*, defaults to 0
|
| 1437 |
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
|
| 1438 |
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
|
| 1439 |
`strength`. The number of denoising steps depends on the amount of noise initially added. When
|
|
@@ -1878,12 +1871,6 @@ class ObjectClearPipeline(
|
|
| 1878 |
for i, t in enumerate(timesteps):
|
| 1879 |
if self.interrupt:
|
| 1880 |
continue
|
| 1881 |
-
# Inject cross-attention storage logic at the last timestep
|
| 1882 |
-
if i == len(timesteps) - 1 and self.config.apply_attention_guided_fusion:
|
| 1883 |
-
self.unet, self.original_state = self.unet_store_cross_attention_scores(
|
| 1884 |
-
self.unet,
|
| 1885 |
-
self.cross_attention_scores
|
| 1886 |
-
)
|
| 1887 |
# expand the latents if we are doing classifier free guidance
|
| 1888 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 1889 |
|
|
@@ -1937,8 +1924,8 @@ class ObjectClearPipeline(
|
|
| 1937 |
)
|
| 1938 |
|
| 1939 |
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
| 1940 |
-
|
| 1941 |
-
if i == len(timesteps) - 1
|
| 1942 |
attn_key, attn_map = next(iter(self.cross_attention_scores.items()))
|
| 1943 |
attn_map = self.resize_attn_map_divide2(attn_map, mask, fuse_index)
|
| 1944 |
init_latents_proper = image_latents
|
|
@@ -1947,13 +1934,7 @@ class ObjectClearPipeline(
|
|
| 1947 |
else:
|
| 1948 |
init_mask = attn_map
|
| 1949 |
attn_map = init_mask
|
| 1950 |
-
|
| 1951 |
-
self.unet = self.unet_restore_attention_processor(
|
| 1952 |
-
self.unet,
|
| 1953 |
-
self.original_state
|
| 1954 |
-
)
|
| 1955 |
-
|
| 1956 |
-
self.clear_cross_attention_scores(self.cross_attention_scores)
|
| 1957 |
|
| 1958 |
if num_channels_unet == 4:
|
| 1959 |
init_latents_proper = image_latents
|
|
@@ -2076,4 +2057,4 @@ class ObjectClearPipeline(
|
|
| 2076 |
else:
|
| 2077 |
if not return_dict:
|
| 2078 |
return (image,)
|
| 2079 |
-
return ObjectClearPipelineOutput(images=image)
|
|
|
|
| 14 |
|
| 15 |
import inspect
|
| 16 |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import PIL.Image
|
|
|
|
| 334 |
class ObjectClearPipelineOutput(StableDiffusionXLPipelineOutput):
|
| 335 |
attns: Optional[List[PIL.Image.Image]] = None
|
| 336 |
|
|
|
|
| 337 |
class ObjectClearPipeline(
|
| 338 |
DiffusionPipeline,
|
| 339 |
StableDiffusionMixin,
|
|
|
|
| 428 |
requires_aesthetics_score: bool = False,
|
| 429 |
force_zeros_for_empty_prompt: bool = True,
|
| 430 |
add_watermarker: Optional[bool] = None,
|
| 431 |
+
apply_attention_guided_fusion: bool = False,
|
| 432 |
):
|
| 433 |
super().__init__()
|
| 434 |
|
|
|
|
| 463 |
|
| 464 |
if self.config.apply_attention_guided_fusion:
|
| 465 |
self.cross_attention_scores = {}
|
| 466 |
+
self.unet = self.unet_store_cross_attention_scores(
|
| 467 |
+
self.unet, self.cross_attention_scores
|
| 468 |
+
)
|
| 469 |
|
| 470 |
|
| 471 |
@classmethod
|
|
|
|
| 486 |
)
|
| 487 |
|
| 488 |
postfuse_module = PostfuseModule(embed_dim=2048, embed_dim_img=768)
|
|
|
|
| 489 |
filename = "model.safetensors"
|
| 490 |
+
|
| 491 |
+
safetensor_path = hf_hub_download(
|
| 492 |
+
repo_id="jixin0101/ObjectClear",
|
| 493 |
+
filename=filename,
|
| 494 |
+
subfolder="postfuse_module",
|
| 495 |
+
cache_dir=cache_dir
|
| 496 |
+
)
|
|
|
|
|
|
|
| 497 |
state_dict_postfuse = load_file(safetensor_path)
|
| 498 |
postfuse_module.load_state_dict(state_dict_postfuse)
|
| 499 |
|
|
|
|
| 537 |
|
| 538 |
return image_embeds, uncond_image_embeds
|
| 539 |
|
| 540 |
+
def unet_store_cross_attention_scores(self, unet, attention_scores):
|
| 541 |
from diffusers.models.attention_processor import (
|
| 542 |
Attention,
|
| 543 |
AttnProcessor,
|
|
|
|
| 545 |
)
|
| 546 |
import types
|
| 547 |
|
| 548 |
+
UNET_LAYER_NAMES = [
|
| 549 |
+
"down_blocks.0",
|
| 550 |
+
"down_blocks.1",
|
| 551 |
+
"down_blocks.2",
|
| 552 |
+
"mid_block",
|
| 553 |
+
"up_blocks.1",
|
| 554 |
+
"up_blocks.2",
|
| 555 |
+
"up_blocks.3",
|
| 556 |
+
]
|
| 557 |
+
|
| 558 |
+
start_layer = 0
|
| 559 |
+
end_layer = 2
|
| 560 |
+
applicable_layers = UNET_LAYER_NAMES[start_layer:end_layer]
|
| 561 |
|
| 562 |
def make_new_get_attention_scores_fn(name):
|
| 563 |
def new_get_attention_scores(module, query, key, attention_mask=None):
|
| 564 |
attention_probs = module.old_get_attention_scores(
|
| 565 |
query, key, attention_mask
|
| 566 |
)
|
| 567 |
+
attention_scores[name] = attention_probs
|
|
|
|
| 568 |
return attention_probs
|
| 569 |
+
|
| 570 |
return new_get_attention_scores
|
| 571 |
|
| 572 |
for name, module in unet.named_modules():
|
| 573 |
+
if isinstance(module, Attention) and "attn2" in name:
|
| 574 |
+
if not any(layer in name for layer in applicable_layers):
|
| 575 |
+
continue
|
|
|
|
|
|
|
| 576 |
if isinstance(module.processor, AttnProcessor2_0):
|
| 577 |
module.set_processor(AttnProcessor())
|
| 578 |
module.old_get_attention_scores = module.get_attention_scores
|
|
|
|
| 581 |
)
|
| 582 |
module.get_attention_scores = module.new_get_attention_scores
|
| 583 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
return unet
|
| 585 |
|
| 586 |
def resize_attn_map_divide2(self, attn_map, mask, fuse_index):
|
|
|
|
| 1426 |
on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
|
| 1427 |
resizing to the original image size for inpainting. This is useful when the masked area is small while
|
| 1428 |
the image is large and contain information irrelevant for inpainting, such as background.
|
| 1429 |
+
strength (`float`, *optional*, defaults to 1.0):
|
| 1430 |
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
|
| 1431 |
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
|
| 1432 |
`strength`. The number of denoising steps depends on the amount of noise initially added. When
|
|
|
|
| 1871 |
for i, t in enumerate(timesteps):
|
| 1872 |
if self.interrupt:
|
| 1873 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1874 |
# expand the latents if we are doing classifier free guidance
|
| 1875 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 1876 |
|
|
|
|
| 1924 |
)
|
| 1925 |
|
| 1926 |
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
| 1927 |
+
|
| 1928 |
+
if i == len(timesteps) - 1:
|
| 1929 |
attn_key, attn_map = next(iter(self.cross_attention_scores.items()))
|
| 1930 |
attn_map = self.resize_attn_map_divide2(attn_map, mask, fuse_index)
|
| 1931 |
init_latents_proper = image_latents
|
|
|
|
| 1934 |
else:
|
| 1935 |
init_mask = attn_map
|
| 1936 |
attn_map = init_mask
|
| 1937 |
+
self.clear_cross_attention_scores(self.cross_attention_scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1938 |
|
| 1939 |
if num_channels_unet == 4:
|
| 1940 |
init_latents_proper = image_latents
|
|
|
|
| 2057 |
else:
|
| 2058 |
if not return_dict:
|
| 2059 |
return (image,)
|
| 2060 |
+
return ObjectClearPipelineOutput(images=image)
|
requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
accelerate
|
| 2 |
-
torch==2.
|
| 3 |
torchvision
|
| 4 |
numpy==1.26.4
|
| 5 |
opencv-python
|
|
|
|
| 1 |
accelerate
|
| 2 |
+
torch==2.2.0
|
| 3 |
torchvision
|
| 4 |
numpy==1.26.4
|
| 5 |
opencv-python
|