KenjieDec commited on
Commit
b8dc8f2
·
verified ·
1 Parent(s): 97c98cc

Removed "Mask Only" option

Browse files

- Removed "Mask Only" option as it now shows both mask and "RemBG"-ed image

Files changed (1) hide show
  1. app.py +104 -100
app.py CHANGED
@@ -1,100 +1,104 @@
1
- import gradio as gr
2
- import os
3
- import cv2
4
- from rembg import new_session, remove
5
- from rembg.bg import download_models
6
-
7
- def inference(file, mask, model, x, y):
8
- session = new_session(model)
9
-
10
- output = remove(
11
- file,
12
- session=session,
13
- **{ "sam_prompt": [{"type": "point", "data": [x, y], "label": 1}] },
14
- only_mask=(mask == "Mask only")
15
- )
16
-
17
- return output
18
-
19
- title = "RemBG"
20
- description = "Gradio demo for **[RemBG](https://github.com/danielgatis/rembg)**. To use it, simply upload your image, select a model, click Process, and wait."
21
- badge = """
22
- <div style="position: fixed; left: 50%; text-align: center;">
23
- <a href="https://github.com/danielgatis/rembg" target="_blank" style="text-decoration: none;">
24
- <img src="https://img.shields.io/badge/RemBG-Github-blue" alt="RemBG Github" />
25
- </a>
26
- </div>
27
- """
28
- def get_coords(evt: gr.SelectData) -> tuple:
29
- return evt.index[0], evt.index[1]
30
-
31
- def show_coords(model: str):
32
- visible = model == "sam"
33
- return gr.update(visible=visible), gr.update(visible=visible), gr.update(visible=visible)
34
-
35
- download_models(tuple())
36
-
37
- with gr.Blocks() as app:
38
- gr.Markdown(f"# {title}")
39
- gr.Markdown(description)
40
-
41
- with gr.Row():
42
- inputs = gr.Image(type="numpy", label="Input Image")
43
- outputs = gr.Image(label="Output Image")
44
-
45
- with gr.Row():
46
- mask_option = gr.Radio(
47
- ["Default", "Mask only"],
48
- value="Default",
49
- label="Output Type"
50
- )
51
- model_selector = gr.Dropdown(
52
- [
53
- "u2net",
54
- "u2netp",
55
- "u2net_human_seg",
56
- "u2net_cloth_seg",
57
- "silueta",
58
- "isnet-general-use",
59
- "isnet-anime",
60
- "sam",
61
- "bria-rmbg",
62
- "birefnet-general",
63
- "birefnet-general-lite",
64
- "birefnet-portrait",
65
- "birefnet-dis",
66
- "birefnet-hrsod",
67
- "birefnet-cod",
68
- "birefnet-massive",
69
- ],
70
- value="isnet-general-use",
71
- label="Model Selection"
72
- )
73
-
74
- extra = gr.Markdown("## Click on the image to capture coordinates (for SAM model)", visible=False)
75
-
76
- x = gr.Number(label="Mouse X Coordinate", visible=False)
77
- y = gr.Number(label="Mouse Y Coordinate", visible=False)
78
-
79
- model_selector.change(show_coords, inputs=model_selector, outputs=[x, y, extra])
80
- inputs.select(get_coords, None, [x, y])
81
-
82
-
83
- gr.Button("Process Image").click(
84
- inference,
85
- inputs=[inputs, mask_option, model_selector, x, y],
86
- outputs=outputs
87
- )
88
-
89
- gr.Examples(
90
- examples=[
91
- ["lion.png", "Default", "u2net", None, None],
92
- ["girl.jpg", "Default", "u2net", None, None],
93
- ["anime-girl.jpg", "Default", "isnet-anime", None, None]
94
- ],
95
- inputs=[inputs, mask_option, model_selector, x, y],
96
- outputs=outputs
97
- )
98
- gr.HTML(badge)
99
-
100
- app.launch(share=True)
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
+ from rembg import new_session, remove
6
+ from rembg.bg import download_models
7
+
8
+ def inference(file, model, x, y):
9
+ session = new_session(model)
10
+
11
+ mask = remove(
12
+ file,
13
+ session=session,
14
+ **{ "sam_prompt": [{"type": "point", "data": [x, y], "label": 1}] },
15
+ only_mask=True
16
+ )
17
+
18
+ print("Image Done")
19
+ if mask.shape[:2] != file.shape[:2]:
20
+ mask = cv2.resize(mask, (file.shape[1], file.shape[0]), interpolation=cv2.INTER_LANCZOS4)
21
+
22
+ image = cv2.cvtColor(file, cv2.COLOR_BGR2BGRA)
23
+ image[:, :, 3] = mask
24
+
25
+ return (image, mask)
26
+
27
+ title = "RemBG"
28
+ description = "Gradio demo for **[RemBG](https://github.com/danielgatis/rembg)**. To use it, simply upload your image, select a model, click Process, and wait."
29
+ badge = """
30
+ <div style="position: fixed; left: 50%; text-align: center;">
31
+ <a href="https://github.com/danielgatis/rembg" target="_blank" style="text-decoration: none;">
32
+ <img src="https://img.shields.io/badge/RemBG-Github-blue" alt="RemBG Github" />
33
+ </a>
34
+ </div>
35
+ """
36
+ def get_coords(evt: gr.SelectData) -> tuple:
37
+ return evt.index[0], evt.index[1]
38
+
39
+ def show_coords(model: str):
40
+ visible = model == "sam"
41
+ return gr.update(visible=visible), gr.update(visible=visible), gr.update(visible=visible)
42
+
43
+ download_models(tuple())
44
+
45
+ with gr.Blocks() as app:
46
+ gr.Markdown(f"# {title}")
47
+ gr.Markdown(description)
48
+
49
+ with gr.Row():
50
+ inputs = gr.Image(type="numpy", label="Input Image")
51
+ with gr.Column():
52
+ output_image = gr.Image(label="Output Image")
53
+ output_mask = gr.Image(label="Output Mask")
54
+
55
+ model_selector = gr.Dropdown(
56
+ [
57
+ "u2net",
58
+ "u2netp",
59
+ "u2net_human_seg",
60
+ "u2net_cloth_seg",
61
+ "silueta",
62
+ "isnet-general-use",
63
+ "isnet-anime",
64
+ "sam",
65
+ "bria-rmbg",
66
+ "birefnet-general",
67
+ "birefnet-general-lite",
68
+ "birefnet-portrait",
69
+ "birefnet-dis",
70
+ "birefnet-hrsod",
71
+ "birefnet-cod",
72
+ "birefnet-massive",
73
+ ],
74
+ value="isnet-general-use",
75
+ label="Model Selection"
76
+ )
77
+
78
+ extra = gr.Markdown("## Click on the image to capture coordinates (for SAM model)", visible=False)
79
+
80
+ x = gr.Number(label="Mouse X Coordinate", visible=False)
81
+ y = gr.Number(label="Mouse Y Coordinate", visible=False)
82
+
83
+ model_selector.change(show_coords, inputs=model_selector, outputs=[x, y, extra])
84
+ inputs.select(get_coords, None, [x, y])
85
+
86
+
87
+ gr.Button("Process Image").click(
88
+ inference,
89
+ inputs=[inputs, model_selector, x, y],
90
+ outputs=(output_image, output_mask)
91
+ )
92
+
93
+ gr.Examples(
94
+ examples=[
95
+ ["lion.png", "u2net", 100, None, None],
96
+ ["girl.jpg", "u2net", 100, None, None],
97
+ ["anime-girl.jpg", "isnet-anime", 100, None, None]
98
+ ],
99
+ inputs=[inputs, model_selector, x, y],
100
+ outputs=(output_image, output_mask)
101
+ )
102
+ gr.HTML(badge)
103
+
104
+ app.launch(share=True)