KenjieDec commited on
Commit
97c98cc
·
verified ·
1 Parent(s): 9ce2677

Add bria-rembg

Browse files
app.py CHANGED
@@ -1,109 +1,100 @@
1
- import gradio as gr
2
- import os
3
- import cv2
4
- from rembg import new_session, remove
5
- from rembg.sessions import sessions_class
6
-
7
- def inference(file, mask, model, x, y):
8
- im = cv2.imread(file, cv2.IMREAD_COLOR)
9
- input_path = "input.png"
10
- output_path = "output.png"
11
- cv2.imwrite(input_path, im)
12
-
13
- with open(input_path, 'rb') as i:
14
- with open(output_path, 'wb') as o:
15
- input = i.read()
16
- session = new_session(model)
17
-
18
- output = remove(
19
- input,
20
- session=session,
21
- **{ "sam_prompt": [{"type": "point", "data": [x, y], "label": 1}] },
22
- only_mask=(mask == "Mask only")
23
- )
24
- o.write(output)
25
-
26
- return output_path
27
-
28
- title = "RemBG"
29
- description = "Gradio demo for **[RemBG](https://github.com/danielgatis/rembg)**. To use it, simply upload your image, select a model, click Process, and wait."
30
- badge = """
31
- <div style="position: fixed; left: 50%; text-align: center;">
32
- <a href="https://github.com/danielgatis/rembg" target="_blank" style="text-decoration: none;">
33
- <img src="https://img.shields.io/badge/RemBG-Github-blue" alt="RemBG Github" />
34
- </a>
35
- </div>
36
- """
37
- def get_coords(evt: gr.SelectData) -> tuple:
38
- return evt.index[0], evt.index[1]
39
-
40
- def show_coords(model: str):
41
- visible = model == "sam"
42
- return gr.update(visible=visible), gr.update(visible=visible), gr.update(visible=visible)
43
-
44
- for session in sessions_class:
45
- session.download_models()
46
-
47
- with gr.Blocks() as app:
48
- gr.Markdown(f"# {title}")
49
- gr.Markdown(description)
50
-
51
- with gr.Row():
52
- inputs = gr.Image(type="filepath", label="Input Image")
53
- outputs = gr.Image(type="filepath", label="Output Image")
54
-
55
- with gr.Row():
56
- mask_option = gr.Radio(
57
- ["Default", "Mask only"],
58
- value="Default",
59
- label="Output Type"
60
- )
61
- model_selector = gr.Dropdown(
62
- [
63
- "u2net",
64
- "u2netp",
65
- "u2net_human_seg",
66
- "u2net_cloth_seg",
67
- "silueta",
68
- "isnet-general-use",
69
- "isnet-anime",
70
- "sam",
71
- "birefnet-general",
72
- "birefnet-general-lite",
73
- "birefnet-portrait",
74
- "birefnet-dis",
75
- "birefnet-hrsod",
76
- "birefnet-cod",
77
- "birefnet-massive"
78
- ],
79
- value="isnet-general-use",
80
- label="Model Selection"
81
- )
82
-
83
- extra = gr.Markdown("## Click on the image to capture coordinates (for SAM model)", visible=False)
84
-
85
- x = gr.Number(label="Mouse X Coordinate", visible=False)
86
- y = gr.Number(label="Mouse Y Coordinate", visible=False)
87
-
88
- model_selector.change(show_coords, inputs=model_selector, outputs=[x, y, extra])
89
- inputs.select(get_coords, None, [x, y])
90
-
91
-
92
- gr.Button("Process Image").click(
93
- inference,
94
- inputs=[inputs, mask_option, model_selector, x, y],
95
- outputs=outputs
96
- )
97
-
98
- gr.Examples(
99
- examples=[
100
- ["lion.png", "Default", "u2net", None, None],
101
- ["girl.jpg", "Default", "u2net", None, None],
102
- ["anime-girl.jpg", "Default", "isnet-anime", None, None]
103
- ],
104
- inputs=[inputs, mask_option, model_selector, x, y],
105
- outputs=outputs
106
- )
107
- gr.HTML(badge)
108
-
109
- app.launch(share=True)
 
