Add bria-rembg
Browse files- app.py +100 -109
- rembg/_version.py +3 -3
- rembg/bg.py +21 -8
- rembg/commands/b_command.py +8 -7
- rembg/commands/d_command.py +5 -4
- rembg/commands/p_command.py +1 -2
- rembg/commands/s_command.py +7 -3
- rembg/session_factory.py +13 -9
- rembg/sessions/__init__.py +33 -35
- rembg/sessions/base.py +18 -18
- rembg/sessions/ben_custom.py +92 -0
- rembg/sessions/bria_rmbg.py +88 -0
- rembg/sessions/dis_custom.py +91 -0
- rembg/sessions/dis_general_use.py +1 -1
- rembg/sessions/sam.py +11 -14
- rembg/sessions/u2net.py +1 -1
- rembg/sessions/u2net_custom.py +3 -11
- requirements.txt +22 -21
app.py
CHANGED
|
@@ -1,109 +1,100 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
import os
|
| 3 |
-
import cv2
|
| 4 |
-
from rembg import new_session, remove
|
| 5 |
-
from rembg.
|
| 6 |
-
|
| 7 |
-
def inference(file, mask, model, x, y):
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
"
|
| 64 |
-
"
|
| 65 |
-
"
|
| 66 |
-
"
|
| 67 |
-
"
|
| 68 |
-
"
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
gr.
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
["girl.jpg", "Default", "u2net", None, None],
|
| 102 |
-
["anime-girl.jpg", "Default", "isnet-anime", None, None]
|
| 103 |
-
],
|
| 104 |
-
inputs=[inputs, mask_option, model_selector, x, y],
|
| 105 |
-
outputs=outputs
|
| 106 |
-
)
|
| 107 |
-
gr.HTML(badge)
|
| 108 |
-
|
| 109 |
-
app.launch(share=True)
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import cv2
|
| 4 |
+
from rembg import new_session, remove
|
| 5 |
+
from rembg.bg import download_models
|
| 6 |
+
|
| 7 |
+
def inference(file, mask, model, x, y):
|
| 8 |
+
session = new_session(model)
|
| 9 |
+
|
| 10 |
+
output = remove(
|
| 11 |
+
file,
|
| 12 |
+
session=session,
|
| 13 |
+
**{ "sam_prompt": [{"type": "point", "data": [x, y], "label": 1}] },
|
| 14 |
+
only_mask=(mask == "Mask only")
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
return output
|
| 18 |
+
|
| 19 |
+
title = "RemBG"
|
| 20 |
+
description = "Gradio demo for **[RemBG](https://github.com/danielgatis/rembg)**. To use it, simply upload your image, select a model, click Process, and wait."
|
| 21 |
+
badge = """
|
| 22 |
+
<div style="position: fixed; left: 50%; text-align: center;">
|
| 23 |
+
<a href="https://github.com/danielgatis/rembg" target="_blank" style="text-decoration: none;">
|
| 24 |
+
<img src="https://img.shields.io/badge/RemBG-Github-blue" alt="RemBG Github" />
|
| 25 |
+
</a>
|
| 26 |
+
</div>
|
| 27 |
+
"""
|
| 28 |
+
def get_coords(evt: gr.SelectData) -> tuple:
|
| 29 |
+
return evt.index[0], evt.index[1]
|
| 30 |
+
|
| 31 |
+
def show_coords(model: str):
|
| 32 |
+
visible = model == "sam"
|
| 33 |
+
return gr.update(visible=visible), gr.update(visible=visible), gr.update(visible=visible)
|
| 34 |
+
|
| 35 |
+
download_models(tuple())
|
| 36 |
+
|
| 37 |
+
with gr.Blocks() as app:
|
| 38 |
+
gr.Markdown(f"# {title}")
|
| 39 |
+
gr.Markdown(description)
|
| 40 |
+
|
| 41 |
+
with gr.Row():
|
| 42 |
+
inputs = gr.Image(type="numpy", label="Input Image")
|
| 43 |
+
outputs = gr.Image(label="Output Image")
|
| 44 |
+
|
| 45 |
+
with gr.Row():
|
| 46 |
+
mask_option = gr.Radio(
|
| 47 |
+
["Default", "Mask only"],
|
| 48 |
+
value="Default",
|
| 49 |
+
label="Output Type"
|
| 50 |
+
)
|
| 51 |
+
model_selector = gr.Dropdown(
|
| 52 |
+
[
|
| 53 |
+
"u2net",
|
| 54 |
+
"u2netp",
|
| 55 |
+
"u2net_human_seg",
|
| 56 |
+
"u2net_cloth_seg",
|
| 57 |
+
"silueta",
|
| 58 |
+
"isnet-general-use",
|
| 59 |
+
"isnet-anime",
|
| 60 |
+
"sam",
|
| 61 |
+
"bria-rmbg",
|
| 62 |
+
"birefnet-general",
|
| 63 |
+
"birefnet-general-lite",
|
| 64 |
+
"birefnet-portrait",
|
| 65 |
+
"birefnet-dis",
|
| 66 |
+
"birefnet-hrsod",
|
| 67 |
+
"birefnet-cod",
|
| 68 |
+
"birefnet-massive",
|
| 69 |
+
],
|
| 70 |
+
value="isnet-general-use",
|
| 71 |
+
label="Model Selection"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
extra = gr.Markdown("## Click on the image to capture coordinates (for SAM model)", visible=False)
|
| 75 |
+
|
| 76 |
+
x = gr.Number(label="Mouse X Coordinate", visible=False)
|
| 77 |
+
y = gr.Number(label="Mouse Y Coordinate", visible=False)
|
| 78 |
+
|
| 79 |
+
model_selector.change(show_coords, inputs=model_selector, outputs=[x, y, extra])
|
| 80 |
+
inputs.select(get_coords, None, [x, y])
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
gr.Button("Process Image").click(
|
| 84 |
+
inference,
|
| 85 |
+
inputs=[inputs, mask_option, model_selector, x, y],
|
| 86 |
+
outputs=outputs
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
gr.Examples(
|
| 90 |
+
examples=[
|
| 91 |
+
["lion.png", "Default", "u2net", None, None],
|
| 92 |
+
["girl.jpg", "Default", "u2net", None, None],
|
| 93 |
+
["anime-girl.jpg", "Default", "isnet-anime", None, None]
|
| 94 |
+
],
|
| 95 |
+
inputs=[inputs, mask_option, model_selector, x, y],
|
| 96 |
+
outputs=outputs
|
| 97 |
+
)
|
| 98 |
+
gr.HTML(badge)
|
| 99 |
+
|
| 100 |
+
app.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rembg/_version.py
CHANGED
|
@@ -23,9 +23,9 @@ def get_keywords():
|
|
| 23 |
# setup.py/versioneer.py will grep for the variable names, so they must
|
| 24 |
# each be defined on a line of their own. _version.py will just call
|
| 25 |
# get_keywords().
|
| 26 |
-
git_refnames = " (HEAD -> main)"
|
| 27 |
-
git_full = "
|
| 28 |
-
git_date = "
|
| 29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
| 30 |
return keywords
|
| 31 |
|
|
|
|
| 23 |
# setup.py/versioneer.py will grep for the variable names, so they must
|
| 24 |
# each be defined on a line of their own. _version.py will just call
|
| 25 |
# get_keywords().
|
| 26 |
+
git_refnames = " (HEAD -> main, tag: v2.0.69)"
|
| 27 |
+
git_full = "df72e3dea3f41e543a13991cb05b8a2659ee95c1"
|
| 28 |
+
git_date = "2025-12-04 18:05:12 -0300"
|
| 29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
| 30 |
return keywords
|
| 31 |
|
rembg/bg.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import io
|
|
|
|
| 2 |
from enum import Enum
|
| 3 |
from typing import Any, List, Optional, Tuple, Union, cast
|
| 4 |
|
|
@@ -20,7 +21,7 @@ from pymatting.util.util import stack_images
|
|
| 20 |
from scipy.ndimage import binary_erosion
|
| 21 |
|
| 22 |
from .session_factory import new_session
|
| 23 |
-
from .sessions import
|
| 24 |
from .sessions.base import BaseSession
|
| 25 |
|
| 26 |
ort.set_default_logger_severity(3)
|
|
@@ -175,9 +176,8 @@ def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> P
|
|
| 175 |
Returns:
|
| 176 |
PILImage: The modified image with the background color applied.
|
| 177 |
"""
|
| 178 |
-
|
| 179 |
-
colored_image = Image.
|
| 180 |
-
colored_image.paste(img, mask=img)
|
| 181 |
|
| 182 |
return colored_image
|
| 183 |
|
|
@@ -195,12 +195,25 @@ def fix_image_orientation(img: PILImage) -> PILImage:
|
|
| 195 |
return cast(PILImage, ImageOps.exif_transpose(img))
|
| 196 |
|
| 197 |
|
| 198 |
-
def download_models() -> None:
|
| 199 |
"""
|
| 200 |
Download models for image processing.
|
| 201 |
"""
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
|
| 206 |
def remove(
|
|
@@ -215,7 +228,7 @@ def remove(
|
|
| 215 |
bgcolor: Optional[Tuple[int, int, int, int]] = None,
|
| 216 |
force_return_bytes: bool = False,
|
| 217 |
*args: Optional[Any],
|
| 218 |
-
**kwargs: Optional[Any]
|
| 219 |
) -> Union[bytes, PILImage, np.ndarray]:
|
| 220 |
"""
|
| 221 |
Remove the background from an input image.
|
|
|
|
| 1 |
import io
|
| 2 |
+
import sys
|
| 3 |
from enum import Enum
|
| 4 |
from typing import Any, List, Optional, Tuple, Union, cast
|
| 5 |
|
|
|
|
| 21 |
from scipy.ndimage import binary_erosion
|
| 22 |
|
| 23 |
from .session_factory import new_session
|
| 24 |
+
from .sessions import sessions, sessions_names
|
| 25 |
from .sessions.base import BaseSession
|
| 26 |
|
| 27 |
ort.set_default_logger_severity(3)
|
|
|
|
| 176 |
Returns:
|
| 177 |
PILImage: The modified image with the background color applied.
|
| 178 |
"""
|
| 179 |
+
background = Image.new("RGBA", img.size, tuple(color))
|
| 180 |
+
colored_image = Image.alpha_composite(background, img)
|
|
|
|
| 181 |
|
| 182 |
return colored_image
|
| 183 |
|
|
|
|
| 195 |
return cast(PILImage, ImageOps.exif_transpose(img))
|
| 196 |
|
| 197 |
|
| 198 |
+
def download_models(models: tuple[str, ...]) -> None:
|
| 199 |
"""
|
| 200 |
Download models for image processing.
|
| 201 |
"""
|
| 202 |
+
if len(models) == 0:
|
| 203 |
+
print("No models specified, downloading all models")
|
| 204 |
+
models = tuple(sessions_names)
|
| 205 |
+
|
| 206 |
+
for model in models:
|
| 207 |
+
session = sessions.get(model)
|
| 208 |
+
if session is None:
|
| 209 |
+
print(f"Error: no model found: {model}")
|
| 210 |
+
sys.exit(1)
|
| 211 |
+
else:
|
| 212 |
+
print(f"Downloading model: {model}")
|
| 213 |
+
try:
|
| 214 |
+
session.download_models()
|
| 215 |
+
except Exception as e:
|
| 216 |
+
print(f"Error downloading model: {e}")
|
| 217 |
|
| 218 |
|
| 219 |
def remove(
|
|
|
|
| 228 |
bgcolor: Optional[Tuple[int, int, int, int]] = None,
|
| 229 |
force_return_bytes: bool = False,
|
| 230 |
*args: Optional[Any],
|
| 231 |
+
**kwargs: Optional[Any],
|
| 232 |
) -> Union[bytes, PILImage, np.ndarray]:
|
| 233 |
"""
|
| 234 |
Remove the background from an input image.
|
rembg/commands/b_command.py
CHANGED
|
@@ -6,7 +6,7 @@ import sys
|
|
| 6 |
from typing import IO
|
| 7 |
|
| 8 |
import click
|
| 9 |
-
|
| 10 |
|
| 11 |
from ..bg import remove
|
| 12 |
from ..session_factory import new_session
|
|
@@ -118,10 +118,11 @@ def b_command(
|
|
| 118 |
Returns:
|
| 119 |
None
|
| 120 |
"""
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
| 125 |
|
| 126 |
session = new_session(model, **kwargs)
|
| 127 |
bytes_per_img = image_width * image_height * 3
|
|
@@ -134,7 +135,7 @@ def b_command(
|
|
| 134 |
if not os.path.isdir(output_dir):
|
| 135 |
os.makedirs(output_dir, exist_ok=True)
|
| 136 |
|
| 137 |
-
def img_to_byte_array(img:
|
| 138 |
buff = io.BytesIO()
|
| 139 |
img.save(buff, format="PNG")
|
| 140 |
return buff.getvalue()
|
|
@@ -162,7 +163,7 @@ def b_command(
|
|
| 162 |
if not img_bytes:
|
| 163 |
break
|
| 164 |
|
| 165 |
-
img =
|
| 166 |
output = remove(img, session=session, **kwargs)
|
| 167 |
|
| 168 |
if output_specifier:
|
|
|
|
| 6 |
from typing import IO
|
| 7 |
|
| 8 |
import click
|
| 9 |
+
import PIL
|
| 10 |
|
| 11 |
from ..bg import remove
|
| 12 |
from ..session_factory import new_session
|
|
|
|
| 118 |
Returns:
|
| 119 |
None
|
| 120 |
"""
|
| 121 |
+
if extras:
|
| 122 |
+
try:
|
| 123 |
+
kwargs.update(json.loads(extras))
|
| 124 |
+
except Exception:
|
| 125 |
+
raise click.BadParameter("extras must be a valid JSON string")
|
| 126 |
|
| 127 |
session = new_session(model, **kwargs)
|
| 128 |
bytes_per_img = image_width * image_height * 3
|
|
|
|
| 135 |
if not os.path.isdir(output_dir):
|
| 136 |
os.makedirs(output_dir, exist_ok=True)
|
| 137 |
|
| 138 |
+
def img_to_byte_array(img: PIL.Image.Image) -> bytes:
|
| 139 |
buff = io.BytesIO()
|
| 140 |
img.save(buff, format="PNG")
|
| 141 |
return buff.getvalue()
|
|
|
|
| 163 |
if not img_bytes:
|
| 164 |
break
|
| 165 |
|
| 166 |
+
img = PIL.Image.frombytes("RGB", (image_width, image_height), img_bytes)
|
| 167 |
output = remove(img, session=session, **kwargs)
|
| 168 |
|
| 169 |
if output_specifier:
|
rembg/commands/d_command.py
CHANGED
|
@@ -5,10 +5,11 @@ from ..bg import download_models
|
|
| 5 |
|
| 6 |
@click.command( # type: ignore
|
| 7 |
name="d",
|
| 8 |
-
help="download
|
| 9 |
)
|
| 10 |
-
|
|
|
|
| 11 |
"""
|
| 12 |
-
Download
|
| 13 |
"""
|
| 14 |
-
download_models()
|
|
|
|
| 5 |
|
| 6 |
@click.command( # type: ignore
|
| 7 |
name="d",
|
| 8 |
+
help="download models",
|
| 9 |
)
|
| 10 |
+
@click.argument("models", nargs=-1)
|
| 11 |
+
def d_command(models: tuple[str, ...]) -> None:
|
| 12 |
"""
|
| 13 |
+
Download models
|
| 14 |
"""
|
| 15 |
+
download_models(models)
|
rembg/commands/p_command.py
CHANGED
|
@@ -185,8 +185,7 @@ def p_command(
|
|
| 185 |
print(e)
|
| 186 |
|
| 187 |
inputs = list(input.glob("**/*"))
|
| 188 |
-
|
| 189 |
-
inputs_tqdm = tqdm(inputs)
|
| 190 |
|
| 191 |
for each_input in inputs_tqdm:
|
| 192 |
if not each_input.is_dir():
|
|
|
|
| 185 |
print(e)
|
| 186 |
|
| 187 |
inputs = list(input.glob("**/*"))
|
| 188 |
+
inputs_tqdm = inputs if watch else tqdm(inputs)
|
|
|
|
| 189 |
|
| 190 |
for each_input in inputs_tqdm:
|
| 191 |
if not each_input.is_dir():
|
rembg/commands/s_command.py
CHANGED
|
@@ -197,12 +197,15 @@ def s_command(port: int, host: str, log_level: str, threads: int) -> None:
|
|
| 197 |
except Exception:
|
| 198 |
pass
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
return Response(
|
| 201 |
remove(
|
| 202 |
content,
|
| 203 |
-
session=
|
| 204 |
-
commons.model, new_session(commons.model, **kwargs)
|
| 205 |
-
),
|
| 206 |
alpha_matting=commons.a,
|
| 207 |
alpha_matting_foreground_threshold=commons.af,
|
| 208 |
alpha_matting_background_threshold=commons.ab,
|
|
@@ -306,6 +309,7 @@ def s_command(port: int, host: str, log_level: str, threads: int) -> None:
|
|
| 306 |
],
|
| 307 |
gr.components.Image(type="filepath", label="Output"),
|
| 308 |
concurrency_limit=3,
|
|
|
|
| 309 |
)
|
| 310 |
|
| 311 |
app = gr.mount_gradio_app(app, interface, path="/")
|
|
|
|
| 197 |
except Exception:
|
| 198 |
pass
|
| 199 |
|
| 200 |
+
session = sessions.get(commons.model)
|
| 201 |
+
if session is None:
|
| 202 |
+
session = new_session(commons.model, **kwargs)
|
| 203 |
+
sessions[commons.model] = session
|
| 204 |
+
|
| 205 |
return Response(
|
| 206 |
remove(
|
| 207 |
content,
|
| 208 |
+
session=session,
|
|
|
|
|
|
|
| 209 |
alpha_matting=commons.a,
|
| 210 |
alpha_matting_foreground_threshold=commons.af,
|
| 211 |
alpha_matting_background_threshold=commons.ab,
|
|
|
|
| 309 |
],
|
| 310 |
gr.components.Image(type="filepath", label="Output"),
|
| 311 |
concurrency_limit=3,
|
| 312 |
+
analytics_enabled=False,
|
| 313 |
)
|
| 314 |
|
| 315 |
app = gr.mount_gradio_app(app, interface, path="/")
|
rembg/session_factory.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import os
|
| 2 |
-
from typing import Type
|
| 3 |
|
| 4 |
import onnxruntime as ort
|
| 5 |
|
|
@@ -8,9 +8,7 @@ from .sessions.base import BaseSession
|
|
| 8 |
from .sessions.u2net import U2netSession
|
| 9 |
|
| 10 |
|
| 11 |
-
def new_session(
|
| 12 |
-
model_name: str = "u2net", providers=None, *args, **kwargs
|
| 13 |
-
) -> BaseSession:
|
| 14 |
"""
|
| 15 |
Create a new session object based on the specified model name.
|
| 16 |
|
|
@@ -21,24 +19,30 @@ def new_session(
|
|
| 21 |
|
| 22 |
Parameters:
|
| 23 |
model_name (str): The name of the model.
|
| 24 |
-
providers: The providers for the session.
|
| 25 |
*args: Additional positional arguments.
|
| 26 |
**kwargs: Additional keyword arguments.
|
| 27 |
|
|
|
|
|
|
|
|
|
|
| 28 |
Returns:
|
| 29 |
BaseSession: The created session object.
|
| 30 |
"""
|
| 31 |
-
session_class: Type[BaseSession] =
|
| 32 |
|
| 33 |
for sc in sessions_class:
|
| 34 |
if sc.name() == model_name:
|
| 35 |
session_class = sc
|
| 36 |
break
|
| 37 |
|
|
|
|
|
|
|
|
|
|
| 38 |
sess_opts = ort.SessionOptions()
|
| 39 |
|
| 40 |
if "OMP_NUM_THREADS" in os.environ:
|
| 41 |
-
|
| 42 |
-
sess_opts.
|
|
|
|
| 43 |
|
| 44 |
-
return session_class(model_name, sess_opts,
|
|
|
|
| 1 |
import os
|
| 2 |
+
from typing import Optional, Type
|
| 3 |
|
| 4 |
import onnxruntime as ort
|
| 5 |
|
|
|
|
| 8 |
from .sessions.u2net import U2netSession
|
| 9 |
|
| 10 |
|
| 11 |
+
def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
Create a new session object based on the specified model name.
|
| 14 |
|
|
|
|
| 19 |
|
| 20 |
Parameters:
|
| 21 |
model_name (str): The name of the model.
|
|
|
|
| 22 |
*args: Additional positional arguments.
|
| 23 |
**kwargs: Additional keyword arguments.
|
| 24 |
|
| 25 |
+
Raises:
|
| 26 |
+
ValueError: If no session class with the given `model_name` is found.
|
| 27 |
+
|
| 28 |
Returns:
|
| 29 |
BaseSession: The created session object.
|
| 30 |
"""
|
| 31 |
+
session_class: Optional[Type[BaseSession]] = None
|
| 32 |
|
| 33 |
for sc in sessions_class:
|
| 34 |
if sc.name() == model_name:
|
| 35 |
session_class = sc
|
| 36 |
break
|
| 37 |
|
| 38 |
+
if session_class is None:
|
| 39 |
+
raise ValueError(f"No session class found for model '{model_name}'")
|
| 40 |
+
|
| 41 |
sess_opts = ort.SessionOptions()
|
| 42 |
|
| 43 |
if "OMP_NUM_THREADS" in os.environ:
|
| 44 |
+
threads = int(os.environ["OMP_NUM_THREADS"])
|
| 45 |
+
sess_opts.inter_op_num_threads = threads
|
| 46 |
+
sess_opts.intra_op_num_threads = threads
|
| 47 |
|
| 48 |
+
return session_class(model_name, sess_opts, *args, **kwargs)
|
rembg/sessions/__init__.py
CHANGED
|
@@ -1,88 +1,86 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
from typing import List
|
| 4 |
|
| 5 |
from .base import BaseSession
|
| 6 |
|
| 7 |
-
|
| 8 |
-
sessions_names: List[str] = []
|
| 9 |
|
| 10 |
from .birefnet_general import BiRefNetSessionGeneral
|
| 11 |
|
| 12 |
-
|
| 13 |
-
sessions_names.append(BiRefNetSessionGeneral.name())
|
| 14 |
|
| 15 |
from .birefnet_general_lite import BiRefNetSessionGeneralLite
|
| 16 |
|
| 17 |
-
|
| 18 |
-
sessions_names.append(BiRefNetSessionGeneralLite.name())
|
| 19 |
|
| 20 |
from .birefnet_portrait import BiRefNetSessionPortrait
|
| 21 |
|
| 22 |
-
|
| 23 |
-
sessions_names.append(BiRefNetSessionPortrait.name())
|
| 24 |
|
| 25 |
from .birefnet_dis import BiRefNetSessionDIS
|
| 26 |
|
| 27 |
-
|
| 28 |
-
sessions_names.append(BiRefNetSessionDIS.name())
|
| 29 |
|
| 30 |
from .birefnet_hrsod import BiRefNetSessionHRSOD
|
| 31 |
|
| 32 |
-
|
| 33 |
-
sessions_names.append(BiRefNetSessionHRSOD.name())
|
| 34 |
|
| 35 |
from .birefnet_cod import BiRefNetSessionCOD
|
| 36 |
|
| 37 |
-
|
| 38 |
-
sessions_names.append(BiRefNetSessionCOD.name())
|
| 39 |
|
| 40 |
from .birefnet_massive import BiRefNetSessionMassive
|
| 41 |
|
| 42 |
-
|
| 43 |
-
sessions_names.append(BiRefNetSessionMassive.name())
|
| 44 |
|
| 45 |
from .dis_anime import DisSession
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
from .dis_general_use import DisSession as DisSessionGeneralUse
|
| 51 |
|
| 52 |
-
|
| 53 |
-
sessions_names.append(DisSessionGeneralUse.name())
|
| 54 |
|
| 55 |
from .sam import SamSession
|
| 56 |
|
| 57 |
-
|
| 58 |
-
sessions_names.append(SamSession.name())
|
| 59 |
|
| 60 |
from .silueta import SiluetaSession
|
| 61 |
|
| 62 |
-
|
| 63 |
-
sessions_names.append(SiluetaSession.name())
|
| 64 |
|
| 65 |
from .u2net_cloth_seg import Unet2ClothSession
|
| 66 |
|
| 67 |
-
|
| 68 |
-
sessions_names.append(Unet2ClothSession.name())
|
| 69 |
|
| 70 |
from .u2net_custom import U2netCustomSession
|
| 71 |
|
| 72 |
-
|
| 73 |
-
sessions_names.append(U2netCustomSession.name())
|
| 74 |
|
| 75 |
from .u2net_human_seg import U2netHumanSegSession
|
| 76 |
|
| 77 |
-
|
| 78 |
-
sessions_names.append(U2netHumanSegSession.name())
|
| 79 |
|
| 80 |
from .u2net import U2netSession
|
| 81 |
|
| 82 |
-
|
| 83 |
-
sessions_names.append(U2netSession.name())
|
| 84 |
|
| 85 |
from .u2netp import U2netpSession
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
from typing import Dict, List
|
| 4 |
|
| 5 |
from .base import BaseSession
|
| 6 |
|
| 7 |
+
sessions: Dict[str, type[BaseSession]] = {}
|
|
|
|
| 8 |
|
| 9 |
from .birefnet_general import BiRefNetSessionGeneral
|
| 10 |
|
| 11 |
+
sessions[BiRefNetSessionGeneral.name()] = BiRefNetSessionGeneral
|
|
|
|
| 12 |
|
| 13 |
from .birefnet_general_lite import BiRefNetSessionGeneralLite
|
| 14 |
|
| 15 |
+
sessions[BiRefNetSessionGeneralLite.name()] = BiRefNetSessionGeneralLite
|
|
|
|
| 16 |
|
| 17 |
from .birefnet_portrait import BiRefNetSessionPortrait
|
| 18 |
|
| 19 |
+
sessions[BiRefNetSessionPortrait.name()] = BiRefNetSessionPortrait
|
|
|
|
| 20 |
|
| 21 |
from .birefnet_dis import BiRefNetSessionDIS
|
| 22 |
|
| 23 |
+
sessions[BiRefNetSessionDIS.name()] = BiRefNetSessionDIS
|
|
|
|
| 24 |
|
| 25 |
from .birefnet_hrsod import BiRefNetSessionHRSOD
|
| 26 |
|
| 27 |
+
sessions[BiRefNetSessionHRSOD.name()] = BiRefNetSessionHRSOD
|
|
|
|
| 28 |
|
| 29 |
from .birefnet_cod import BiRefNetSessionCOD
|
| 30 |
|
| 31 |
+
sessions[BiRefNetSessionCOD.name()] = BiRefNetSessionCOD
|
|
|
|
| 32 |
|
| 33 |
from .birefnet_massive import BiRefNetSessionMassive
|
| 34 |
|
| 35 |
+
sessions[BiRefNetSessionMassive.name()] = BiRefNetSessionMassive
|
|
|
|
| 36 |
|
| 37 |
from .dis_anime import DisSession
|
| 38 |
|
| 39 |
+
sessions[DisSession.name()] = DisSession
|
| 40 |
+
|
| 41 |
+
from .dis_custom import DisCustomSession
|
| 42 |
+
|
| 43 |
+
sessions[DisCustomSession.name()] = DisCustomSession
|
| 44 |
|
| 45 |
from .dis_general_use import DisSession as DisSessionGeneralUse
|
| 46 |
|
| 47 |
+
sessions[DisSessionGeneralUse.name()] = DisSessionGeneralUse
|
|
|
|
| 48 |
|
| 49 |
from .sam import SamSession
|
| 50 |
|
| 51 |
+
sessions[SamSession.name()] = SamSession
|
|
|
|
| 52 |
|
| 53 |
from .silueta import SiluetaSession
|
| 54 |
|
| 55 |
+
sessions[SiluetaSession.name()] = SiluetaSession
|
|
|
|
| 56 |
|
| 57 |
from .u2net_cloth_seg import Unet2ClothSession
|
| 58 |
|
| 59 |
+
sessions[Unet2ClothSession.name()] = Unet2ClothSession
|
|
|
|
| 60 |
|
| 61 |
from .u2net_custom import U2netCustomSession
|
| 62 |
|
| 63 |
+
sessions[U2netCustomSession.name()] = U2netCustomSession
|
|
|
|
| 64 |
|
| 65 |
from .u2net_human_seg import U2netHumanSegSession
|
| 66 |
|
| 67 |
+
sessions[U2netHumanSegSession.name()] = U2netHumanSegSession
|
|
|
|
| 68 |
|
| 69 |
from .u2net import U2netSession
|
| 70 |
|
| 71 |
+
sessions[U2netSession.name()] = U2netSession
|
|
|
|
| 72 |
|
| 73 |
from .u2netp import U2netpSession
|
| 74 |
|
| 75 |
+
sessions[U2netpSession.name()] = U2netpSession
|
| 76 |
+
|
| 77 |
+
from .bria_rmbg import BriaRmBgSession
|
| 78 |
+
|
| 79 |
+
sessions[BriaRmBgSession.name()] = BriaRmBgSession
|
| 80 |
+
|
| 81 |
+
from .ben_custom import BenCustomSession
|
| 82 |
+
|
| 83 |
+
sessions[BenCustomSession.name()] = BenCustomSession
|
| 84 |
+
|
| 85 |
+
sessions_names = list(sessions.keys())
|
| 86 |
+
sessions_class = list(sessions.values())
|
rembg/sessions/base.py
CHANGED
|
@@ -10,31 +10,31 @@ from PIL.Image import Image as PILImage
|
|
| 10 |
class BaseSession:
|
| 11 |
"""This is a base class for managing a session with a machine learning model."""
|
| 12 |
|
| 13 |
-
def __init__(
|
| 14 |
-
self,
|
| 15 |
-
model_name: str,
|
| 16 |
-
sess_opts: ort.SessionOptions,
|
| 17 |
-
providers=None,
|
| 18 |
-
*args,
|
| 19 |
-
**kwargs
|
| 20 |
-
):
|
| 21 |
"""Initialize an instance of the BaseSession class."""
|
| 22 |
self.model_name = model_name
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
_providers = ort.get_available_providers()
|
| 27 |
-
if providers:
|
| 28 |
-
for provider in providers:
|
| 29 |
-
if provider in _providers:
|
| 30 |
-
self.providers.append(provider)
|
| 31 |
else:
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
self.inner_session = ort.InferenceSession(
|
| 35 |
str(self.__class__.download_models(*args, **kwargs)),
|
| 36 |
-
providers=self.providers,
|
| 37 |
sess_options=sess_opts,
|
|
|
|
| 38 |
)
|
| 39 |
|
| 40 |
def normalize(
|
|
@@ -49,7 +49,7 @@ class BaseSession:
|
|
| 49 |
im = img.convert("RGB").resize(size, Image.Resampling.LANCZOS)
|
| 50 |
|
| 51 |
im_ary = np.array(im)
|
| 52 |
-
im_ary = im_ary / np.max(im_ary)
|
| 53 |
|
| 54 |
tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
|
| 55 |
tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
|
|
|
|
| 10 |
class BaseSession:
|
| 11 |
"""This is a base class for managing a session with a machine learning model."""
|
| 12 |
|
| 13 |
+
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""Initialize an instance of the BaseSession class."""
|
| 15 |
self.model_name = model_name
|
| 16 |
|
| 17 |
+
if "providers" in kwargs and isinstance(kwargs["providers"], list):
|
| 18 |
+
providers = kwargs.pop("providers")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
else:
|
| 20 |
+
device_type = ort.get_device()
|
| 21 |
+
if (
|
| 22 |
+
device_type == "GPU"
|
| 23 |
+
and "CUDAExecutionProvider" in ort.get_available_providers()
|
| 24 |
+
):
|
| 25 |
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 26 |
+
elif (
|
| 27 |
+
device_type[0:3] == "GPU"
|
| 28 |
+
and "ROCMExecutionProvider" in ort.get_available_providers()
|
| 29 |
+
):
|
| 30 |
+
providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
|
| 31 |
+
else:
|
| 32 |
+
providers = ["CPUExecutionProvider"]
|
| 33 |
|
| 34 |
self.inner_session = ort.InferenceSession(
|
| 35 |
str(self.__class__.download_models(*args, **kwargs)),
|
|
|
|
| 36 |
sess_options=sess_opts,
|
| 37 |
+
providers=providers,
|
| 38 |
)
|
| 39 |
|
| 40 |
def normalize(
|
|
|
|
| 49 |
im = img.convert("RGB").resize(size, Image.Resampling.LANCZOS)
|
| 50 |
|
| 51 |
im_ary = np.array(im)
|
| 52 |
+
im_ary = im_ary / max(np.max(im_ary), 1e-6)
|
| 53 |
|
| 54 |
tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
|
| 55 |
tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
|
rembg/sessions/ben_custom.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import onnxruntime as ort
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from PIL.Image import Image as PILImage
|
| 8 |
+
|
| 9 |
+
from .base import BaseSession
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BenCustomSession(BaseSession):
|
| 13 |
+
"""This is a class representing a custom session for the Ben model."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
|
| 16 |
+
"""
|
| 17 |
+
Initialize a new BenCustomSession object.
|
| 18 |
+
|
| 19 |
+
Parameters:
|
| 20 |
+
model_name (str): The name of the model.
|
| 21 |
+
sess_opts: The session options.
|
| 22 |
+
*args: Additional positional arguments.
|
| 23 |
+
**kwargs: Additional keyword arguments.
|
| 24 |
+
"""
|
| 25 |
+
model_path = kwargs.get("model_path")
|
| 26 |
+
if model_path is None:
|
| 27 |
+
raise ValueError("model_path is required")
|
| 28 |
+
|
| 29 |
+
super().__init__(model_name, sess_opts, *args, **kwargs)
|
| 30 |
+
|
| 31 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 32 |
+
"""
|
| 33 |
+
Predicts the mask image for the input image.
|
| 34 |
+
|
| 35 |
+
This method takes a PILImage object as input and returns a list of PILImage objects as output. It performs several image processing operations to generate the mask image.
|
| 36 |
+
|
| 37 |
+
Parameters:
|
| 38 |
+
img (PILImage): The input image.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
List[PILImage]: A list of PILImage objects representing the generated mask image.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
ort_outs = self.inner_session.run(
|
| 45 |
+
None,
|
| 46 |
+
self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (1024, 1024)),
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
pred = ort_outs[0][:, 0, :, :]
|
| 50 |
+
|
| 51 |
+
ma = np.max(pred)
|
| 52 |
+
mi = np.min(pred)
|
| 53 |
+
|
| 54 |
+
pred = (pred - mi) / (ma - mi)
|
| 55 |
+
pred = np.squeeze(pred)
|
| 56 |
+
|
| 57 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
| 58 |
+
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
| 59 |
+
|
| 60 |
+
return [mask]
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def download_models(cls, *args, **kwargs):
|
| 64 |
+
"""
|
| 65 |
+
Download the model files.
|
| 66 |
+
|
| 67 |
+
Parameters:
|
| 68 |
+
*args: Additional positional arguments.
|
| 69 |
+
**kwargs: Additional keyword arguments.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
str: The absolute path to the model files.
|
| 73 |
+
"""
|
| 74 |
+
model_path = kwargs.get("model_path")
|
| 75 |
+
if model_path is None:
|
| 76 |
+
raise ValueError("model_path is required")
|
| 77 |
+
|
| 78 |
+
return os.path.abspath(os.path.expanduser(model_path))
|
| 79 |
+
|
| 80 |
+
@classmethod
|
| 81 |
+
def name(cls, *args, **kwargs):
|
| 82 |
+
"""
|
| 83 |
+
Get the name of the model.
|
| 84 |
+
|
| 85 |
+
Parameters:
|
| 86 |
+
*args: Additional positional arguments.
|
| 87 |
+
**kwargs: Additional keyword arguments.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
str: The name of the model.
|
| 91 |
+
"""
|
| 92 |
+
return "ben_custom"
|
rembg/sessions/bria_rmbg.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pooch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from PIL.Image import Image as PILImage
|
| 8 |
+
|
| 9 |
+
from .base import BaseSession
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BriaRmBgSession(BaseSession):
|
| 13 |
+
"""
|
| 14 |
+
This class represents a Bria-rmbg-2.0 session, which is a subclass of BaseSession.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 18 |
+
"""
|
| 19 |
+
Predicts the output masks for the input image using the inner session.
|
| 20 |
+
|
| 21 |
+
Parameters:
|
| 22 |
+
img (PILImage): The input image.
|
| 23 |
+
*args: Additional positional arguments.
|
| 24 |
+
**kwargs: Additional keyword arguments.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
List[PILImage]: The list of output masks.
|
| 28 |
+
"""
|
| 29 |
+
ort_outs = self.inner_session.run(
|
| 30 |
+
None,
|
| 31 |
+
self.normalize(
|
| 32 |
+
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (1024, 1024)
|
| 33 |
+
),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
pred = ort_outs[0][:, 0, :, :]
|
| 37 |
+
|
| 38 |
+
ma = np.max(pred)
|
| 39 |
+
mi = np.min(pred)
|
| 40 |
+
|
| 41 |
+
pred = (pred - mi) / (ma - mi)
|
| 42 |
+
pred = np.squeeze(pred)
|
| 43 |
+
|
| 44 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
| 45 |
+
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
| 46 |
+
|
| 47 |
+
return [mask]
|
| 48 |
+
|
| 49 |
+
@classmethod
|
| 50 |
+
def download_models(cls, *args, **kwargs):
|
| 51 |
+
"""
|
| 52 |
+
Downloads the BRIA-RMBG 2.0 model file from a specific URL and saves it.
|
| 53 |
+
|
| 54 |
+
Parameters:
|
| 55 |
+
*args: Additional positional arguments.
|
| 56 |
+
**kwargs: Additional keyword arguments.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
str: The path to the downloaded model file.
|
| 60 |
+
"""
|
| 61 |
+
fname = f"{cls.name(*args, **kwargs)}.onnx"
|
| 62 |
+
pooch.retrieve(
|
| 63 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx",
|
| 64 |
+
(
|
| 65 |
+
None
|
| 66 |
+
if cls.checksum_disabled(*args, **kwargs)
|
| 67 |
+
else "sha256:5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958"
|
| 68 |
+
),
|
| 69 |
+
fname=fname,
|
| 70 |
+
path=cls.u2net_home(*args, **kwargs),
|
| 71 |
+
progressbar=True,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
return os.path.join(cls.u2net_home(*args, **kwargs), fname)
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def name(cls, *args, **kwargs):
|
| 78 |
+
"""
|
| 79 |
+
Returns the name of the Bria-rmbg session.
|
| 80 |
+
|
| 81 |
+
Parameters:
|
| 82 |
+
*args: Additional positional arguments.
|
| 83 |
+
**kwargs: Additional keyword arguments.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
str: The name of the session.
|
| 87 |
+
"""
|
| 88 |
+
return "bria-rmbg"
|
rembg/sessions/dis_custom.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import onnxruntime as ort
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from PIL.Image import Image as PILImage
|
| 8 |
+
|
| 9 |
+
from .base import BaseSession
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DisCustomSession(BaseSession):
|
| 13 |
+
"""This is a class representing a custom session for the Dis model."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
|
| 16 |
+
"""
|
| 17 |
+
Initialize a new DisCustomSession object.
|
| 18 |
+
|
| 19 |
+
Parameters:
|
| 20 |
+
model_name (str): The name of the model.
|
| 21 |
+
sess_opts: The session options.
|
| 22 |
+
*args: Additional positional arguments.
|
| 23 |
+
**kwargs: Additional keyword arguments.
|
| 24 |
+
"""
|
| 25 |
+
model_path = kwargs.get("model_path")
|
| 26 |
+
if model_path is None:
|
| 27 |
+
raise ValueError("model_path is required")
|
| 28 |
+
|
| 29 |
+
super().__init__(model_name, sess_opts, *args, **kwargs)
|
| 30 |
+
|
| 31 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 32 |
+
"""
|
| 33 |
+
Predicts the mask image for the input image.
|
| 34 |
+
|
| 35 |
+
This method takes a PILImage object as input and returns a list of PILImage objects as output. It performs several image processing operations to generate the mask image.
|
| 36 |
+
|
| 37 |
+
Parameters:
|
| 38 |
+
img (PILImage): The input image.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
List[PILImage]: A list of PILImage objects representing the generated mask image.
|
| 42 |
+
"""
|
| 43 |
+
ort_outs = self.inner_session.run(
|
| 44 |
+
None,
|
| 45 |
+
self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (1024, 1024)),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
pred = ort_outs[0][:, 0, :, :]
|
| 49 |
+
|
| 50 |
+
ma = np.max(pred)
|
| 51 |
+
mi = np.min(pred)
|
| 52 |
+
|
| 53 |
+
pred = (pred - mi) / (ma - mi)
|
| 54 |
+
pred = np.squeeze(pred)
|
| 55 |
+
|
| 56 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
| 57 |
+
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
| 58 |
+
|
| 59 |
+
return [mask]
|
| 60 |
+
|
| 61 |
+
@classmethod
|
| 62 |
+
def download_models(cls, *args, **kwargs):
|
| 63 |
+
"""
|
| 64 |
+
Download the model files.
|
| 65 |
+
|
| 66 |
+
Parameters:
|
| 67 |
+
*args: Additional positional arguments.
|
| 68 |
+
**kwargs: Additional keyword arguments.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
str: The absolute path to the model files.
|
| 72 |
+
"""
|
| 73 |
+
model_path = kwargs.get("model_path")
|
| 74 |
+
if model_path is None:
|
| 75 |
+
raise ValueError("model_path is required")
|
| 76 |
+
|
| 77 |
+
return os.path.abspath(os.path.expanduser(model_path))
|
| 78 |
+
|
| 79 |
+
@classmethod
|
| 80 |
+
def name(cls, *args, **kwargs):
|
| 81 |
+
"""
|
| 82 |
+
Get the name of the model.
|
| 83 |
+
|
| 84 |
+
Parameters:
|
| 85 |
+
*args: Additional positional arguments.
|
| 86 |
+
**kwargs: Additional keyword arguments.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
str: The name of the model.
|
| 90 |
+
"""
|
| 91 |
+
return "dis_custom"
|
rembg/sessions/dis_general_use.py
CHANGED
|
@@ -24,7 +24,7 @@ class DisSession(BaseSession):
|
|
| 24 |
"""
|
| 25 |
ort_outs = self.inner_session.run(
|
| 26 |
None,
|
| 27 |
-
self.normalize(img, (0.
|
| 28 |
)
|
| 29 |
|
| 30 |
pred = ort_outs[0][:, 0, :, :]
|
|
|
|
| 24 |
"""
|
| 25 |
ort_outs = self.inner_session.run(
|
| 26 |
None,
|
| 27 |
+
self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (1024, 1024)),
|
| 28 |
)
|
| 29 |
|
| 30 |
pred = ort_outs[0][:, 0, :, :]
|
rembg/sessions/sam.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
from copy import deepcopy
|
| 3 |
-
from typing import
|
| 4 |
|
| 5 |
import cv2
|
| 6 |
import numpy as np
|
|
@@ -87,7 +87,6 @@ class SamSession(BaseSession):
|
|
| 87 |
self,
|
| 88 |
model_name: str,
|
| 89 |
sess_opts: ort.SessionOptions,
|
| 90 |
-
providers=None,
|
| 91 |
*args,
|
| 92 |
**kwargs,
|
| 93 |
):
|
|
@@ -102,24 +101,13 @@ class SamSession(BaseSession):
|
|
| 102 |
"""
|
| 103 |
self.model_name = model_name
|
| 104 |
|
| 105 |
-
valid_providers = []
|
| 106 |
-
available_providers = ort.get_available_providers()
|
| 107 |
-
|
| 108 |
-
for provider in providers or []:
|
| 109 |
-
if provider in available_providers:
|
| 110 |
-
valid_providers.append(provider)
|
| 111 |
-
else:
|
| 112 |
-
valid_providers.extend(available_providers)
|
| 113 |
-
|
| 114 |
paths = self.__class__.download_models(*args, **kwargs)
|
| 115 |
self.encoder = ort.InferenceSession(
|
| 116 |
str(paths[0]),
|
| 117 |
-
providers=valid_providers,
|
| 118 |
sess_options=sess_opts,
|
| 119 |
)
|
| 120 |
self.decoder = ort.InferenceSession(
|
| 121 |
str(paths[1]),
|
| 122 |
-
providers=valid_providers,
|
| 123 |
sess_options=sess_opts,
|
| 124 |
)
|
| 125 |
|
|
@@ -142,7 +130,16 @@ class SamSession(BaseSession):
|
|
| 142 |
Returns:
|
| 143 |
List[PILImage]: A list of masks generated by the decoder.
|
| 144 |
"""
|
| 145 |
-
prompt = kwargs.get(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
schema = {
|
| 147 |
"type": "array",
|
| 148 |
"items": {
|
|
|
|
| 1 |
import os
|
| 2 |
from copy import deepcopy
|
| 3 |
+
from typing import List
|
| 4 |
|
| 5 |
import cv2
|
| 6 |
import numpy as np
|
|
|
|
| 87 |
self,
|
| 88 |
model_name: str,
|
| 89 |
sess_opts: ort.SessionOptions,
|
|
|
|
| 90 |
*args,
|
| 91 |
**kwargs,
|
| 92 |
):
|
|
|
|
| 101 |
"""
|
| 102 |
self.model_name = model_name
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
paths = self.__class__.download_models(*args, **kwargs)
|
| 105 |
self.encoder = ort.InferenceSession(
|
| 106 |
str(paths[0]),
|
|
|
|
| 107 |
sess_options=sess_opts,
|
| 108 |
)
|
| 109 |
self.decoder = ort.InferenceSession(
|
| 110 |
str(paths[1]),
|
|
|
|
| 111 |
sess_options=sess_opts,
|
| 112 |
)
|
| 113 |
|
|
|
|
| 130 |
Returns:
|
| 131 |
List[PILImage]: A list of masks generated by the decoder.
|
| 132 |
"""
|
| 133 |
+
prompt = kwargs.get(
|
| 134 |
+
"sam_prompt",
|
| 135 |
+
[
|
| 136 |
+
{
|
| 137 |
+
"type": "point",
|
| 138 |
+
"label": 1,
|
| 139 |
+
"data": [int(img.width / 2), int(img.height / 2)],
|
| 140 |
+
}
|
| 141 |
+
],
|
| 142 |
+
)
|
| 143 |
schema = {
|
| 144 |
"type": "array",
|
| 145 |
"items": {
|
rembg/sessions/u2net.py
CHANGED
|
@@ -41,7 +41,7 @@ class U2netSession(BaseSession):
|
|
| 41 |
pred = (pred - mi) / (ma - mi)
|
| 42 |
pred = np.squeeze(pred)
|
| 43 |
|
| 44 |
-
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
| 45 |
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
| 46 |
|
| 47 |
return [mask]
|
|
|
|
| 41 |
pred = (pred - mi) / (ma - mi)
|
| 42 |
pred = np.squeeze(pred)
|
| 43 |
|
| 44 |
+
mask = Image.fromarray((pred.clip(0, 1) * 255).astype("uint8"), mode="L")
|
| 45 |
mask = mask.resize(img.size, Image.Resampling.LANCZOS)
|
| 46 |
|
| 47 |
return [mask]
|
rembg/sessions/u2net_custom.py
CHANGED
|
@@ -13,21 +13,13 @@ from .base import BaseSession
|
|
| 13 |
class U2netCustomSession(BaseSession):
|
| 14 |
"""This is a class representing a custom session for the U2net model."""
|
| 15 |
|
| 16 |
-
def __init__(
|
| 17 |
-
self,
|
| 18 |
-
model_name: str,
|
| 19 |
-
sess_opts: ort.SessionOptions,
|
| 20 |
-
providers=None,
|
| 21 |
-
*args,
|
| 22 |
-
**kwargs
|
| 23 |
-
):
|
| 24 |
"""
|
| 25 |
Initialize a new U2netCustomSession object.
|
| 26 |
|
| 27 |
Parameters:
|
| 28 |
model_name (str): The name of the model.
|
| 29 |
sess_opts (ort.SessionOptions): The session options.
|
| 30 |
-
providers: The providers.
|
| 31 |
*args: Additional positional arguments.
|
| 32 |
**kwargs: Additional keyword arguments.
|
| 33 |
|
|
@@ -38,7 +30,7 @@ class U2netCustomSession(BaseSession):
|
|
| 38 |
if model_path is None:
|
| 39 |
raise ValueError("model_path is required")
|
| 40 |
|
| 41 |
-
super().__init__(model_name, sess_opts,
|
| 42 |
|
| 43 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 44 |
"""
|
|
@@ -86,7 +78,7 @@ class U2netCustomSession(BaseSession):
|
|
| 86 |
"""
|
| 87 |
model_path = kwargs.get("model_path")
|
| 88 |
if model_path is None:
|
| 89 |
-
|
| 90 |
|
| 91 |
return os.path.abspath(os.path.expanduser(model_path))
|
| 92 |
|
|
|
|
| 13 |
class U2netCustomSession(BaseSession):
|
| 14 |
"""This is a class representing a custom session for the U2net model."""
|
| 15 |
|
| 16 |
+
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
"""
|
| 18 |
Initialize a new U2netCustomSession object.
|
| 19 |
|
| 20 |
Parameters:
|
| 21 |
model_name (str): The name of the model.
|
| 22 |
sess_opts (ort.SessionOptions): The session options.
|
|
|
|
| 23 |
*args: Additional positional arguments.
|
| 24 |
**kwargs: Additional keyword arguments.
|
| 25 |
|
|
|
|
| 30 |
if model_path is None:
|
| 31 |
raise ValueError("model_path is required")
|
| 32 |
|
| 33 |
+
super().__init__(model_name, sess_opts, *args, **kwargs)
|
| 34 |
|
| 35 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 36 |
"""
|
|
|
|
| 78 |
"""
|
| 79 |
model_path = kwargs.get("model_path")
|
| 80 |
if model_path is None:
|
| 81 |
+
raise ValueError("model_path is required")
|
| 82 |
|
| 83 |
return os.path.abspath(os.path.expanduser(model_path))
|
| 84 |
|
requirements.txt
CHANGED
|
@@ -1,21 +1,22 @@
|
|
| 1 |
-
pydantic==2.10.6
|
| 2 |
-
filetype==1.2.0
|
| 3 |
-
pooch==1.6.0
|
| 4 |
-
imagehash==4.3.1
|
| 5 |
-
numpy==1.23.5
|
| 6 |
-
onnxruntime
|
| 7 |
-
opencv-python-headless==4.6.0.66
|
| 8 |
-
pillow==9.3.0
|
| 9 |
-
pymatting==1.1.8
|
| 10 |
-
python-multipart==0.0.5
|
| 11 |
-
scikit-image==0.19.3
|
| 12 |
-
scipy==1.9.3
|
| 13 |
-
tqdm==4.64.1
|
| 14 |
-
uvicorn==0.20.0
|
| 15 |
-
watchdog==2.1.9
|
| 16 |
-
click==8.1.3
|
| 17 |
-
fastapi
|
| 18 |
-
aiohttp==3.8.3
|
| 19 |
-
asyncer==0.0.2
|
| 20 |
-
gradio
|
| 21 |
-
jsonschema==4.16.0
|
|
|
|
|
|
| 1 |
+
pydantic==2.10.6
|
| 2 |
+
filetype==1.2.0
|
| 3 |
+
pooch==1.6.0
|
| 4 |
+
imagehash==4.3.1
|
| 5 |
+
numpy==1.23.5
|
| 6 |
+
onnxruntime
|
| 7 |
+
opencv-python-headless==4.6.0.66
|
| 8 |
+
pillow==9.3.0
|
| 9 |
+
pymatting==1.1.8
|
| 10 |
+
python-multipart==0.0.5
|
| 11 |
+
scikit-image==0.19.3
|
| 12 |
+
scipy==1.9.3
|
| 13 |
+
tqdm==4.64.1
|
| 14 |
+
uvicorn==0.20.0
|
| 15 |
+
watchdog==2.1.9
|
| 16 |
+
click==8.1.3
|
| 17 |
+
fastapi
|
| 18 |
+
aiohttp==3.8.3
|
| 19 |
+
asyncer==0.0.2
|
| 20 |
+
gradio
|
| 21 |
+
jsonschema==4.16.0
|
| 22 |
+
huggingface-hub==0.34.3
|