Spaces:
Runtime error
Runtime error
Add remove
Browse files
app.py
CHANGED
|
@@ -11,6 +11,8 @@ import subprocess
|
|
| 11 |
import copy
|
| 12 |
import time
|
| 13 |
import warnings
|
|
|
|
|
|
|
| 14 |
|
| 15 |
import torch
|
| 16 |
from torchvision.ops import box_convert
|
|
@@ -26,13 +28,18 @@ import groundingdino.datasets.transforms as T
|
|
| 26 |
# segment anything
|
| 27 |
from segment_anything import build_sam, SamPredictor
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
#stable diffusion
|
| 30 |
from diffusers import StableDiffusionInpaintPipeline
|
| 31 |
|
| 32 |
from huggingface_hub import hf_hub_download
|
| 33 |
|
| 34 |
-
if not os.path.exists('./
|
| 35 |
-
os.system("wget https://github.com/IDEA-Research/Grounded-Segment-Anything/raw/main/assets/
|
| 36 |
|
| 37 |
if not os.path.exists('./sam_vit_h_4b8939.pth'):
|
| 38 |
logger.info(f"get sam_vit_h_4b8939.pth...")
|
|
@@ -177,6 +184,63 @@ def mix_masks(imgs):
|
|
| 177 |
re_img = 1 - re_img
|
| 178 |
return Image.fromarray(np.uint8(255*re_img))
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
| 181 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
|
| 182 |
|
|
@@ -199,6 +263,8 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
| 199 |
# load image
|
| 200 |
image_pil, image_tensor = load_image_and_transform(input_image['image'])
|
| 201 |
|
|
|
|
|
|
|
| 202 |
# RUN GROUNDINGDINO: we skip DINO if we draw mask on the image
|
| 203 |
if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
|
| 204 |
pass
|
|
@@ -218,7 +284,6 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
| 218 |
}
|
| 219 |
|
| 220 |
# store and save DINO output
|
| 221 |
-
output_images = []
|
| 222 |
image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
|
| 223 |
image_path = os.path.join(output_dir, f"grounding_dino_output_{file_temp}.jpg")
|
| 224 |
image_with_box.save(image_path)
|
|
@@ -300,7 +365,39 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
| 300 |
image_source_for_inpaint = image_pil.resize((512, 512))
|
| 301 |
image_mask_for_inpaint = mask_pil.resize((512, 512))
|
| 302 |
image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
|
| 306 |
output_images.append(image_inpainting)
|
|
@@ -330,6 +427,7 @@ def change_radio_display(task_type, mask_source_radio):
|
|
| 330 |
# model initialization
|
| 331 |
groundingDino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filename, groundingdino_device)
|
| 332 |
sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
|
|
|
|
| 333 |
|
| 334 |
# initialize stable-diffusion-inpainting
|
| 335 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
|
@@ -359,7 +457,7 @@ if __name__ == "__main__":
|
|
| 359 |
with gr.Row():
|
| 360 |
with gr.Column():
|
| 361 |
input_image = gr.Image(
|
| 362 |
-
source="upload", elem_id="image_upload", type="pil", tool="sketch", value="
|
| 363 |
task_type = gr.Radio(["segment", "inpainting", "remove"], value="segment",
|
| 364 |
label='Task type', visible=True)
|
| 365 |
|
|
@@ -368,7 +466,7 @@ if __name__ == "__main__":
|
|
| 368 |
visible=False)
|
| 369 |
|
| 370 |
text_prompt = gr.Textbox(label="Detection Prompt, seperating each name with dot '.', i.e.: bear.cat.dog.chair ]", \
|
| 371 |
-
value='
|
| 372 |
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
| 373 |
|
| 374 |
run_button = gr.Button(label="Run")
|
|
|
|
| 11 |
import copy
|
| 12 |
import time
|
| 13 |
import warnings
|
| 14 |
+
import io
|
| 15 |
+
import random
|
| 16 |
|
| 17 |
import torch
|
| 18 |
from torchvision.ops import box_convert
|
|
|
|
| 28 |
# segment anything
|
| 29 |
from segment_anything import build_sam, SamPredictor
|
| 30 |
|
| 31 |
+
# lama-cleaner
|
| 32 |
+
from lama_cleaner.model_manager import ModelManager
|
| 33 |
+
from lama_cleaner.schema import Config as lama_Config
|
| 34 |
+
from lama_cleaner.helper import load_img, numpy_to_bytes, resize_max_size
|
| 35 |
+
|
| 36 |
#stable diffusion
|
| 37 |
from diffusers import StableDiffusionInpaintPipeline
|
| 38 |
|
| 39 |
from huggingface_hub import hf_hub_download
|
| 40 |
|
| 41 |
+
if not os.path.exists('./demo2.jpg'):
|
| 42 |
+
os.system("wget https://github.com/IDEA-Research/Grounded-Segment-Anything/raw/main/assets/demo2.jpg")
|
| 43 |
|
| 44 |
if not os.path.exists('./sam_vit_h_4b8939.pth'):
|
| 45 |
logger.info(f"get sam_vit_h_4b8939.pth...")
|
|
|
|
| 184 |
re_img = 1 - re_img
|
| 185 |
return Image.fromarray(np.uint8(255*re_img))
|
| 186 |
|
| 187 |
+
def lama_cleaner_process(image, mask):
|
| 188 |
+
ori_image = image
|
| 189 |
+
if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
|
| 190 |
+
# rotate image
|
| 191 |
+
ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
|
| 192 |
+
image = ori_image
|
| 193 |
+
|
| 194 |
+
original_shape = ori_image.shape
|
| 195 |
+
interpolation = cv2.INTER_CUBIC
|
| 196 |
+
|
| 197 |
+
size_limit = 1080
|
| 198 |
+
if size_limit == "Original":
|
| 199 |
+
size_limit = max(image.shape)
|
| 200 |
+
else:
|
| 201 |
+
size_limit = int(size_limit)
|
| 202 |
+
|
| 203 |
+
config = lama_Config(
|
| 204 |
+
ldm_steps=25,
|
| 205 |
+
ldm_sampler='plms',
|
| 206 |
+
zits_wireframe=True,
|
| 207 |
+
hd_strategy='Original',
|
| 208 |
+
hd_strategy_crop_margin=196,
|
| 209 |
+
hd_strategy_crop_trigger_size=1280,
|
| 210 |
+
hd_strategy_resize_limit=2048,
|
| 211 |
+
prompt='',
|
| 212 |
+
use_croper=False,
|
| 213 |
+
croper_x=0,
|
| 214 |
+
croper_y=0,
|
| 215 |
+
croper_height=512,
|
| 216 |
+
croper_width=512,
|
| 217 |
+
sd_mask_blur=5,
|
| 218 |
+
sd_strength=0.75,
|
| 219 |
+
sd_steps=50,
|
| 220 |
+
sd_guidance_scale=7.5,
|
| 221 |
+
sd_sampler='ddim',
|
| 222 |
+
sd_seed=42,
|
| 223 |
+
cv2_flag='INPAINT_NS',
|
| 224 |
+
cv2_radius=5,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if config.sd_seed == -1:
|
| 228 |
+
config.sd_seed = random.randint(1, 999999999)
|
| 229 |
+
|
| 230 |
+
# logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
|
| 231 |
+
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
| 232 |
+
# logger.info(f"Resized image shape_1_: {image.shape}")
|
| 233 |
+
|
| 234 |
+
# logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
|
| 235 |
+
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
| 236 |
+
# logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
|
| 237 |
+
|
| 238 |
+
res_np_img = lama_cleaner_model(image, mask, config)
|
| 239 |
+
torch.cuda.empty_cache()
|
| 240 |
+
|
| 241 |
+
image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
|
| 242 |
+
return image
|
| 243 |
+
|
| 244 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
| 245 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
|
| 246 |
|
|
|
|
| 263 |
# load image
|
| 264 |
image_pil, image_tensor = load_image_and_transform(input_image['image'])
|
| 265 |
|
| 266 |
+
output_images = []
|
| 267 |
+
output_images.append(input_image['image'])
|
| 268 |
# RUN GROUNDINGDINO: we skip DINO if we draw mask on the image
|
| 269 |
if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
|
| 270 |
pass
|
|
|
|
| 284 |
}
|
| 285 |
|
| 286 |
# store and save DINO output
|
|
|
|
| 287 |
image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
|
| 288 |
image_path = os.path.join(output_dir, f"grounding_dino_output_{file_temp}.jpg")
|
| 289 |
image_with_box.save(image_path)
|
|
|
|
| 365 |
image_source_for_inpaint = image_pil.resize((512, 512))
|
| 366 |
image_mask_for_inpaint = mask_pil.resize((512, 512))
|
| 367 |
image_inpainting = sd_pipe(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
| 368 |
+
else:
|
| 369 |
+
# remove from mask
|
| 370 |
+
if mask_source_radio == mask_source_segment:
|
| 371 |
+
mask_imgs = []
|
| 372 |
+
masks_shape = masks_ori.shape
|
| 373 |
+
boxes_filt_ori_array = boxes_filt_ori.numpy()
|
| 374 |
+
if inpaint_mode == 'merge':
|
| 375 |
+
extend_shape_0 = masks_shape[0]
|
| 376 |
+
extend_shape_1 = masks_shape[1]
|
| 377 |
+
else:
|
| 378 |
+
extend_shape_0 = 1
|
| 379 |
+
extend_shape_1 = 1
|
| 380 |
+
for i in range(extend_shape_0):
|
| 381 |
+
for j in range(extend_shape_1):
|
| 382 |
+
mask = masks_ori[i][j].cpu().numpy()
|
| 383 |
+
mask_pil = Image.fromarray(mask)
|
| 384 |
+
|
| 385 |
+
if remove_mode == 'segment':
|
| 386 |
+
useRectangle = False
|
| 387 |
+
else:
|
| 388 |
+
useRectangle = True
|
| 389 |
+
|
| 390 |
+
try:
|
| 391 |
+
remove_mask_extend = int(remove_mask_extend)
|
| 392 |
+
except:
|
| 393 |
+
remove_mask_extend = 10
|
| 394 |
+
mask_pil_exp = mask_extend(copy.deepcopy(mask_pil).convert("RGB"),
|
| 395 |
+
box_convert(torch.tensor(boxes_filt_ori_array[i]), in_fmt="cxcywh", out_fmt="xyxy").numpy(),
|
| 396 |
+
extend_pixels=remove_mask_extend, useRectangle=useRectangle)
|
| 397 |
+
mask_imgs.append(mask_pil_exp)
|
| 398 |
+
mask_pil = mix_masks(mask_imgs)
|
| 399 |
+
output_images.append(mask_pil.convert("RGB"))
|
| 400 |
+
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")))
|
| 401 |
|
| 402 |
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
|
| 403 |
output_images.append(image_inpainting)
|
|
|
|
| 427 |
# model initialization
|
| 428 |
groundingDino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filename, groundingdino_device)
|
| 429 |
sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
|
| 430 |
+
lama_cleaner_model = ModelManager(name='lama',device='cpu')
|
| 431 |
|
| 432 |
# initialize stable-diffusion-inpainting
|
| 433 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
|
|
|
| 457 |
with gr.Row():
|
| 458 |
with gr.Column():
|
| 459 |
input_image = gr.Image(
|
| 460 |
+
source="upload", elem_id="image_upload", type="pil", tool="sketch", value="demo2.jpg", label="Upload")
|
| 461 |
task_type = gr.Radio(["segment", "inpainting", "remove"], value="segment",
|
| 462 |
label='Task type', visible=True)
|
| 463 |
|
|
|
|
| 466 |
visible=False)
|
| 467 |
|
| 468 |
text_prompt = gr.Textbox(label="Detection Prompt, seperating each name with dot '.', i.e.: bear.cat.dog.chair ]", \
|
| 469 |
+
value='dog', placeholder="Cannot be empty")
|
| 470 |
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
| 471 |
|
| 472 |
run_button = gr.Button(label="Run")
|