1
+ import gradio as gr
2
+ import os
3
+ import cv2
4
+ from rembg import new_session, remove
5
+ from rembg.bg import download_models
6
+
7
+ def inference(file, mask, model, x, y):
8
+ session = new_session(model)
9
+
10
+ output = remove(
11
+ file,
12
+ session=session,
13
+ **{ "sam_prompt": [{"type": "point", "data": [x, y], "label": 1}] },
14
+ only_mask=(mask == "Mask only")
15
+ )
16
+
17
+ return output
18
+
19
+ title = "RemBG"
20
+ description = "Gradio demo for **[RemBG](https://github.com/danielgatis/rembg)**. To use it, simply upload your image, select a model, click Process, and wait."
21
+ badge = """
22
+ <div style="position: fixed; left: 50%; text-align: center;">
23
+ <a href="https://github.com/danielgatis/rembg" target="_blank" style="text-decoration: none;">
24
+ <img src="https://img.shields.io/badge/RemBG-Github-blue" alt="RemBG Github" />
25
+ </a>
26
+ </div>
27
+ """
28
+ def get_coords(evt: gr.SelectData) -> tuple:
29
+ return evt.index[0], evt.index[1]
30
+
31
+ def show_coords(model: str):
32
+ visible = model == "sam"
33
+ return gr.update(visible=visible), gr.update(visible=visible), gr.update(visible=visible)
34
+
35
+ download_models(tuple())
36
+
37
+ with gr.Blocks() as app:
38
+ gr.Markdown(f"# {title}")
39
+ gr.Markdown(description)
40
+
41
+ with gr.Row():
42
+ inputs = gr.Image(type="numpy", label="Input Image")
43
+ outputs = gr.Image(label="Output Image")
44
+
45
+ with gr.Row():
46
+ mask_option = gr.Radio(
47
+ ["Default", "Mask only"],
48
+ value="Default",
49
+ label="Output Type"
50
+ )
51
+ model_selector = gr.Dropdown(
52
+ [
53
+ "u2net",
54
+ "u2netp",
55
+ "u2net_human_seg",
56
+ "u2net_cloth_seg",
57
+ "silueta",
58
+ "isnet-general-use",
59
+ "isnet-anime",
60
+ "sam",
61
+ "bria-rmbg",
62
+ "birefnet-general",
63
+ "birefnet-general-lite",
64
+ "birefnet-portrait",
65
+ "birefnet-dis",
66
+ "birefnet-hrsod",
67
+ "birefnet-cod",
68
+ "birefnet-massive",
69
+ ],
70
+ value="isnet-general-use",
71
+ label="Model Selection"
72
+ )
73
+
74
+ extra = gr.Markdown("## Click on the image to capture coordinates (for SAM model)", visible=False)
75
+
76
+ x = gr.Number(label="Mouse X Coordinate", visible=False)
77
+ y = gr.Number(label="Mouse Y Coordinate", visible=False)
78
+
79
+ model_selector.change(show_coords, inputs=model_selector, outputs=[x, y, extra])
80
+ inputs.select(get_coords, None, [x, y])
81
+
82
+
83
+ gr.Button("Process Image").click(
84
+ inference,
85
+ inputs=[inputs, mask_option, model_selector, x, y],
86
+ outputs=outputs
87
+ )
88
+
89
+ gr.Examples(
90
+ examples=[
91
+ ["lion.png", "Default", "u2net", None, None],
92
+ ["girl.jpg", "Default", "u2net", None, None],
93
+ ["anime-girl.jpg", "Default", "isnet-anime", None, None]
94
+ ],
95
+ inputs=[inputs, mask_option, model_selector, x, y],
96
+ outputs=outputs
97
+ )
98
+ gr.HTML(badge)
99
+
100
+ app.launch(share=True)
 
 
 
 
 
 
 
 
 
rembg/_version.py CHANGED
@@ -23,9 +23,9 @@ def get_keywords():
23
  # setup.py/versioneer.py will grep for the variable names, so they must
24
  # each be defined on a line of their own. _version.py will just call
25
  # get_keywords().
26
- git_refnames = " (HEAD -> main)"
27
- git_full = "e740a9681ea32f5c34adce52aa7cc0b4b85bbb11"
28
- git_date = "2024-11-20 09:41:13 -0300"
29
  keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
30
  return keywords
31
 
 
23
  # setup.py/versioneer.py will grep for the variable names, so they must
24
  # each be defined on a line of their own. _version.py will just call
25
  # get_keywords().
26
+ git_refnames = " (HEAD -> main, tag: v2.0.69)"
27
+ git_full = "df72e3dea3f41e543a13991cb05b8a2659ee95c1"
28
+ git_date = "2025-12-04 18:05:12 -0300"
29
  keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
30
  return keywords
31
 
rembg/bg.py CHANGED
@@ -1,4 +1,5 @@
1
  import io
 
2
  from enum import Enum
3
  from typing import Any, List, Optional, Tuple, Union, cast
4
 
@@ -20,7 +21,7 @@ from pymatting.util.util import stack_images
20
  from scipy.ndimage import binary_erosion
21
 
22
  from .session_factory import new_session
23
- from .sessions import sessions_class
24
  from .sessions.base import BaseSession
25
 
26
  ort.set_default_logger_severity(3)
@@ -175,9 +176,8 @@ def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> P
175
  Returns:
176
  PILImage: The modified image with the background color applied.
177
  """
178
- r, g, b, a = color
179
- colored_image = Image.new("RGBA", img.size, (r, g, b, a))
180
- colored_image.paste(img, mask=img)
181
 
182
  return colored_image
183
 
@@ -195,12 +195,25 @@ def fix_image_orientation(img: PILImage) -> PILImage:
195
  return cast(PILImage, ImageOps.exif_transpose(img))
196
 
197
 
198
- def download_models() -> None:
199
  """
200
  Download models for image processing.
201
  """
202
- for session in sessions_class:
203
- session.download_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
 
206
  def remove(
@@ -215,7 +228,7 @@ def remove(
215
  bgcolor: Optional[Tuple[int, int, int, int]] = None,
216
  force_return_bytes: bool = False,
217
  *args: Optional[Any],
218
- **kwargs: Optional[Any]
219
  ) -> Union[bytes, PILImage, np.ndarray]:
220
  """
221
  Remove the background from an input image.
 
1
  import io
2
+ import sys
3
  from enum import Enum
4
  from typing import Any, List, Optional, Tuple, Union, cast
5
 
 
21
  from scipy.ndimage import binary_erosion
22
 
23
  from .session_factory import new_session
24
+ from .sessions import sessions, sessions_names
25
  from .sessions.base import BaseSession
26
 
27
  ort.set_default_logger_severity(3)
 
176
  Returns:
177
  PILImage: The modified image with the background color applied.
178
  """
179
+ background = Image.new("RGBA", img.size, tuple(color))
180
+ colored_image = Image.alpha_composite(background, img)
 
181
 
182
  return colored_image
183
 
 
195
  return cast(PILImage, ImageOps.exif_transpose(img))
196
 
197
 
198
+ def download_models(models: tuple[str, ...]) -> None:
199
  """
