GulbaharAI commited on
Commit
7ca5b5b
·
verified ·
1 Parent(s): 45d7474

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +546 -0
app.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import os
4
+ from PIL import Image
5
+ import torch
6
+ from diffusers.utils import check_min_version
7
+ from pipeline_objectclear import ObjectClearPipeline
8
+ from tools.download_util import load_file_from_url
9
+ from tools.painter import mask_painter
10
+ import argparse
11
+ import numpy as np
12
+ import torchvision.transforms.functional as TF
13
+ from scipy.ndimage import convolve, zoom
14
+ from utils import resize_by_short_side
15
+
16
+ from tools.interact_tools import SamControler
17
+ from tools.misc import get_device
18
+ import json
19
+
20
+ check_min_version("0.30.2")
21
+
22
+
23
+ def parse_augment():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument('--device', type=str, default=None)
26
+ parser.add_argument('--sam_model_type', type=str, default="vit_h")
27
+ parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications")
28
+ args = parser.parse_args()
29
+
30
+ if not args.device:
31
+ args.device = str(get_device())
32
+
33
+ return args
34
+
35
+ # convert points input to prompt state
36
+ def get_prompt(click_state, click_input):
37
+ inputs = json.loads(click_input)
38
+ points = click_state[0]
39
+ labels = click_state[1]
40
+ for input in inputs:
41
+ points.append(input[:2])
42
+ labels.append(input[2])
43
+ click_state[0] = points
44
+ click_state[1] = labels
45
+ prompt = {
46
+ "prompt_type":["click"],
47
+ "input_point":click_state[0],
48
+ "input_label":click_state[1],
49
+ "multimask_output":"True",
50
+ }
51
+ return prompt
52
+
53
+ # use sam to get the mask
54
+ @spaces.GPU
55
+ def sam_refine(image_state, point_prompt, click_state, evt:gr.SelectData):
56
+ if point_prompt == "Positive":
57
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
58
+ else:
59
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
60
+
61
+ # prompt for sam model
62
+ model.samcontroler.sam_controler.reset_image()
63
+ model.samcontroler.sam_controler.set_image(image_state["origin_image"])
64
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
65
+
66
+ mask, logit, painted_image = model.first_frame_click(
67
+ image=image_state["origin_image"],
68
+ points=np.array(prompt["input_point"]),
69
+ labels=np.array(prompt["input_label"]),
70
+ multimask=prompt["multimask_output"],
71
+ )
72
+ image_state["mask"] = mask
73
+ image_state["logit"] = logit
74
+ image_state["painted_image"] = painted_image
75
+
76
+ return painted_image, image_state, click_state
77
+
78
+
79
+ def add_multi_mask(image_state, interactive_state, mask_dropdown):
80
+ mask = image_state["mask"]
81
+ interactive_state["masks"].append(mask)
82
+ interactive_state["mask_names"].append("mask_{:03d}".format(len(interactive_state["masks"])))
83
+ mask_dropdown.append("mask_{:03d}".format(len(interactive_state["masks"])))
84
+ select_frame = show_mask(image_state, interactive_state, mask_dropdown)
85
+
86
+ return interactive_state, gr.update(choices=interactive_state["mask_names"], value=mask_dropdown), select_frame, [[],[]]
87
+
88
+ def clear_click(image_state, click_state):
89
+ click_state = [[],[]]
90
+ input_image = image_state["origin_image"]
91
+ return input_image, click_state
92
+
93
+ def remove_multi_mask(interactive_state, click_state, image_state):
94
+ interactive_state["mask_names"]= []
95
+ interactive_state["masks"] = []
96
+ click_state = [[],[]]
97
+ input_image = image_state["origin_image"]
98
+
99
+ return interactive_state, gr.update(choices=[],value=[]), input_image, click_state
100
+
101
+ def show_mask(image_state, interactive_state, mask_dropdown):
102
+ mask_dropdown.sort()
103
+ if image_state["origin_image"] is not None:
104
+ select_frame = image_state["origin_image"]
105
+ for i in range(len(mask_dropdown)):
106
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
107
+ mask = interactive_state["masks"][mask_number]
108
+ select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
109
+
110
+ return select_frame
111
+
112
+ @spaces.GPU
113
+ def upload_and_reset(image_input, interactive_state):
114
+ click_state = [[], []]
115
+
116
+ interactive_state["mask_names"]= []
117
+ interactive_state["masks"] = []
118
+
119
+ image_state, image_info, image_input = update_image_state_on_upload(image_input)
120
+
121
+ return (
122
+ image_state,
123
+ image_info,
124
+ image_input,
125
+ interactive_state,
126
+ click_state,
127
+ gr.update(choices=[], value=[]),
128
+ )
129
+
130
+ def update_image_state_on_upload(image_input):
131
+ frame = image_input
132
+
133
+ image_size = (frame.size[1], frame.size[0])
134
+
135
+ frame_np = np.array(frame)
136
+
137
+ image_state = {
138
+ "origin_image": frame_np,
139
+ "painted_image": frame_np.copy(),
140
+ "mask": np.zeros((image_size[0], image_size[1]), np.uint8),
141
+ "logit": None,
142
+ }
143
+
144
+ image_info = f"Image Name: uploaded.png,\nImage Size: {image_size}"
145
+
146
+ model.samcontroler.sam_controler.reset_image()
147
+ model.samcontroler.sam_controler.set_image(frame_np)
148
+
149
+ return image_state, image_info, image_input
150
+
151
+
152
+
153
+ # SAM generator
154
+ class MaskGenerator():
155
+ def __init__(self, sam_checkpoint, args):
156
+ self.args = args
157
+ self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
158
+
159
+ def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
160
+ mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
161
+ return mask, logit, painted_image
162
+
163
+
164
+ # args, defined in track_anything.py
165
+ args = parse_augment()
166
+ sam_checkpoint_url_dict = {
167
+ 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
168
+ 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
169
+ 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
170
+ }
171
+ checkpoint_folder = os.path.join('/home/user/app/', 'pretrained_models')
172
+
173
+ sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_folder)
174
+ # initialize sams
175
+ model = MaskGenerator(sam_checkpoint, args)
176
+
177
+ # Build pipeline
178
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
179
+ pipe = ObjectClearPipeline.from_pretrained_with_custom_modules(
180
+ "jixin0101/ObjectClear",
181
+ torch_dtype=torch.float16,
182
+ variant='fp16',
183
+ apply_attention_guided_fusion=True
184
+ )
185
+
186
+ pipe.to(device)
187
+
188
+ @spaces.GPU
189
+ def process(image_state, interactive_state, mask_dropdown, guidance_scale, seed, num_inference_steps
190
+ ):
191
+ generator = torch.Generator(device="cuda").manual_seed(seed)
192
+ image_np = image_state["origin_image"]
193
+ image = Image.fromarray(image_np)
194
+ if interactive_state["masks"]:
195
+ if len(mask_dropdown) == 0:
196
+ mask_dropdown = ["mask_001"]
197
+ mask_dropdown.sort()
198
+ template_mask = interactive_state["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
199
+ for i in range(1,len(mask_dropdown)):
200
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
201
+ template_mask = np.clip(template_mask+interactive_state["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
202
+ image_state["mask"]= template_mask
203
+ else:
204
+ template_mask = image_state["mask"]
205
+ mask = Image.fromarray((template_mask).astype(np.uint8) * 255)
206
+ image_or = image.copy()
207
+
208
+ image = image.convert("RGB")
209
+ mask = mask.convert("RGB")
210
+
211
+ image = resize_by_short_side(image, 512, resample=Image.BICUBIC)
212
+ mask = resize_by_short_side(mask, 512, resample=Image.NEAREST)
213
+
214
+ w, h = image.size
215
+
216
+ result = pipe(
217
+ prompt="remove the instance of object",
218
+ image=image,
219
+ mask_image=mask,
220
+ generator=generator,
221
+ num_inference_steps=num_inference_steps,
222
+ guidance_scale=guidance_scale,
223
+ height=h,
224
+ width=w,
225
+ )
226
+ fused_img_pil = result.images[0]
227
+
228
+ return fused_img_pil.resize((image_or.size[:2])), (image.resize((image_or.size[:2])), fused_img_pil.resize((image_or.size[:2])))
229
+
230
+ import base64
231
+ with open("./Logo.png", "rb") as f:
232
+ img_bytes = f.read()
233
+ img_b64 = base64.b64encode(img_bytes).decode()
234
+
235
+ html_img = f'''
236
+ <div style="display:flex; justify-content:center; align-items:center; width:100%;">
237
+ <img src="data:image/png;base64,{img_b64}" style="border:none; width:200px; height:auto;"/>
238
+ </div>
239
+ '''
240
+
241
+ tutorial_url = "https://github.com/zjx0101/ObjectClear/releases/download/media/tutorial.mp4"
242
+ assets_path = os.path.join('/home/user/app/hugging_face/', "assets/")
243
+ load_file_from_url(tutorial_url, assets_path)
244
+
245
+ description = r"""
246
+ <b>Official Gradio demo</b> for <a href='https://github.com/zjx0101/ObjectClear' target='_blank'><b>ObjectClear: Complete Object Removal via Object-Effect Attention</b></a>.<br>
247
+ 🔥 ObjectClear is an object removal model that can jointly eliminate the target object and its associated effects leveraging Object-Effect Attention, while preserving background consistency.<br>
248
+ 🖼️ Try to drop your image, assign the target masks with a few clicks, and get the object removal results!<br>
249
+ *Note: All input images are temporarily resized (shorter side = 512 pixels) during inference to match the training resolution. Final outputs are restored to the original resolution.<br>*
250
+ """
251
+
252
+ article = r"""<h3>
253
+ <b>If ObjectClear is helpful, please help to star the <a href='https://github.com/zjx0101/ObjectClear' target='_blank'>Github Repo</a>. Thanks!</b></h3>
254
+ <hr>
255
+ 📑 **Citation**
256
+ <br>
257
+ If our work is useful for your research, please consider citing:
258
+ ```bibtex
259
+ @InProceedings{zhao2025ObjectClear,
260
+ title = {{ObjectClear}: Complete Object Removal via Object-Effect Attention},
261
+ author = {Zhao, Jixin and Zhou, Shangchen and Wang, Zhouxia and Yang, Peiqing and Loy, Chen Change},
262
+ booktitle = {arXiv preprint arXiv:2505.22636},
263
+ year = {2025}
264
+ }
265
+ ```
266
+ 📧 **Contact**
267
+ <br>
268
+ If you have any questions, please feel free to reach me out at <b>jixinzhao0101@gmail.com</b>.
269
+ <br>
270
+ 👏 **Acknowledgement**
271
+ <br>
272
+ This demo is adapted from [MatAnyone](https://github.com/pq-yang/MatAnyone), and leveraging segmentation capabilities from [Segment Anything](https://github.com/facebookresearch/segment-anything). Thanks for their awesome works!
273
+ """
274
+
275
+ custom_css = """
276
+ #input-image {
277
+ aspect-ratio: 1 / 1;
278
+ width: 100%;
279
+ max-width: 100%;
280
+ height: auto;
281
+ display: flex;
282
+ align-items: center;
283
+ justify-content: center;
284
+ }
285
+ #input-image img {
286
+ max-width: 100%;
287
+ max-height: 100%;
288
+ object-fit: contain;
289
+ display: block;
290
+ }
291
+ #main-columns {
292
+ gap: 60px;
293
+ }
294
+ #main-columns > .gr-column {
295
+ flex: 1;
296
+ }
297
+ #compare-image {
298
+ width: 100%;
299
+ aspect-ratio: 1 / 1;
300
+ display: flex;
301
+ align-items: center;
302
+ justify-content: center;
303
+ margin: 0;
304
+ padding: 0;
305
+ max-width: 100%;
306
+ box-sizing: border-box;
307
+ }
308
+ #compare-image svg.svelte-zyxd38 {
309
+ position: absolute !important;
310
+ top: 50% !important;
311
+ left: 50% !important;
312
+ transform: translate(-50%, -50%) !important;
313
+ }
314
+ #compare-image .icon.svelte-1oiin9d {
315
+ position: absolute;
316
+ top: 50%;
317
+ left: 50%;
318
+ transform: translate(-50%, -50%);
319
+ }
320
+ #compare-image {
321
+ position: relative;
322
+ overflow: hidden;
323
+ }
324
+ .new_button {background-color: #171717 !important; color: #ffffff !important; border: none !important;}
325
+ .new_button:hover {background-color: #4b4b4b !important;}
326
+ #start-button {
327
+ background: linear-gradient(135deg, #2575fc 0%, #6a11cb 100%);
328
+ color: white;
329
+ border: none;
330
+ padding: 12px 24px;
331
+ font-size: 16px;
332
+ font-weight: bold;
333
+ border-radius: 12px;
334
+ cursor: pointer;
335
+ box-shadow: 0 0 12px rgba(100, 100, 255, 0.7);
336
+ transition: all 0.3s ease;
337
+ }
338
+ #start-button:hover {
339
+ transform: scale(1.05);
340
+ box-shadow: 0 0 20px rgba(100, 100, 255, 1);
341
+ }
342
+ <style>
343
+ .button-wrapper {
344
+ width: 30%;
345
+ text-align: center;
346
+ }
347
+ .wide-button {
348
+ width: 83% !important;
349
+ background-color: black !important;
350
+ color: white !important;
351
+ border: none !important;
352
+ padding: 8px 0 !important;
353
+ font-size: 16px !important;
354
+ display: inline-block;
355
+ margin: 30px 0px 0px 50px ;
356
+ }
357
+ .wide-button:hover {
358
+ background-color: #656262 !important;
359
+ }
360
+ </style>
361
+ """
362
+
363
+
364
+ with gr.Blocks(css=custom_css) as demo:
365
+ gr.HTML(html_img)
366
+ gr.Markdown(description)
367
+ with gr.Group(elem_classes="gr-monochrome-group", visible=True):
368
+ with gr.Row():
369
+ with gr.Accordion('SAM Settings (click to expand)', open=False):
370
+ with gr.Row():
371
+ point_prompt = gr.Radio(
372
+ choices=["Positive", "Negative"],
373
+ value="Positive",
374
+ label="Point Prompt",
375
+ info="Click to add positive or negative point for target mask",
376
+ interactive=True,
377
+ min_width=100,
378
+ scale=1)
379
+ mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2")
380
+
381
+ with gr.Row(elem_id="main-columns"):
382
+ with gr.Column():
383
+
384
+ click_state = gr.State([[],[]])
385
+
386
+ interactive_state = gr.State(
387
+ {
388
+ "mask_names": [],
389
+ "masks": []
390
+ }
391
+ )
392
+
393
+ image_state = gr.State(
394
+ {
395
+ "origin_image": None,
396
+ "painted_image": None,
397
+ "mask": None,
398
+ "logit": None
399
+ }
400
+ )
401
+
402
+ image_info = gr.Textbox(label="Image Info", visible=False)
403
+ input_image = gr.Image(
404
+ label='Input',
405
+ type='pil',
406
+ sources=["upload"],
407
+ image_mode='RGB',
408
+ interactive=True,
409
+ elem_id="input-image"
410
+ )
411
+
412
+ with gr.Row(equal_height=True, elem_classes="mask_button_group"):
413
+ clear_button_click = gr.Button(value="Clear Clicks",elem_classes="new_button", min_width=100)
414
+ add_mask_button = gr.Button(value="Add Mask", elem_classes="new_button", min_width=100)
415
+ remove_mask_button = gr.Button(value="Delete Mask", elem_classes="new_button", min_width=100)
416
+
417
+ submit_button_component = gr.Button(
418
+ value='Start ObjectClear', elem_id="start-button"
419
+ )
420
+
421
+ with gr.Accordion('ObjectClear Settings', open=True):
422
+ guidance_scale = gr.Slider(
423
+ minimum=1, maximum=10, step=0.5, value=2.5,
424
+ label="Guidance Scale",
425
+ info="Higher = stronger removal; lower = better background preservation (default: 2.5)"
426
+ )
427
+
428
+ seed = gr.Slider(
429
+ minimum=0, maximum=1000000, step=1, value=300000,
430
+ label="Seed Value",
431
+ info="Different seeds can lead to noticeably different object removal results (default: 300000)"
432
+ )
433
+
434
+ num_inference_steps = gr.Slider(
435
+ minimum=1, maximum=40, step=1, value=20,
436
+ label="Num Inference Steps",
437
+ info="Higher values may improve quality but take longer (default: 20)"
438
+ )
439
+
440
+
441
+ with gr.Column():
442
+ output_image_component = gr.Image(
443
+ type='pil', image_mode='RGB', label='Output', format="png", elem_id="input-image")
444
+
445
+ output_compare_image_component = gr.ImageSlider(
446
+ label="Comparison",
447
+ type="pil",
448
+ format='png',
449
+ elem_id="compare-image"
450
+ )
451
+
452
+ input_image.upload(
453
+ fn=upload_and_reset,
454
+ inputs=[input_image, interactive_state],
455
+ outputs=[
456
+ image_state,
457
+ image_info,
458
+ input_image,
459
+ interactive_state,
460
+ click_state,
461
+ mask_dropdown,
462
+ ]
463
+ )
464
+
465
+ # click select image to get mask using sam
466
+ input_image.select(
467
+ fn=sam_refine,
468
+ inputs=[image_state, point_prompt, click_state],
469
+ outputs=[input_image, image_state, click_state]
470
+ )
471
+
472
+ # add different mask
473
+ add_mask_button.click(
474
+ fn=add_multi_mask,
475
+ inputs=[image_state, interactive_state, mask_dropdown],
476
+ outputs=[interactive_state, mask_dropdown, input_image, click_state]
477
+ )
478
+
479
+ remove_mask_button.click(
480
+ fn=remove_multi_mask,
481
+ inputs=[interactive_state, click_state, image_state],
482
+ outputs=[interactive_state, mask_dropdown, input_image, click_state]
483
+ )
484
+
485
+ # points clear
486
+ clear_button_click.click(
487
+ fn = clear_click,
488
+ inputs = [image_state, click_state,],
489
+ outputs = [input_image, click_state],
490
+ )
491
+
492
+ submit_button_component.click(
493
+ fn=process,
494
+ inputs=[
495
+ image_state,
496
+ interactive_state,
497
+ mask_dropdown,
498
+ guidance_scale,
499
+ seed,
500
+ num_inference_steps
501
+ ],
502
+ outputs=[
503
+ output_image_component, output_compare_image_component
504
+ ]
505
+ )
506
+
507
+ with gr.Accordion("📕 Video Tutorial (click to expand)", open=False, elem_classes="custom-bg"):
508
+ with gr.Row():
509
+ gr.Video(value="/home/user/app/hugging_face/assets/tutorial.mp4", elem_classes="video")
510
+
511
+ gr.Markdown("---")
512
+ gr.Markdown("## Examples")
513
+
514
+ example_images = [
515
+ os.path.join(os.path.dirname(__file__), "examples", f"test{i}.png")
516
+ for i in range(10)
517
+ ]
518
+
519
+ examples_data = [
520
+ [example_images[i], None] for i in range(len(example_images))
521
+ ]
522
+
523
+ examples = gr.Examples(
524
+ examples=examples_data,
525
+ inputs=[input_image, interactive_state],
526
+ outputs=[image_state, image_info, input_image,
527
+ interactive_state, click_state, mask_dropdown],
528
+ fn=upload_and_reset,
529
+ run_on_click=True,
530
+ cache_examples=False,
531
+ label="Click below to load example images"
532
+ )
533
+
534
+ gr.Markdown(article)
535
+
536
+ def pre_update_input_image():
537
+ return gr.update(value=None)
538
+
539
+ demo.load(
540
+ fn=pre_update_input_image,
541
+ inputs=[],
542
+ outputs=[input_image]
543
+ )
544
+
545
+
546
+ demo.launch(debug=True, show_error=True)