200
  Download models for image processing.
201
  """
202
+ if len(models) == 0:
203
+ print("No models specified, downloading all models")
204
+ models = tuple(sessions_names)
205
+
206
+ for model in models:
207
+ session = sessions.get(model)
208
+ if session is None:
209
+ print(f"Error: no model found: {model}")
210
+ sys.exit(1)
211
+ else:
212
+ print(f"Downloading model: {model}")
213
+ try:
214
+ session.download_models()
215
+ except Exception as e:
216
+ print(f"Error downloading model: {e}")
217
 
218
 
219
  def remove(
 
228
  bgcolor: Optional[Tuple[int, int, int, int]] = None,
229
  force_return_bytes: bool = False,
230
  *args: Optional[Any],
231
+ **kwargs: Optional[Any],
232
  ) -> Union[bytes, PILImage, np.ndarray]:
233
  """
234
  Remove the background from an input image.
rembg/commands/b_command.py CHANGED
@@ -6,7 +6,7 @@ import sys
6
  from typing import IO
7
 
8
  import click
9
- from PIL.Image import Image as PILImage
10
 
11
  from ..bg import remove
12
  from ..session_factory import new_session
@@ -118,10 +118,11 @@ def b_command(
118
  Returns:
119
  None
120
  """
121
- try:
122
- kwargs.update(json.loads(extras))
123
- except Exception:
124
- pass
 
125
 
126
  session = new_session(model, **kwargs)
127
  bytes_per_img = image_width * image_height * 3
@@ -134,7 +135,7 @@ def b_command(
134
  if not os.path.isdir(output_dir):
135
  os.makedirs(output_dir, exist_ok=True)
136
 
137
- def img_to_byte_array(img: PILImage) -> bytes:
138
  buff = io.BytesIO()
139
  img.save(buff, format="PNG")
140
  return buff.getvalue()
@@ -162,7 +163,7 @@ def b_command(
162
  if not img_bytes:
163
  break
164
 
165
- img = PILImage.frombytes("RGB", (image_width, image_height), img_bytes)
166
  output = remove(img, session=session, **kwargs)
167
 
168
  if output_specifier:
 
6
  from typing import IO
7
 
8
  import click
9
+ import PIL
10
 
11
  from ..bg import remove
12
  from ..session_factory import new_session
 
118
  Returns:
119
  None
120
  """
121
+ if extras:
122
+ try:
123
+ kwargs.update(json.loads(extras))
124
+ except Exception:
125
+ raise click.BadParameter("extras must be a valid JSON string")
126
 
127
  session = new_session(model, **kwargs)
128
  bytes_per_img = image_width * image_height * 3
 
135
  if not os.path.isdir(output_dir):
136
  os.makedirs(output_dir, exist_ok=True)
137
 
138
+ def img_to_byte_array(img: PIL.Image.Image) -> bytes:
139
  buff = io.BytesIO()
140
  img.save(buff, format="PNG")
141
  return buff.getvalue()
 
163
  if not img_bytes:
164
  break
165
 
166
+ img = PIL.Image.frombytes("RGB", (image_width, image_height), img_bytes)
167
  output = remove(img, session=session, **kwargs)
168
 
169
  if output_specifier:
rembg/commands/d_command.py CHANGED
@@ -5,10 +5,11 @@ from ..bg import download_models
5
 
6
  @click.command( # type: ignore
7
  name="d",
8
- help="download all models",
9
  )
10
- def d_command(*args, **kwargs) -> None:
 
11
  """
12
- Download all models
13
  """
14
- download_models()
 
5
 
6
  @click.command( # type: ignore
7
  name="d",
8
+ help="download models",
9
  )
10
+ @click.argument("models", nargs=-1)
11
+ def d_command(models: tuple[str, ...]) -> None:
12
  """
13
+ Download models
14
  """
15
+ download_models(models)
rembg/commands/p_command.py CHANGED
@@ -185,8 +185,7 @@ def p_command(
185
  print(e)
186
 
187
  inputs = list(input.glob("**/*"))
188
- if not watch:
189
- inputs_tqdm = tqdm(inputs)
190
 
191
  for each_input in inputs_tqdm:
192
  if not each_input.is_dir():
 
185
  print(e)
186
 
187
  inputs = list(input.glob("**/*"))
188
+ inputs_tqdm = inputs if watch else tqdm(inputs)
 
189
 
190
  for each_input in inputs_tqdm:
191
  if not each_input.is_dir():
rembg/commands/s_command.py CHANGED
@@ -197,12 +197,15 @@ def s_command(port: int, host: str, log_level: str, threads: int) -> None:
197
  except Exception:
198
  pass
199
 
 
 
 
 
 
200
  return Response(
201
  remove(
202
  content,
203
- session=sessions.setdefault(
204
- commons.model, new_session(commons.model, **kwargs)
205
- ),
206
  alpha_matting=commons.a,
207
  alpha_matting_foreground_threshold=commons.af,
208
  alpha_matting_background_threshold=commons.ab,
@@ -306,6 +309,7 @@ def s_command(port: int, host: str, log_level: str, threads: int) -> None:
306
  ],
307
  gr.components.Image(type="filepath", label="Output"),
308
  concurrency_limit=3,
 
309
  )
310
 
311
  app = gr.mount_gradio_app(app, interface, path="/")
 
197
  except Exception:
198
  pass
199
 
200
+ session = sessions.get(commons.model)
201
+ if session is None:
202
+ session = new_session(commons.model, **kwargs)
203
+ sessions[commons.model] = session
204
+
205
  return Response(
206
  remove(
207
  content,
208
+ session=session,
 
 
209
  alpha_matting=commons.a,
210
  alpha_matting_foreground_threshold=commons.af,
211
  alpha_matting_background_threshold=commons.ab,
 
309
  ],
310
  gr.components.Image(type="filepath", label="Output"),
311
  concurrency_limit=3,
312
+ analytics_enabled=False,
313
  )
314
 
315
  app = gr.mount_gradio_app(app, interface, path="/")
rembg/session_factory.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import Type
3
 
4
  import onnxruntime as ort
5
 
@@ -8,9 +8,7 @@ from .sessions.base import BaseSession
8
  from .sessions.u2net import U2netSession
9
 
10
 
11
- def new_session(
12
- model_name: str = "u2net", providers=None, *args, **kwargs
13
- ) -> BaseSession:
14
  """
15
  Create a new session object based on the specified model name.
16
 
@@ -21,24 +19,30 @@ def new_session(
21
 
22
  Parameters:
23
  model_name (str): The name of the model.
24
- providers: The providers for the session.
25
  *args: Additional positional arguments.
26
  **kwargs: Additional keyword arguments.
27
 
 
 
 
28
  Returns:
29
  BaseSession: The created session object.
30
  """
31
- session_class: Type[BaseSession] = U2netSession
32
 
33
  for sc in sessions_class:
34
  if sc.name() == model_name:
35
  session_class = sc
36
  break
37
 
 
 
 
38
  sess_opts = ort.SessionOptions()
39
 
40
  if "OMP_NUM_THREADS" in os.environ:
41
- sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
42
- sess_opts.intra_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
 
43
 
44
- return session_class(model_name, sess_opts, providers, *args, **kwargs)
 
1
  import os
2
+ from typing import Optional, Type
3
 
4
  import onnxruntime as ort
5
 
 
8
  from .sessions.u2net import U2netSession
9
 
10
 
11
+ def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
 
 
12
  """
13
  Create a new session object based on the specified model name.
14
 
 
19
 
20
  Parameters:
21
  model_name (str): The name of the model.
 
22
  *args: Additional positional arguments.
23
  **kwargs: Additional keyword arguments.
24
 
25
+ Raises:
26
+ ValueError: If no session class with the given `model_name` is found.
27
+
28
  Returns:
29
  BaseSession: The created session object.
30
  """
31
+ session_class: Optional[Type[BaseSession]] = None
32
 
33
  for sc in sessions_class:
34
  if sc.name() == model_name:
35
  session_class = sc
36
  break
37
 
38
+ if session_class is None:
39
+ raise ValueError(f"No session class found for model '{model_name}'")
40
+
41
  sess_opts = ort.SessionOptions()
42
 
43
  if "OMP_NUM_THREADS" in os.environ:
44
+ threads = int(os.environ["OMP_NUM_THREADS"])
45
+ sess_opts.inter_op_num_threads = threads
46
+ sess_opts.intra_op_num_threads = threads
47
 
48
+ return session_class(model_name, sess_opts, *args, **kwargs)
rembg/sessions/__init__.py CHANGED
@@ -1,88 +1,86 @@
1
  from __future__ import annotations
2
 
3
- from typing import List
4
 
5
  from .base import BaseSession
6
 
7
- sessions_class: List[type[BaseSession]] = []
8
- sessions_names: List[str] = []
9
 
10
  from .birefnet_general import BiRefNetSessionGeneral
11
 
12
- sessions_class.append(BiRefNetSessionGeneral)
13
- sessions_names.append(BiRefNetSessionGeneral.name())
14
 
15
  from .birefnet_general_lite import BiRefNetSessionGeneralLite
16
 
17
- sessions_class.append(BiRefNetSessionGeneralLite)
18
- sessions_names.append(BiRefNetSessionGeneralLite.name())
19
 
20
  from .birefnet_portrait import BiRefNetSessionPortrait
21
 
22
- sessions_class.append(BiRefNetSessionPortrait)
23
- sessions_names.append(BiRefNetSessionPortrait.name())
24
 
25
  from .birefnet_dis import BiRefNetSessionDIS
26
 
27
- sessions_class.append(BiRefNetSessionDIS)
28
- sessions_names.append(BiRefNetSessionDIS.name())
29
 
30
  from .birefnet_hrsod import BiRefNetSessionHRSOD
31
 
32
- sessions_class.append(BiRefNetSessionHRSOD)
33
- sessions_names.append(BiRefNetSessionHRSOD.name())
34
 
35
  from .birefnet_cod import BiRefNetSessionCOD
36
 
37
- sessions_class.append(BiRefNetSessionCOD)
38
- sessions_names.append(BiRefNetSessionCOD.name())
39
 
40
  from .birefnet_massive import BiRefNetSessionMassive
41
 
42
- sessions_class.append(BiRefNetSessionMassive)
43
- sessions_names.append(BiRefNetSessionMassive.name())
44
 
45
  from .dis_anime import DisSession
46
 
47
- sessions_class.append(DisSession)
48
- sessions_names.append(DisSession.name())
 
 
 
49
 
50
  from .dis_general_use import DisSession as DisSessionGeneralUse
51
 
52
- sessions_class.append(DisSessionGeneralUse)
53
- sessions_names.append(DisSessionGeneralUse.name())
54
 
55
  from .sam import SamSession
56
 
57
- sessions_class.append(SamSession)
58
- sessions_names.append(SamSession.name())
59
 
60
  from .silueta import SiluetaSession
61
 
62
- sessions_class.append(SiluetaSession)
63
- sessions_names.append(SiluetaSession.name())
64
 
65
  from .u2net_cloth_seg import Unet2ClothSession
66
 
67
- sessions_class.append(Unet2ClothSession)
68
- sessions_names.append(Unet2ClothSession.name())
69
 
70
  from .u2net_custom import U2netCustomSession
71
 
72
- sessions_class.append(U2netCustomSession)
73
- sessions_names.append(U2netCustomSession.name())
74
 
75
  from .u2net_human_seg import U2netHumanSegSession
76
 
77
- sessions_class.append(U2netHumanSegSession)
78
- sessions_names.append(U2netHumanSegSession.name())
79
 
80
  from .u2net import U2netSession
81
 
82
- sessions_class.append(U2netSession)
83
- sessions_names.append(U2netSession.name())
84
 
85
  from .u2netp import U2netpSession
86
 
87
- sessions_class.append(U2netpSession)
88
- sessions_names.append(U2netpSession.name())
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ from typing import Dict, List
4
 
5
  from .base import BaseSession
6
 
7
+ sessions: Dict[str, type[BaseSession]] = {}
 
8
 
9
  from .birefnet_general import BiRefNetSessionGeneral
10
 
11
+ sessions[BiRefNetSessionGeneral.name()] = BiRefNetSessionGeneral
 
12
 
13
  from .birefnet_general_lite import BiRefNetSessionGeneralLite
14
 
15
+ sessions[BiRefNetSessionGeneralLite.name()] = BiRefNetSessionGeneralLite
 
16
 
17
  from .birefnet_portrait import BiRefNetSessionPortrait
18
 
19
+ sessions[BiRefNetSessionPortrait.name()] = BiRefNetSessionPortrait
 
20
 
21
  from .birefnet_dis import BiRefNetSessionDIS
22
 
23
+ sessions[BiRefNetSessionDIS.name()] = BiRefNetSessionDIS
 
24
 
25
  from .birefnet_hrsod import BiRefNetSessionHRSOD
26
 
27
+ sessions[BiRefNetSessionHRSOD.name()] = BiRefNetSessionHRSOD
 
28
 
29
  from .birefnet_cod import BiRefNetSessionCOD
30
 
31
+ sessions[BiRefNetSessionCOD.name()] = BiRefNetSessionCOD
 
32
 
33
  from .birefnet_massive import BiRefNetSessionMassive
34
 
35
+ sessions[BiRefNetSessionMassive.name()] = BiRefNetSessionMassive
 
36
 
37
  from .dis_anime import DisSession
38
 
39
+ sessions[DisSession.name()] = DisSession
40
+
41
+ from .dis_custom import DisCustomSession
42
+
43
+ sessions[DisCustomSession.name()] = DisCustomSession
44
 
45
  from .dis_general_use import DisSession as DisSessionGeneralUse
46
 
47
+ sessions[DisSessionGeneralUse.name()] = DisSessionGeneralUse
 
48
 
49
  from .sam import SamSession
50
 
51
+ sessions[SamSession.name()] = SamSession
 
52
 
53
  from .silueta import SiluetaSession
54
 
55
+ sessions[SiluetaSession.name()] = SiluetaSession
 
56
 
57
  from .u2net_cloth_seg import Unet2ClothSession
58
 
59
+ sessions[Unet2ClothSession.name()] = Unet2ClothSession
 
60
 
61
  from .u2net_custom import U2netCustomSession
62
 
63
+ sessions[U2netCustomSession.name()] = U2netCustomSession
 
64
 
65
  from .u2net_human_seg import U2netHumanSegSession
66
 
67
+ sessions[U2netHumanSegSession.name()] = U2netHumanSegSession
 
68
 
69
  from .u2net import U2netSession
70
 
71
+ sessions[U2netSession.name()] = U2netSession
 
72
 
73
  from .u2netp import U2netpSession
74
 
75
+ sessions[U2netpSession.name()] = U2netpSession
76
+
77
+ from .bria_rmbg import BriaRmBgSession
78
+
79
+ sessions[BriaRmBgSession.name()] = BriaRmBgSession
80
+
81
+ from .ben_custom import BenCustomSession
82
+
83
+ sessions[BenCustomSession.name()] = BenCustomSession
84
+
85
+ sessions_names = list(sessions.keys())
86
+ sessions_class = list(sessions.values())
rembg/sessions/base.py CHANGED
@@ -10,31 +10,31 @@ from PIL.Image import Image as PILImage
10
  class BaseSession:
11
  """This is a base class for managing a session with a machine learning model."""
12
 
13
- def __init__(
14
- self,
15
- model_name: str,
16
- sess_opts: ort.SessionOptions,
17
- providers=None,
18
- *args,
19
- **kwargs
20
- ):
21
  """Initialize an instance of the BaseSession class."""
22
  self.model_name = model_name
23
 
24
- self.providers = []
25
-
26
- _providers = ort.get_available_providers()
27
- if providers:
28
- for provider in providers:
29
- if provider in _providers:
30
- self.providers.append(provider)
31
  else:
32
- self.providers.extend(_providers)
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  self.inner_session = ort.InferenceSession(
35
  str(self.__class__.download_models(*args, **kwargs)),
36
- providers=self.providers,
37
  sess_options=sess_opts,
 
38
  )
39
 
40
  def normalize(
@@ -49,7 +49,7 @@ class BaseSession:
49
  im = img.convert("RGB").resize(size, Image.Resampling.LANCZOS)
50
 
51
  im_ary = np.array(im)
52
- im_ary = im_ary / np.max(im_ary)
53
 
54
  tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
55
  tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
 
10
  class BaseSession:
11
  """This is a base class for managing a session with a machine learning model."""
12
 
13
+ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
 
 
 
 
 
 
 
14
  """Initialize an instance of the BaseSession class."""
15
  self.model_name = model_name
16
 
17
+ if "providers" in kwargs and isinstance(kwargs["providers"], list):
18
+ providers = kwargs.pop("providers")
 
 
 
 
 
19
  else:
20
+ device_type = ort.get_device()
21
+ if (
22
+ device_type == "GPU"
23
+ and "CUDAExecutionProvider" in ort.get_available_providers()
24
+ ):
25
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
26
+ elif (
27
+ device_type[0:3] == "GPU"
28
+ and "ROCMExecutionProvider" in ort.get_available_providers()
29
+ ):
30
+ providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
31
+ else:
32
+ providers = ["CPUExecutionProvider"]
33
 
34
  self.inner_session = ort.InferenceSession(
35
  str(self.__class__.download_models(*args, **kwargs)),
 
36
  sess_options=sess_opts,
37
+ providers=providers,
38
  )
39
 
40
  def normalize(
 
49
  im = img.convert("RGB").resize(size, Image.Resampling.LANCZOS)
50
 
51
  im_ary = np.array(im)
52
+ im_ary = im_ary / max(np.max(im_ary), 1e-6)
53
 
54
  tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
55
  tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
rembg/sessions/ben_custom.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class BenCustomSession(BaseSession):
13
+ """This is a class representing a custom session for the Ben model."""
14
+
15
+ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
16
+ """
17
+ Initialize a new BenCustomSession object.
18
+
19
+ Parameters:
20
+ model_name (str): The name of the model.
21
+ sess_opts: The session options.
22
+ *args: Additional positional arguments.
23
+ **kwargs: Additional keyword arguments.
24
+ """
25
+ model_path = kwargs.get("model_path")
26
+ if model_path is None:
27
+ raise ValueError("model_path is required")
28
+
29
+ super().__init__(model_name, sess_opts, *args, **kwargs)
30
+
31
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
32
+ """
33
+ Predicts the mask image for the input image.
34
+
35
+ This method takes a PILImage object as input and returns a list of PILImage objects as output. It performs several image processing operations to generate the mask image.
36
+
37
+ Parameters:
38
+ img (PILImage): The input image.
39
+
40
+ Returns:
41
+ List[PILImage]: A list of PILImage objects representing the generated mask image.
42
+ """
43
+
44
+ ort_outs = self.inner_session.run(
45
+ None,
46
+ self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (1024, 1024)),
47
+ )
48
+
49
+ pred = ort_outs[0][:, 0, :, :]
50
+
51
+ ma = np.max(pred)
52
+ mi = np.min(pred)
53
+
54
+ pred = (pred - mi) / (ma - mi)
55
+ pred = np.squeeze(pred)
56
+
57
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
58
+ mask = mask.resize(img.size, Image.Resampling.LANCZOS)
59
+
60
+ return [mask]
61
+
62
+ @classmethod
63
+ def download_models(cls, *args, **kwargs):
64
+ """
65
+ Download the model files.
66
+
67
+ Parameters:
68
+ *args: Additional positional arguments.
69
+ **kwargs: Additional keyword arguments.
70
+
71
+ Returns:
72
+ str: The absolute path to the model files.
73
+ """
74
+ model_path = kwargs.get("model_path")
75
+ if model_path is None:
76
+ raise ValueError("model_path is required")
77
+
78
+ return os.path.abspath(os.path.expanduser(model_path))
79
+
80
+ @classmethod
81
+ def name(cls, *args, **kwargs):
82
+ """
83
+ Get the name of the model.
84
+
85
+ Parameters:
86
+ *args: Additional positional arguments.
87
+ **kwargs: Additional keyword arguments.
88
+
89
+ Returns:
90
+ str: The name of the model.
91
+ """
92
+ return "ben_custom"
rembg/sessions/bria_rmbg.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class BriaRmBgSession(BaseSession):
13
+ """
14
+ This class represents a Bria-rmbg-2.0 session, which is a subclass of BaseSession.
15
+ """
16
+
17
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
18
+ """
19
+ Predicts the output masks for the input image using the inner session.
20
+
21
+ Parameters:
22
+ img (PILImage): The input image.
23
+ *args: Additional positional arguments.
24
+ **kwargs: Additional keyword arguments.
25
+
26
+ Returns:
27
+ List[PILImage]: The list of output masks.
28
+ """
29
+ ort_outs = self.inner_session.run(
30
+ None,
31
+ self.normalize(
32
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (1024, 1024)
33
+ ),
34
+ )
35
+
36
+ pred = ort_outs[0][:, 0, :, :]
37
+
38
+ ma = np.max(pred)
39
+ mi = np.min(pred)
40
+
41
+ pred = (pred - mi) / (ma - mi)
42
+ pred = np.squeeze(pred)
43
+
44
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
45
+ mask = mask.resize(img.size, Image.Resampling.LANCZOS)
46
+
47
+ return [mask]
48
+
49
+ @classmethod
50
+ def download_models(cls, *args, **kwargs):
51
+ """
52
+ Downloads the BRIA-RMBG 2.0 model file from a specific URL and saves it.
53
+
54
+ Parameters:
55
+ *args: Additional positional arguments.
56
+ **kwargs: Additional keyword arguments.
57
+
58
+ Returns:
59
+ str: The path to the downloaded model file.
60
+ """
61
+ fname = f"{cls.name(*args, **kwargs)}.onnx"
62
+ pooch.retrieve(
63
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx",
64
+ (
65
+ None
66
+ if cls.checksum_disabled(*args, **kwargs)
67
+ else "sha256:5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958"
68
+ ),
69
+ fname=fname,
70
+ path=cls.u2net_home(*args, **kwargs),
71
+ progressbar=True,
72
+ )
73
+
74
+ return os.path.join(cls.u2net_home(*args, **kwargs), fname)
75
+
76
+ @classmethod
77
+ def name(cls, *args, **kwargs):
78
+ """
79
+ Returns the name of the Bria-rmbg session.
80
+
81
+ Parameters:
82
+ *args: Additional positional arguments.
83
+ **kwargs: Additional keyword arguments.
84
+
85
+ Returns:
86
+ str: The name of the session.
87
+ """
88
+ return "bria-rmbg"
rembg/sessions/dis_custom.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class DisCustomSession(BaseSession):
13
+ """This is a class representing a custom session for the Dis model."""
14
+
15
+ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
16
+ """
17
+ Initialize a new DisCustomSession object.
18
+
19
+ Parameters:
20
+ model_name (str): The name of the model.
21
+ sess_opts: The session options.
22
+ *args: Additional positional arguments.
23
+ **kwargs: Additional keyword arguments.
24
+ """
25
+ model_path = kwargs.get("model_path")
26
+ if model_path is None:
27
+ raise ValueError("model_path is required")
28
+
29
+ super().__init__(model_name, sess_opts, *args, **kwargs)
30
+
31
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
32
+ """
33
+ Predicts the mask image for the input image.
34
+
35
+ This method takes a PILImage object as input and returns a list of PILImage objects as output. It performs several image processing operations to generate the mask image.
36
+
37
+ Parameters:
38
+ img (PILImage): The input image.
39
+
40
+ Returns:
41
+ List[PILImage]: A list of PILImage objects representing the generated mask image.
42
+ """
43
+ ort_outs = self.inner_session.run(
44
+ None,
45
+ self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (1024, 1024)),
46
+ )
47
+
48
+ pred = ort_outs[0][:, 0, :, :]
49
+
50
+ ma = np.max(pred)
51
+ mi = np.min(pred)
52
+
53
+ pred = (pred - mi) / (ma - mi)
54
+ pred = np.squeeze(pred)
55
+
56
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
57
+ mask = mask.resize(img.size, Image.Resampling.LANCZOS)
58
+
59
+ return [mask]
60
+
61
+ @classmethod
62
+ def download_models(cls, *args, **kwargs):
63
+ """
64
+ Download the model files.
65
+
66
+ Parameters:
67
+ *args: Additional positional arguments.
68
+ **kwargs: Additional keyword arguments.
69
+
70
+ Returns:
71
+ str: The absolute path to the model files.
72
+ """
73
+ model_path = kwargs.get("model_path")
74
+ if model_path is None:
75
+ raise ValueError("model_path is required")
76
+
77
+ return os.path.abspath(os.path.expanduser(model_path))
78
+
79
+ @classmethod
80
+ def name(cls, *args, **kwargs):
81
+ """
82
+ Get the name of the model.
83
+
84
+ Parameters:
85
+ *args: Additional positional arguments.
86
+ **kwargs: Additional keyword arguments.
87
+
88
+ Returns:
89
+ str: The name of the model.
90
+ """
91
+ return "dis_custom"
rembg/sessions/dis_general_use.py CHANGED
@@ -24,7 +24,7 @@ class DisSession(BaseSession):
24
  """
25
  ort_outs = self.inner_session.run(
26
  None,
27
- self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
28
  )
29
 
30
  pred = ort_outs[0][:, 0, :, :]
 
24
  """
25
  ort_outs = self.inner_session.run(
26
  None,
27
+ self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (1024, 1024)),
28
  )
29
 
30
  pred = ort_outs[0][:, 0, :, :]
rembg/sessions/sam.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  from copy import deepcopy
3
- from typing import Dict, List, Tuple
4
 
5
  import cv2
6
  import numpy as np
@@ -87,7 +87,6 @@ class SamSession(BaseSession):
87
  self,
88
  model_name: str,
89
  sess_opts: ort.SessionOptions,
90
- providers=None,
91
  *args,
92
  **kwargs,
93
  ):
@@ -102,24 +101,13 @@ class SamSession(BaseSession):
102
  """
103
  self.model_name = model_name
104
 
105
- valid_providers = []
106
- available_providers = ort.get_available_providers()
107
-
108
- for provider in providers or []:
109
- if provider in available_providers:
110
- valid_providers.append(provider)
111
- else:
112
- valid_providers.extend(available_providers)
113
-
114
  paths = self.__class__.download_models(*args, **kwargs)
115
  self.encoder = ort.InferenceSession(
116
  str(paths[0]),
117
- providers=valid_providers,
118
  sess_options=sess_opts,
119
  )
120
  self.decoder = ort.InferenceSession(
121
  str(paths[1]),
122
- providers=valid_providers,
123
  sess_options=sess_opts,
124
  )
125
 
@@ -142,7 +130,16 @@ class SamSession(BaseSession):
142
  Returns:
143
  List[PILImage]: A list of masks generated by the decoder.
144
  """
145
- prompt = kwargs.get("sam_prompt", "{}")
 
 
 
 
 
 
 
 
 
146
  schema = {
147
  "type": "array",
148
  "items": {
 
1
  import os
2
  from copy import deepcopy
3
+ from typing import List
4
 
5
  import cv2
6
  import numpy as np
 
87
  self,
88
  model_name: str,
89
  sess_opts: ort.SessionOptions,
 
90
  *args,
91
  **kwargs,
92
  ):
 
101
  """
102
  self.model_name = model_name
103
 
 
 
 
 
 
 
 
 
 
104
  paths = self.__class__.download_models(*args, **kwargs)
105
  self.encoder = ort.InferenceSession(
106
  str(paths[0]),
 
107
  sess_options=sess_opts,
108
  )
109
  self.decoder = ort.InferenceSession(
110
  str(paths[1]),
 
111
  sess_options=sess_opts,
112
  )
113
 
 
130
  Returns:
131
  List[PILImage]: A list of masks generated by the decoder.
132
  """
133
+ prompt = kwargs.get(
134
+ "sam_prompt",
135
+ [
136
+ {
137
+ "type": "point",
138
+ "label": 1,
139
+ "data": [int(img.width / 2), int(img.height / 2)],
140
+ }
141
+ ],
142
+ )
143
  schema = {
144
  "type": "array",
145
  "items": {
rembg/sessions/u2net.py CHANGED
@@ -41,7 +41,7 @@ class U2netSession(BaseSession):
41
  pred = (pred - mi) / (ma - mi)
42
  pred = np.squeeze(pred)
43
 
44
- mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
45
  mask = mask.resize(img.size, Image.Resampling.LANCZOS)
46
 
47
  return [mask]
 
41
  pred = (pred - mi) / (ma - mi)
42
  pred = np.squeeze(pred)
43
 
44
+ mask = Image.fromarray((pred.clip(0, 1) * 255).astype("uint8"), mode="L")
45
  mask = mask.resize(img.size, Image.Resampling.LANCZOS)
46
 
47
  return [mask]
rembg/sessions/u2net_custom.py CHANGED
@@ -13,21 +13,13 @@ from .base import BaseSession
13
  class U2netCustomSession(BaseSession):
14
  """This is a class representing a custom session for the U2net model."""
15
 
16
- def __init__(
17
- self,
18
- model_name: str,
19
- sess_opts: ort.SessionOptions,
20
- providers=None,
21
- *args,
22
- **kwargs
23
- ):
24
  """
25
  Initialize a new U2netCustomSession object.
26
 
27
  Parameters:
28
  model_name (str): The name of the model.
29
  sess_opts (ort.SessionOptions): The session options.
30
- providers: The providers.
31
  *args: Additional positional arguments.
32
  **kwargs: Additional keyword arguments.
33
 
@@ -38,7 +30,7 @@ class U2netCustomSession(BaseSession):
38
  if model_path is None:
39
  raise ValueError("model_path is required")
40
 
41
- super().__init__(model_name, sess_opts, providers, *args, **kwargs)
42
 
43
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
44
  """
@@ -86,7 +78,7 @@ class U2netCustomSession(BaseSession):
86
  """
87
  model_path = kwargs.get("model_path")
88
  if model_path is None:
89
- return
90
 
91
  return os.path.abspath(os.path.expanduser(model_path))
92
 
 
13
  class U2netCustomSession(BaseSession):
14
  """This is a class representing a custom session for the U2net model."""
15
 
16
+ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
 
 
 
 
 
 
 
17
  """
18
  Initialize a new U2netCustomSession object.
19
 
20
  Parameters:
21
  model_name (str): The name of the model.
22
  sess_opts (ort.SessionOptions): The session options.
 
23
  *args: Additional positional arguments.
24
  **kwargs: Additional keyword arguments.
25
 
 
30
  if model_path is None:
31
  raise ValueError("model_path is required")
32
 
33
+ super().__init__(model_name, sess_opts, *args, **kwargs)
34
 
35
  def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
36
  """
 
78
  """
79
  model_path = kwargs.get("model_path")
80
  if model_path is None:
81
+ raise ValueError("model_path is required")
82
 
83
  return os.path.abspath(os.path.expanduser(model_path))
84
 
requirements.txt CHANGED
@@ -1,21 +1,22 @@
1
- pydantic==2.10.6
2
- filetype==1.2.0
3
- pooch==1.6.0
4
- imagehash==4.3.1
5
- numpy==1.23.5
6
- onnxruntime==1.13.1
7
- opencv-python-headless==4.6.0.66
8
- pillow==9.3.0
9
- pymatting==1.1.8
10
- python-multipart==0.0.5
11
- scikit-image==0.19.3
12
- scipy==1.9.3
13
- tqdm==4.64.1
14
- uvicorn==0.20.0
15
- watchdog==2.1.9
16
- click==8.1.3
17
- fastapi
18
- aiohttp==3.8.3
19
- asyncer==0.0.2
20
- gradio==3.0.20
21
- jsonschema==4.16.0
 
 
1
+ pydantic==2.10.6
2
+ filetype==1.2.0
3
+ pooch==1.6.0
4
+ imagehash==4.3.1
5
+ numpy==1.23.5
6
+ onnxruntime
7
+ opencv-python-headless==4.6.0.66
8
+ pillow==9.3.0
9
+ pymatting==1.1.8
10
+ python-multipart==0.0.5
11
+ scikit-image==0.19.3
12
+ scipy==1.9.3
13
+ tqdm==4.64.1
14
+ uvicorn==0.20.0
15
+ watchdog==2.1.9
16
+ click==8.1.3
17
+ fastapi
18
+ aiohttp==3.8.3
19
+ asyncer==0.0.2
20
+ gradio
21
+ jsonschema==4.16.0
22
+ huggingface-hub==0.34.3