nik1-kaj commited on
Commit
6d637a4
·
1 Parent(s): be6c3ba

Upload 21 files

Browse files
rembg/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from . import _version
2
+
3
+ __version__ = _version.get_versions()["version"]
4
+
5
+ from .bg import remove
6
+ from .session_factory import new_session
rembg/_version.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file helps to compute a version number in source trees obtained from
2
+ # git-archive tarball (such as those provided by githubs download-from-tag
3
+ # feature). Distribution tarballs (built by setup.py sdist) and build
4
+ # directories (produced by setup.py build) will contain a much shorter file
5
+ # that just contains the computed version number.
6
+
7
+ # This file is released into the public domain. Generated by
8
+ # versioneer-0.21 (https://github.com/python-versioneer/python-versioneer)
9
+
10
+ """Git implementation of _version.py."""
11
+
12
+ import errno
13
+ import os
14
+ import re
15
+ import subprocess
16
+ import sys
17
+ from typing import Callable, Dict
18
+
19
+
20
+ def get_keywords():
21
+ """Get the keywords needed to look up the version information."""
22
+ # these strings will be replaced by git during git-archive.
23
+ # setup.py/versioneer.py will grep for the variable names, so they must
24
+ # each be defined on a line of their own. _version.py will just call
25
+ # get_keywords().
26
+ git_refnames = " (HEAD -> main)"
27
+ git_full = "e47b2a0ed405a5a30f42bacb142b107f7a4b6536"
28
+ git_date = "2023-04-26 20:40:21 -0300"
29
+ keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
30
+ return keywords
31
+
32
+
33
+ class VersioneerConfig:
34
+ """Container for Versioneer configuration parameters."""
35
+
36
+
37
+ def get_config():
38
+ """Create, populate and return the VersioneerConfig() object."""
39
+ # these strings are filled in when 'setup.py versioneer' creates
40
+ # _version.py
41
+ cfg = VersioneerConfig()
42
+ cfg.VCS = "git"
43
+ cfg.style = "pep440"
44
+ cfg.tag_prefix = "v"
45
+ cfg.parentdir_prefix = "rembg-"
46
+ cfg.versionfile_source = "rembg/_version.py"
47
+ cfg.verbose = False
48
+ return cfg
49
+
50
+
51
+ class NotThisMethod(Exception):
52
+ """Exception raised if a method is not valid for the current scenario."""
53
+
54
+
55
+ LONG_VERSION_PY: Dict[str, str] = {}
56
+ HANDLERS: Dict[str, Dict[str, Callable]] = {}
57
+
58
+
59
+ def register_vcs_handler(vcs, method): # decorator
60
+ """Create decorator to mark a method as the handler of a VCS."""
61
+
62
+ def decorate(f):
63
+ """Store f in HANDLERS[vcs][method]."""
64
+ if vcs not in HANDLERS:
65
+ HANDLERS[vcs] = {}
66
+ HANDLERS[vcs][method] = f
67
+ return f
68
+
69
+ return decorate
70
+
71
+
72
+ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None):
73
+ """Call the given command(s)."""
74
+ assert isinstance(commands, list)
75
+ process = None
76
+ for command in commands:
77
+ try:
78
+ dispcmd = str([command] + args)
79
+ # remember shell=False, so use git.cmd on windows, not just git
80
+ process = subprocess.Popen(
81
+ [command] + args,
82
+ cwd=cwd,
83
+ env=env,
84
+ stdout=subprocess.PIPE,
85
+ stderr=(subprocess.PIPE if hide_stderr else None),
86
+ )
87
+ break
88
+ except OSError:
89
+ e = sys.exc_info()[1]
90
+ if e.errno == errno.ENOENT:
91
+ continue
92
+ if verbose:
93
+ print("unable to run %s" % dispcmd)
94
+ print(e)
95
+ return None, None
96
+ else:
97
+ if verbose:
98
+ print("unable to find command, tried %s" % (commands,))
99
+ return None, None
100
+ stdout = process.communicate()[0].strip().decode()
101
+ if process.returncode != 0:
102
+ if verbose:
103
+ print("unable to run %s (error)" % dispcmd)
104
+ print("stdout was %s" % stdout)
105
+ return None, process.returncode
106
+ return stdout, process.returncode
107
+
108
+
109
+ def versions_from_parentdir(parentdir_prefix, root, verbose):
110
+ """Try to determine the version from the parent directory name.
111
+
112
+ Source tarballs conventionally unpack into a directory that includes both
113
+ the project name and a version string. We will also support searching up
114
+ two directory levels for an appropriately named parent directory
115
+ """
116
+ rootdirs = []
117
+
118
+ for _ in range(3):
119
+ dirname = os.path.basename(root)
120
+ if dirname.startswith(parentdir_prefix):
121
+ return {
122
+ "version": dirname[len(parentdir_prefix) :],
123
+ "full-revisionid": None,
124
+ "dirty": False,
125
+ "error": None,
126
+ "date": None,
127
+ }
128
+ rootdirs.append(root)
129
+ root = os.path.dirname(root) # up a level
130
+
131
+ if verbose:
132
+ print(
133
+ "Tried directories %s but none started with prefix %s"
134
+ % (str(rootdirs), parentdir_prefix)
135
+ )
136
+ raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
137
+
138
+
139
+ @register_vcs_handler("git", "get_keywords")
140
+ def git_get_keywords(versionfile_abs):
141
+ """Extract version information from the given file."""
142
+ # the code embedded in _version.py can just fetch the value of these
143
+ # keywords. When used from setup.py, we don't want to import _version.py,
144
+ # so we do it with a regexp instead. This function is not used from
145
+ # _version.py.
146
+ keywords = {}
147
+ try:
148
+ with open(versionfile_abs, "r") as fobj:
149
+ for line in fobj:
150
+ if line.strip().startswith("git_refnames ="):
151
+ mo = re.search(r'=\s*"(.*)"', line)
152
+ if mo:
153
+ keywords["refnames"] = mo.group(1)
154
+ if line.strip().startswith("git_full ="):
155
+ mo = re.search(r'=\s*"(.*)"', line)
156
+ if mo:
157
+ keywords["full"] = mo.group(1)
158
+ if line.strip().startswith("git_date ="):
159
+ mo = re.search(r'=\s*"(.*)"', line)
160
+ if mo:
161
+ keywords["date"] = mo.group(1)
162
+ except OSError:
163
+ pass
164
+ return keywords
165
+
166
+
167
+ @register_vcs_handler("git", "keywords")
168
+ def git_versions_from_keywords(keywords, tag_prefix, verbose):
169
+ """Get version information from git keywords."""
170
+ if "refnames" not in keywords:
171
+ raise NotThisMethod("Short version file found")
172
+ date = keywords.get("date")
173
+ if date is not None:
174
+ # Use only the last line. Previous lines may contain GPG signature
175
+ # information.
176
+ date = date.splitlines()[-1]
177
+
178
+ # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
179
+ # datestamp. However we prefer "%ci" (which expands to an "ISO-8601
180
+ # -like" string, which we must then edit to make compliant), because
181
+ # it's been around since git-1.5.3, and it's too difficult to
182
+ # discover which version we're using, or to work around using an
183
+ # older one.
184
+ date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
185
+ refnames = keywords["refnames"].strip()
186
+ if refnames.startswith("$Format"):
187
+ if verbose:
188
+ print("keywords are unexpanded, not using")
189
+ raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
190
+ refs = {r.strip() for r in refnames.strip("()").split(",")}
191
+ # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
192
+ # just "foo-1.0". If we see a "tag: " prefix, prefer those.
193
+ TAG = "tag: "
194
+ tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)}
195
+ if not tags:
196
+ # Either we're using git < 1.8.3, or there really are no tags. We use
197
+ # a heuristic: assume all version tags have a digit. The old git %d
198
+ # expansion behaves like git log --decorate=short and strips out the
199
+ # refs/heads/ and refs/tags/ prefixes that would let us distinguish
200
+ # between branches and tags. By ignoring refnames without digits, we
201
+ # filter out many common branch names like "release" and
202
+ # "stabilization", as well as "HEAD" and "master".
203
+ tags = {r for r in refs if re.search(r"\d", r)}
204
+ if verbose:
205
+ print("discarding '%s', no digits" % ",".join(refs - tags))
206
+ if verbose:
207
+ print("likely tags: %s" % ",".join(sorted(tags)))
208
+ for ref in sorted(tags):
209
+ # sorting will prefer e.g. "2.0" over "2.0rc1"
210
+ if ref.startswith(tag_prefix):
211
+ r = ref[len(tag_prefix) :]
212
+ # Filter out refs that exactly match prefix or that don't start
213
+ # with a number once the prefix is stripped (mostly a concern
214
+ # when prefix is '')
215
+ if not re.match(r"\d", r):
216
+ continue
217
+ if verbose:
218
+ print("picking %s" % r)
219
+ return {
220
+ "version": r,
221
+ "full-revisionid": keywords["full"].strip(),
222
+ "dirty": False,
223
+ "error": None,
224
+ "date": date,
225
+ }
226
+ # no suitable tags, so version is "0+unknown", but full hex is still there
227
+ if verbose:
228
+ print("no suitable tags, using unknown + full revision id")
229
+ return {
230
+ "version": "0+unknown",
231
+ "full-revisionid": keywords["full"].strip(),
232
+ "dirty": False,
233
+ "error": "no suitable tags",
234
+ "date": None,
235
+ }
236
+
237
+
238
+ @register_vcs_handler("git", "pieces_from_vcs")
239
+ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
240
+ """Get version from 'git describe' in the root of the source tree.
241
+
242
+ This only gets called if the git-archive 'subst' keywords were *not*
243
+ expanded, and _version.py hasn't already been rewritten with a short
244
+ version string, meaning we're inside a checked out source tree.
245
+ """
246
+ GITS = ["git"]
247
+ TAG_PREFIX_REGEX = "*"
248
+ if sys.platform == "win32":
249
+ GITS = ["git.cmd", "git.exe"]
250
+ TAG_PREFIX_REGEX = r"\*"
251
+
252
+ _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True)
253
+ if rc != 0:
254
+ if verbose:
255
+ print("Directory %s not under git control" % root)
256
+ raise NotThisMethod("'git rev-parse --git-dir' returned error")
257
+
258
+ # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
259
+ # if there isn't one, this yields HEX[-dirty] (no NUM)
260
+ describe_out, rc = runner(
261
+ GITS,
262
+ [
263
+ "describe",
264
+ "--tags",
265
+ "--dirty",
266
+ "--always",
267
+ "--long",
268
+ "--match",
269
+ "%s%s" % (tag_prefix, TAG_PREFIX_REGEX),
270
+ ],
271
+ cwd=root,
272
+ )
273
+ # --long was added in git-1.5.5
274
+ if describe_out is None:
275
+ raise NotThisMethod("'git describe' failed")
276
+ describe_out = describe_out.strip()
277
+ full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
278
+ if full_out is None:
279
+ raise NotThisMethod("'git rev-parse' failed")
280
+ full_out = full_out.strip()
281
+
282
+ pieces = {}
283
+ pieces["long"] = full_out
284
+ pieces["short"] = full_out[:7] # maybe improved later
285
+ pieces["error"] = None
286
+
287
+ branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root)
288
+ # --abbrev-ref was added in git-1.6.3
289
+ if rc != 0 or branch_name is None:
290
+ raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
291
+ branch_name = branch_name.strip()
292
+
293
+ if branch_name == "HEAD":
294
+ # If we aren't exactly on a branch, pick a branch which represents
295
+ # the current commit. If all else fails, we are on a branchless
296
+ # commit.
297
+ branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
298
+ # --contains was added in git-1.5.4
299
+ if rc != 0 or branches is None:
300
+ raise NotThisMethod("'git branch --contains' returned error")
301
+ branches = branches.split("\n")
302
+
303
+ # Remove the first line if we're running detached
304
+ if "(" in branches[0]:
305
+ branches.pop(0)
306
+
307
+ # Strip off the leading "* " from the list of branches.
308
+ branches = [branch[2:] for branch in branches]
309
+ if "master" in branches:
310
+ branch_name = "master"
311
+ elif not branches:
312
+ branch_name = None
313
+ else:
314
+ # Pick the first branch that is returned. Good or bad.
315
+ branch_name = branches[0]
316
+
317
+ pieces["branch"] = branch_name
318
+
319
+ # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
320
+ # TAG might have hyphens.
321
+ git_describe = describe_out
322
+
323
+ # look for -dirty suffix
324
+ dirty = git_describe.endswith("-dirty")
325
+ pieces["dirty"] = dirty
326
+ if dirty:
327
+ git_describe = git_describe[: git_describe.rindex("-dirty")]
328
+
329
+ # now we have TAG-NUM-gHEX or HEX
330
+
331
+ if "-" in git_describe:
332
+ # TAG-NUM-gHEX
333
+ mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
334
+ if not mo:
335
+ # unparsable. Maybe git-describe is misbehaving?
336
+ pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
337
+ return pieces
338
+
339
+ # tag
340
+ full_tag = mo.group(1)
341
+ if not full_tag.startswith(tag_prefix):
342
+ if verbose:
343
+ fmt = "tag '%s' doesn't start with prefix '%s'"
344
+ print(fmt % (full_tag, tag_prefix))
345
+ pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (
346
+ full_tag,
347
+ tag_prefix,
348
+ )
349
+ return pieces
350
+ pieces["closest-tag"] = full_tag[len(tag_prefix) :]
351
+
352
+ # distance: number of commits since tag
353
+ pieces["distance"] = int(mo.group(2))
354
+
355
+ # commit: short hex revision ID
356
+ pieces["short"] = mo.group(3)
357
+
358
+ else:
359
+ # HEX: no tags
360
+ pieces["closest-tag"] = None
361
+ count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root)
362
+ pieces["distance"] = int(count_out) # total number of commits
363
+
364
+ # commit date: see ISO-8601 comment in git_versions_from_keywords()
365
+ date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
366
+ # Use only the last line. Previous lines may contain GPG signature
367
+ # information.
368
+ date = date.splitlines()[-1]
369
+ pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
370
+
371
+ return pieces
372
+
373
+
374
+ def plus_or_dot(pieces):
375
+ """Return a + if we don't already have one, else return a ."""
376
+ if "+" in pieces.get("closest-tag", ""):
377
+ return "."
378
+ return "+"
379
+
380
+
381
+ def render_pep440(pieces):
382
+ """Build up version string, with post-release "local version identifier".
383
+
384
+ Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
385
+ get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
386
+
387
+ Exceptions:
388
+ 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
389
+ """
390
+ if pieces["closest-tag"]:
391
+ rendered = pieces["closest-tag"]
392
+ if pieces["distance"] or pieces["dirty"]:
393
+ rendered += plus_or_dot(pieces)
394
+ rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
395
+ if pieces["dirty"]:
396
+ rendered += ".dirty"
397
+ else:
398
+ # exception #1
399
+ rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
400
+ if pieces["dirty"]:
401
+ rendered += ".dirty"
402
+ return rendered
403
+
404
+
405
+ def render_pep440_branch(pieces):
406
+ """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
407
+
408
+ The ".dev0" means not master branch. Note that .dev0 sorts backwards
409
+ (a feature branch will appear "older" than the master branch).
410
+
411
+ Exceptions:
412
+ 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
413
+ """
414
+ if pieces["closest-tag"]:
415
+ rendered = pieces["closest-tag"]
416
+ if pieces["distance"] or pieces["dirty"]:
417
+ if pieces["branch"] != "master":
418
+ rendered += ".dev0"
419
+ rendered += plus_or_dot(pieces)
420
+ rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
421
+ if pieces["dirty"]:
422
+ rendered += ".dirty"
423
+ else:
424
+ # exception #1
425
+ rendered = "0"
426
+ if pieces["branch"] != "master":
427
+ rendered += ".dev0"
428
+ rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
429
+ if pieces["dirty"]:
430
+ rendered += ".dirty"
431
+ return rendered
432
+
433
+
434
+ def pep440_split_post(ver):
435
+ """Split pep440 version string at the post-release segment.
436
+
437
+ Returns the release segments before the post-release and the
438
+ post-release version number (or -1 if no post-release segment is present).
439
+ """
440
+ vc = str.split(ver, ".post")
441
+ return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
442
+
443
+
444
+ def render_pep440_pre(pieces):
445
+ """TAG[.postN.devDISTANCE] -- No -dirty.
446
+
447
+ Exceptions:
448
+ 1: no tags. 0.post0.devDISTANCE
449
+ """
450
+ if pieces["closest-tag"]:
451
+ if pieces["distance"]:
452
+ # update the post release segment
453
+ tag_version, post_version = pep440_split_post(pieces["closest-tag"])
454
+ rendered = tag_version
455
+ if post_version is not None:
456
+ rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"])
457
+ else:
458
+ rendered += ".post0.dev%d" % (pieces["distance"])
459
+ else:
460
+ # no commits, use the tag as the version
461
+ rendered = pieces["closest-tag"]
462
+ else:
463
+ # exception #1
464
+ rendered = "0.post0.dev%d" % pieces["distance"]
465
+ return rendered
466
+
467
+
468
+ def render_pep440_post(pieces):
469
+ """TAG[.postDISTANCE[.dev0]+gHEX] .
470
+
471
+ The ".dev0" means dirty. Note that .dev0 sorts backwards
472
+ (a dirty tree will appear "older" than the corresponding clean one),
473
+ but you shouldn't be releasing software with -dirty anyways.
474
+
475
+ Exceptions:
476
+ 1: no tags. 0.postDISTANCE[.dev0]
477
+ """
478
+ if pieces["closest-tag"]:
479
+ rendered = pieces["closest-tag"]
480
+ if pieces["distance"] or pieces["dirty"]:
481
+ rendered += ".post%d" % pieces["distance"]
482
+ if pieces["dirty"]:
483
+ rendered += ".dev0"
484
+ rendered += plus_or_dot(pieces)
485
+ rendered += "g%s" % pieces["short"]
486
+ else:
487
+ # exception #1
488
+ rendered = "0.post%d" % pieces["distance"]
489
+ if pieces["dirty"]:
490
+ rendered += ".dev0"
491
+ rendered += "+g%s" % pieces["short"]
492
+ return rendered
493
+
494
+
495
+ def render_pep440_post_branch(pieces):
496
+ """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
497
+
498
+ The ".dev0" means not master branch.
499
+
500
+ Exceptions:
501
+ 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
502
+ """
503
+ if pieces["closest-tag"]:
504
+ rendered = pieces["closest-tag"]
505
+ if pieces["distance"] or pieces["dirty"]:
506
+ rendered += ".post%d" % pieces["distance"]
507
+ if pieces["branch"] != "master":
508
+ rendered += ".dev0"
509
+ rendered += plus_or_dot(pieces)
510
+ rendered += "g%s" % pieces["short"]
511
+ if pieces["dirty"]:
512
+ rendered += ".dirty"
513
+ else:
514
+ # exception #1
515
+ rendered = "0.post%d" % pieces["distance"]
516
+ if pieces["branch"] != "master":
517
+ rendered += ".dev0"
518
+ rendered += "+g%s" % pieces["short"]
519
+ if pieces["dirty"]:
520
+ rendered += ".dirty"
521
+ return rendered
522
+
523
+
524
+ def render_pep440_old(pieces):
525
+ """TAG[.postDISTANCE[.dev0]] .
526
+
527
+ The ".dev0" means dirty.
528
+
529
+ Exceptions:
530
+ 1: no tags. 0.postDISTANCE[.dev0]
531
+ """
532
+ if pieces["closest-tag"]:
533
+ rendered = pieces["closest-tag"]
534
+ if pieces["distance"] or pieces["dirty"]:
535
+ rendered += ".post%d" % pieces["distance"]
536
+ if pieces["dirty"]:
537
+ rendered += ".dev0"
538
+ else:
539
+ # exception #1
540
+ rendered = "0.post%d" % pieces["distance"]
541
+ if pieces["dirty"]:
542
+ rendered += ".dev0"
543
+ return rendered
544
+
545
+
546
+ def render_git_describe(pieces):
547
+ """TAG[-DISTANCE-gHEX][-dirty].
548
+
549
+ Like 'git describe --tags --dirty --always'.
550
+
551
+ Exceptions:
552
+ 1: no tags. HEX[-dirty] (note: no 'g' prefix)
553
+ """
554
+ if pieces["closest-tag"]:
555
+ rendered = pieces["closest-tag"]
556
+ if pieces["distance"]:
557
+ rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
558
+ else:
559
+ # exception #1
560
+ rendered = pieces["short"]
561
+ if pieces["dirty"]:
562
+ rendered += "-dirty"
563
+ return rendered
564
+
565
+
566
+ def render_git_describe_long(pieces):
567
+ """TAG-DISTANCE-gHEX[-dirty].
568
+
569
+ Like 'git describe --tags --dirty --always -long'.
570
+ The distance/hash is unconditional.
571
+
572
+ Exceptions:
573
+ 1: no tags. HEX[-dirty] (note: no 'g' prefix)
574
+ """
575
+ if pieces["closest-tag"]:
576
+ rendered = pieces["closest-tag"]
577
+ rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
578
+ else:
579
+ # exception #1
580
+ rendered = pieces["short"]
581
+ if pieces["dirty"]:
582
+ rendered += "-dirty"
583
+ return rendered
584
+
585
+
586
+ def render(pieces, style):
587
+ """Render the given version pieces into the requested style."""
588
+ if pieces["error"]:
589
+ return {
590
+ "version": "unknown",
591
+ "full-revisionid": pieces.get("long"),
592
+ "dirty": None,
593
+ "error": pieces["error"],
594
+ "date": None,
595
+ }
596
+
597
+ if not style or style == "default":
598
+ style = "pep440" # the default
599
+
600
+ if style == "pep440":
601
+ rendered = render_pep440(pieces)
602
+ elif style == "pep440-branch":
603
+ rendered = render_pep440_branch(pieces)
604
+ elif style == "pep440-pre":
605
+ rendered = render_pep440_pre(pieces)
606
+ elif style == "pep440-post":
607
+ rendered = render_pep440_post(pieces)
608
+ elif style == "pep440-post-branch":
609
+ rendered = render_pep440_post_branch(pieces)
610
+ elif style == "pep440-old":
611
+ rendered = render_pep440_old(pieces)
612
+ elif style == "git-describe":
613
+ rendered = render_git_describe(pieces)
614
+ elif style == "git-describe-long":
615
+ rendered = render_git_describe_long(pieces)
616
+ else:
617
+ raise ValueError("unknown style '%s'" % style)
618
+
619
+ return {
620
+ "version": rendered,
621
+ "full-revisionid": pieces["long"],
622
+ "dirty": pieces["dirty"],
623
+ "error": None,
624
+ "date": pieces.get("date"),
625
+ }
626
+
627
+
628
+ def get_versions():
629
+ """Get version information or return default if unable to do so."""
630
+ # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
631
+ # __file__, we can work backwards from there to the root. Some
632
+ # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
633
+ # case we can only use expanded keywords.
634
+
635
+ cfg = get_config()
636
+ verbose = cfg.verbose
637
+
638
+ try:
639
+ return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose)
640
+ except NotThisMethod:
641
+ pass
642
+
643
+ try:
644
+ root = os.path.realpath(__file__)
645
+ # versionfile_source is the relative path from the top of the source
646
+ # tree (where the .git directory might live) to this file. Invert
647
+ # this to find the root from __file__.
648
+ for _ in cfg.versionfile_source.split("/"):
649
+ root = os.path.dirname(root)
650
+ except NameError:
651
+ return {
652
+ "version": "0+unknown",
653
+ "full-revisionid": None,
654
+ "dirty": None,
655
+ "error": "unable to find root of source tree",
656
+ "date": None,
657
+ }
658
+
659
+ try:
660
+ pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
661
+ return render(pieces, cfg.style)
662
+ except NotThisMethod:
663
+ pass
664
+
665
+ try:
666
+ if cfg.parentdir_prefix:
667
+ return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
668
+ except NotThisMethod:
669
+ pass
670
+
671
+ return {
672
+ "version": "0+unknown",
673
+ "full-revisionid": None,
674
+ "dirty": None,
675
+ "error": "unable to compute version",
676
+ "date": None,
677
+ }
rembg/bg.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from enum import Enum
3
+ from typing import Any, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ from cv2 import (
7
+ BORDER_DEFAULT,
8
+ MORPH_ELLIPSE,
9
+ MORPH_OPEN,
10
+ GaussianBlur,
11
+ getStructuringElement,
12
+ morphologyEx,
13
+ )
14
+ from PIL import Image
15
+ from PIL.Image import Image as PILImage
16
+ from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
17
+ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
18
+ from pymatting.util.util import stack_images
19
+ from scipy.ndimage import binary_erosion
20
+
21
+ from .session_factory import new_session
22
+ from .sessions.base import BaseSession
23
+
24
+ kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
25
+
26
+
27
+ class ReturnType(Enum):
28
+ BYTES = 0
29
+ PILLOW = 1
30
+ NDARRAY = 2
31
+
32
+
33
+ def alpha_matting_cutout(
34
+ img: PILImage,
35
+ mask: PILImage,
36
+ foreground_threshold: int,
37
+ background_threshold: int,
38
+ erode_structure_size: int,
39
+ ) -> PILImage:
40
+ if img.mode == "RGBA" or img.mode == "CMYK":
41
+ img = img.convert("RGB")
42
+
43
+ img = np.asarray(img)
44
+ mask = np.asarray(mask)
45
+
46
+ is_foreground = mask > foreground_threshold
47
+ is_background = mask < background_threshold
48
+
49
+ structure = None
50
+ if erode_structure_size > 0:
51
+ structure = np.ones(
52
+ (erode_structure_size, erode_structure_size), dtype=np.uint8
53
+ )
54
+
55
+ is_foreground = binary_erosion(is_foreground, structure=structure)
56
+ is_background = binary_erosion(is_background, structure=structure, border_value=1)
57
+
58
+ trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
59
+ trimap[is_foreground] = 255
60
+ trimap[is_background] = 0
61
+
62
+ img_normalized = img / 255.0
63
+ trimap_normalized = trimap / 255.0
64
+
65
+ alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
66
+ foreground = estimate_foreground_ml(img_normalized, alpha)
67
+ cutout = stack_images(foreground, alpha)
68
+
69
+ cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
70
+ cutout = Image.fromarray(cutout)
71
+
72
+ return cutout
73
+
74
+
75
+ def naive_cutout(img: PILImage, mask: PILImage) -> PILImage:
76
+ empty = Image.new("RGBA", (img.size), 0)
77
+ cutout = Image.composite(img, empty, mask)
78
+ return cutout
79
+
80
+
81
+ def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
82
+ pivot = imgs.pop(0)
83
+ for im in imgs:
84
+ pivot = get_concat_v(pivot, im)
85
+ return pivot
86
+
87
+
88
+ def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
89
+ dst = Image.new("RGBA", (img1.width, img1.height + img2.height))
90
+ dst.paste(img1, (0, 0))
91
+ dst.paste(img2, (0, img1.height))
92
+ return dst
93
+
94
+
95
+ def post_process(mask: np.ndarray) -> np.ndarray:
96
+ """
97
+ Post Process the mask for a smooth boundary by applying Morphological Operations
98
+ Research based on paper: https://www.sciencedirect.com/science/article/pii/S2352914821000757
99
+ args:
100
+ mask: Binary Numpy Mask
101
+ """
102
+ mask = morphologyEx(mask, MORPH_OPEN, kernel)
103
+ mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT)
104
+ mask = np.where(mask < 127, 0, 255).astype(np.uint8) # convert again to binary
105
+ return mask
106
+
107
+
108
+ def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:
109
+ r, g, b, a = color
110
+ colored_image = Image.new("RGBA", img.size, (r, g, b, a))
111
+ colored_image.paste(img, mask=img)
112
+
113
+ return colored_image
114
+
115
+
116
+ def remove(
117
+ data: Union[bytes, PILImage, np.ndarray],
118
+ alpha_matting: bool = False,
119
+ alpha_matting_foreground_threshold: int = 240,
120
+ alpha_matting_background_threshold: int = 10,
121
+ alpha_matting_erode_size: int = 10,
122
+ session: Optional[BaseSession] = None,
123
+ only_mask: bool = False,
124
+ post_process_mask: bool = False,
125
+ bgcolor: Optional[Tuple[int, int, int, int]] = None,
126
+ *args: Optional[Any],
127
+ **kwargs: Optional[Any]
128
+ ) -> Union[bytes, PILImage, np.ndarray]:
129
+ if isinstance(data, PILImage):
130
+ return_type = ReturnType.PILLOW
131
+ img = data
132
+ elif isinstance(data, bytes):
133
+ return_type = ReturnType.BYTES
134
+ img = Image.open(io.BytesIO(data))
135
+ elif isinstance(data, np.ndarray):
136
+ return_type = ReturnType.NDARRAY
137
+ img = Image.fromarray(data)
138
+ else:
139
+ raise ValueError("Input type {} is not supported.".format(type(data)))
140
+
141
+ if session is None:
142
+ session = new_session("u2net", *args, **kwargs)
143
+
144
+ masks = session.predict(img, *args, **kwargs)
145
+ cutouts = []
146
+
147
+ for mask in masks:
148
+ if post_process_mask:
149
+ mask = Image.fromarray(post_process(np.array(mask)))
150
+
151
+ if only_mask:
152
+ cutout = mask
153
+
154
+ elif alpha_matting:
155
+ try:
156
+ cutout = alpha_matting_cutout(
157
+ img,
158
+ mask,
159
+ alpha_matting_foreground_threshold,
160
+ alpha_matting_background_threshold,
161
+ alpha_matting_erode_size,
162
+ )
163
+ except ValueError:
164
+ cutout = naive_cutout(img, mask)
165
+
166
+ else:
167
+ cutout = naive_cutout(img, mask)
168
+
169
+ cutouts.append(cutout)
170
+
171
+ cutout = img
172
+ if len(cutouts) > 0:
173
+ cutout = get_concat_v_multi(cutouts)
174
+
175
+ if bgcolor is not None and not only_mask:
176
+ cutout = apply_background_color(cutout, bgcolor)
177
+
178
+ if ReturnType.PILLOW == return_type:
179
+ return cutout
180
+
181
+ if ReturnType.NDARRAY == return_type:
182
+ return np.asarray(cutout)
183
+
184
+ bio = io.BytesIO()
185
+ cutout.save(bio, "PNG")
186
+ bio.seek(0)
187
+
188
+ return bio.read()
rembg/cli.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+
3
+ from . import _version
4
+ from .commands import command_functions
5
+
6
+
7
+ @click.group()
8
+ @click.version_option(version=_version.get_versions()["version"])
9
+ def main() -> None:
10
+ pass
11
+
12
+
13
+ for command in command_functions:
14
+ main.add_command(command)
rembg/commands/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from pathlib import Path
3
+ from pkgutil import iter_modules
4
+
5
+ command_functions = []
6
+
7
+ package_dir = Path(__file__).resolve().parent
8
+ for _b, module_name, _p in iter_modules([str(package_dir)]):
9
+ module = import_module(f"{__name__}.{module_name}")
10
+ for attribute_name in dir(module):
11
+ attribute = getattr(module, attribute_name)
12
+ if attribute_name.endswith("_command"):
13
+ command_functions.append(attribute)
rembg/commands/i_command.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ from typing import IO
4
+
5
+ import click
6
+
7
+ from ..bg import remove
8
+ from ..session_factory import new_session
9
+ from ..sessions import sessions_names
10
+
11
+
12
+ @click.command(
13
+ name="i",
14
+ help="for a file as input",
15
+ )
16
+ @click.option(
17
+ "-m",
18
+ "--model",
19
+ default="u2net",
20
+ type=click.Choice(sessions_names),
21
+ show_default=True,
22
+ show_choices=True,
23
+ help="model name",
24
+ )
25
+ @click.option(
26
+ "-a",
27
+ "--alpha-matting",
28
+ is_flag=True,
29
+ show_default=True,
30
+ help="use alpha matting",
31
+ )
32
+ @click.option(
33
+ "-af",
34
+ "--alpha-matting-foreground-threshold",
35
+ default=240,
36
+ type=int,
37
+ show_default=True,
38
+ help="trimap fg threshold",
39
+ )
40
+ @click.option(
41
+ "-ab",
42
+ "--alpha-matting-background-threshold",
43
+ default=10,
44
+ type=int,
45
+ show_default=True,
46
+ help="trimap bg threshold",
47
+ )
48
+ @click.option(
49
+ "-ae",
50
+ "--alpha-matting-erode-size",
51
+ default=10,
52
+ type=int,
53
+ show_default=True,
54
+ help="erode size",
55
+ )
56
+ @click.option(
57
+ "-om",
58
+ "--only-mask",
59
+ is_flag=True,
60
+ show_default=True,
61
+ help="output only the mask",
62
+ )
63
+ @click.option(
64
+ "-ppm",
65
+ "--post-process-mask",
66
+ is_flag=True,
67
+ show_default=True,
68
+ help="post process the mask",
69
+ )
70
+ @click.option(
71
+ "-bgc",
72
+ "--bgcolor",
73
+ default=None,
74
+ type=(int, int, int, int),
75
+ nargs=4,
76
+ help="Background color (R G B A) to replace the removed background with",
77
+ )
78
+ @click.option("-x", "--extras", type=str)
79
+ @click.argument(
80
+ "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
81
+ )
82
+ @click.argument(
83
+ "output",
84
+ default=(None if sys.stdin.isatty() else "-"),
85
+ type=click.File("wb", lazy=True),
86
+ )
87
+ def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
88
+ try:
89
+ kwargs.update(json.loads(extras))
90
+ except Exception:
91
+ pass
92
+
93
+ output.write(remove(input.read(), session=new_session(model), **kwargs))
rembg/commands/p_command.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pathlib
3
+ import time
4
+ from typing import cast
5
+
6
+ import click
7
+ import filetype
8
+ from tqdm import tqdm
9
+ from watchdog.events import FileSystemEvent, FileSystemEventHandler
10
+ from watchdog.observers import Observer
11
+
12
+ from ..bg import remove
13
+ from ..session_factory import new_session
14
+ from ..sessions import sessions_names
15
+
16
+
17
+ @click.command(
18
+ name="p",
19
+ help="for a folder as input",
20
+ )
21
+ @click.option(
22
+ "-m",
23
+ "--model",
24
+ default="u2net",
25
+ type=click.Choice(sessions_names),
26
+ show_default=True,
27
+ show_choices=True,
28
+ help="model name",
29
+ )
30
+ @click.option(
31
+ "-a",
32
+ "--alpha-matting",
33
+ is_flag=True,
34
+ show_default=True,
35
+ help="use alpha matting",
36
+ )
37
+ @click.option(
38
+ "-af",
39
+ "--alpha-matting-foreground-threshold",
40
+ default=240,
41
+ type=int,
42
+ show_default=True,
43
+ help="trimap fg threshold",
44
+ )
45
+ @click.option(
46
+ "-ab",
47
+ "--alpha-matting-background-threshold",
48
+ default=10,
49
+ type=int,
50
+ show_default=True,
51
+ help="trimap bg threshold",
52
+ )
53
+ @click.option(
54
+ "-ae",
55
+ "--alpha-matting-erode-size",
56
+ default=10,
57
+ type=int,
58
+ show_default=True,
59
+ help="erode size",
60
+ )
61
+ @click.option(
62
+ "-om",
63
+ "--only-mask",
64
+ is_flag=True,
65
+ show_default=True,
66
+ help="output only the mask",
67
+ )
68
+ @click.option(
69
+ "-ppm",
70
+ "--post-process-mask",
71
+ is_flag=True,
72
+ show_default=True,
73
+ help="post process the mask",
74
+ )
75
+ @click.option(
76
+ "-w",
77
+ "--watch",
78
+ default=False,
79
+ is_flag=True,
80
+ show_default=True,
81
+ help="watches a folder for changes",
82
+ )
83
+ @click.option(
84
+ "-bgc",
85
+ "--bgcolor",
86
+ default=None,
87
+ type=(int, int, int, int),
88
+ nargs=4,
89
+ help="Background color (R G B A) to replace the removed background with",
90
+ )
91
+ @click.option("-x", "--extras", type=str)
92
+ @click.argument(
93
+ "input",
94
+ type=click.Path(
95
+ exists=True,
96
+ path_type=pathlib.Path,
97
+ file_okay=False,
98
+ dir_okay=True,
99
+ readable=True,
100
+ ),
101
+ )
102
+ @click.argument(
103
+ "output",
104
+ type=click.Path(
105
+ exists=False,
106
+ path_type=pathlib.Path,
107
+ file_okay=False,
108
+ dir_okay=True,
109
+ writable=True,
110
+ ),
111
+ )
112
+ def p_command(
113
+ model: str,
114
+ extras: str,
115
+ input: pathlib.Path,
116
+ output: pathlib.Path,
117
+ watch: bool,
118
+ **kwargs,
119
+ ) -> None:
120
+ try:
121
+ kwargs.update(json.loads(extras))
122
+ except Exception:
123
+ pass
124
+
125
+ session = new_session(model)
126
+
127
+ def process(each_input: pathlib.Path) -> None:
128
+ try:
129
+ mimetype = filetype.guess(each_input)
130
+ if mimetype is None:
131
+ return
132
+ if mimetype.mime.find("image") < 0:
133
+ return
134
+
135
+ each_output = (output / each_input.name).with_suffix(".png")
136
+ each_output.parents[0].mkdir(parents=True, exist_ok=True)
137
+
138
+ if not each_output.exists():
139
+ each_output.write_bytes(
140
+ cast(
141
+ bytes,
142
+ remove(each_input.read_bytes(), session=session, **kwargs),
143
+ )
144
+ )
145
+
146
+ if watch:
147
+ print(
148
+ f"processed: {each_input.absolute()} -> {each_output.absolute()}"
149
+ )
150
+ except Exception as e:
151
+ print(e)
152
+
153
+ inputs = list(input.glob("**/*"))
154
+ if not watch:
155
+ inputs = tqdm(inputs)
156
+
157
+ for each_input in inputs:
158
+ if not each_input.is_dir():
159
+ process(each_input)
160
+
161
+ if watch:
162
+ observer = Observer()
163
+
164
+ class EventHandler(FileSystemEventHandler):
165
+ def on_any_event(self, event: FileSystemEvent) -> None:
166
+ if not (
167
+ event.is_directory or event.event_type in ["deleted", "closed"]
168
+ ):
169
+ process(pathlib.Path(event.src_path))
170
+
171
+ event_handler = EventHandler()
172
+ observer.schedule(event_handler, input, recursive=False)
173
+ observer.start()
174
+
175
+ try:
176
+ while True:
177
+ time.sleep(1)
178
+
179
+ finally:
180
+ observer.stop()
181
+ observer.join()
rembg/commands/s_command.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Annotated, Optional, Tuple, cast
3
+
4
+ import aiohttp
5
+ import click
6
+ import uvicorn
7
+ from asyncer import asyncify
8
+ from fastapi import Depends, FastAPI, File, Form, Query
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from starlette.responses import Response
11
+
12
+ from .._version import get_versions
13
+ from ..bg import remove
14
+ from ..session_factory import new_session
15
+ from ..sessions import sessions_names
16
+ from ..sessions.base import BaseSession
17
+
18
+
19
+ @click.command(
20
+ name="s",
21
+ help="for a http server",
22
+ )
23
+ @click.option(
24
+ "-p",
25
+ "--port",
26
+ default=5000,
27
+ type=int,
28
+ show_default=True,
29
+ help="port",
30
+ )
31
+ @click.option(
32
+ "-l",
33
+ "--log_level",
34
+ default="info",
35
+ type=str,
36
+ show_default=True,
37
+ help="log level",
38
+ )
39
+ @click.option(
40
+ "-t",
41
+ "--threads",
42
+ default=None,
43
+ type=int,
44
+ show_default=True,
45
+ help="number of worker threads",
46
+ )
47
+ def s_command(port: int, log_level: str, threads: int) -> None:
48
+ sessions: dict[str, BaseSession] = {}
49
+ tags_metadata = [
50
+ {
51
+ "name": "Background Removal",
52
+ "description": "Endpoints that perform background removal with different image sources.",
53
+ "externalDocs": {
54
+ "description": "GitHub Source",
55
+ "url": "https://github.com/danielgatis/rembg",
56
+ },
57
+ },
58
+ ]
59
+ app = FastAPI(
60
+ title="Rembg",
61
+ description="Rembg is a tool to remove images background. That is it.",
62
+ version=get_versions()["version"],
63
+ contact={
64
+ "name": "Daniel Gatis",
65
+ "url": "https://github.com/danielgatis",
66
+ "email": "danielgatis@gmail.com",
67
+ },
68
+ license_info={
69
+ "name": "MIT License",
70
+ "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
71
+ },
72
+ openapi_tags=tags_metadata,
73
+ )
74
+
75
+ app.add_middleware(
76
+ CORSMiddleware,
77
+ allow_credentials=True,
78
+ allow_origins=["*"],
79
+ allow_methods=["*"],
80
+ allow_headers=["*"],
81
+ )
82
+
83
+ class CommonQueryParams:
84
+ def __init__(
85
+ self,
86
+ model: Annotated[
87
+ str, Query(regex=r"(" + "|".join(sessions_names) + ")")
88
+ ] = Query(
89
+ description="Model to use when processing image",
90
+ ),
91
+ a: bool = Query(default=False, description="Enable Alpha Matting"),
92
+ af: int = Query(
93
+ default=240,
94
+ ge=0,
95
+ le=255,
96
+ description="Alpha Matting (Foreground Threshold)",
97
+ ),
98
+ ab: int = Query(
99
+ default=10,
100
+ ge=0,
101
+ le=255,
102
+ description="Alpha Matting (Background Threshold)",
103
+ ),
104
+ ae: int = Query(
105
+ default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
106
+ ),
107
+ om: bool = Query(default=False, description="Only Mask"),
108
+ ppm: bool = Query(default=False, description="Post Process Mask"),
109
+ bgc: Optional[str] = Query(default=None, description="Background Color"),
110
+ extras: Optional[str] = Query(
111
+ default=None, description="Extra parameters as JSON"
112
+ ),
113
+ ):
114
+ self.model = model
115
+ self.a = a
116
+ self.af = af
117
+ self.ab = ab
118
+ self.ae = ae
119
+ self.om = om
120
+ self.ppm = ppm
121
+ self.extras = extras
122
+ self.bgc = (
123
+ cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
124
+ if bgc
125
+ else None
126
+ )
127
+
128
+ class CommonQueryPostParams:
129
+ def __init__(
130
+ self,
131
+ model: Annotated[
132
+ str, Form(regex=r"(" + "|".join(sessions_names) + ")")
133
+ ] = Form(
134
+ description="Model to use when processing image",
135
+ ),
136
+ a: bool = Form(default=False, description="Enable Alpha Matting"),
137
+ af: int = Form(
138
+ default=240,
139
+ ge=0,
140
+ le=255,
141
+ description="Alpha Matting (Foreground Threshold)",
142
+ ),
143
+ ab: int = Form(
144
+ default=10,
145
+ ge=0,
146
+ le=255,
147
+ description="Alpha Matting (Background Threshold)",
148
+ ),
149
+ ae: int = Form(
150
+ default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
151
+ ),
152
+ om: bool = Form(default=False, description="Only Mask"),
153
+ ppm: bool = Form(default=False, description="Post Process Mask"),
154
+ bgc: Optional[str] = Query(default=None, description="Background Color"),
155
+ extras: Optional[str] = Query(
156
+ default=None, description="Extra parameters as JSON"
157
+ ),
158
+ ):
159
+ self.model = model
160
+ self.a = a
161
+ self.af = af
162
+ self.ab = ab
163
+ self.ae = ae
164
+ self.om = om
165
+ self.ppm = ppm
166
+ self.extras = extras
167
+ self.bgc = (
168
+ cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
169
+ if bgc
170
+ else None
171
+ )
172
+
173
+ def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
174
+ kwargs = {}
175
+
176
+ if commons.extras:
177
+ try:
178
+ kwargs.update(json.loads(commons.extras))
179
+ except Exception:
180
+ pass
181
+
182
+ return Response(
183
+ remove(
184
+ content,
185
+ session=sessions.setdefault(commons.model, new_session(commons.model)),
186
+ alpha_matting=commons.a,
187
+ alpha_matting_foreground_threshold=commons.af,
188
+ alpha_matting_background_threshold=commons.ab,
189
+ alpha_matting_erode_size=commons.ae,
190
+ only_mask=commons.om,
191
+ post_process_mask=commons.ppm,
192
+ bgcolor=commons.bgc,
193
+ **kwargs
194
+ ),
195
+ media_type="image/png",
196
+ )
197
+
198
+ @app.on_event("startup")
199
+ def startup():
200
+ if threads is not None:
201
+ from anyio import CapacityLimiter
202
+ from anyio.lowlevel import RunVar
203
+
204
+ RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
205
+
206
+ @app.get(
207
+ path="/",
208
+ tags=["Background Removal"],
209
+ summary="Remove from URL",
210
+ description="Removes the background from an image obtained by retrieving an URL.",
211
+ )
212
+ async def get_index(
213
+ url: str = Query(
214
+ default=..., description="URL of the image that has to be processed."
215
+ ),
216
+ commons: CommonQueryParams = Depends(),
217
+ ):
218
+ async with aiohttp.ClientSession() as session:
219
+ async with session.get(url) as response:
220
+ file = await response.read()
221
+ return await asyncify(im_without_bg)(file, commons)
222
+
223
+ @app.post(
224
+ path="/",
225
+ tags=["Background Removal"],
226
+ summary="Remove from Stream",
227
+ description="Removes the background from an image sent within the request itself.",
228
+ )
229
+ async def post_index(
230
+ file: bytes = File(
231
+ default=...,
232
+ description="Image file (byte stream) that has to be processed.",
233
+ ),
234
+ commons: CommonQueryPostParams = Depends(),
235
+ ):
236
+ return await asyncify(im_without_bg)(file, commons) # type: ignore
237
+
238
+ uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
rembg/session_base.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple
2
+
3
+ import numpy as np
4
+ import onnxruntime as ort
5
+ from PIL import Image
6
+ from PIL.Image import Image as PILImage
7
+
8
+
9
+ class BaseSession:
10
+ def __init__(self, model_name: str, inner_session: ort.InferenceSession):
11
+ self.model_name = model_name
12
+ self.inner_session = inner_session
13
+
14
+ def normalize(
15
+ self,
16
+ img: PILImage,
17
+ mean: Tuple[float, float, float],
18
+ std: Tuple[float, float, float],
19
+ size: Tuple[int, int],
20
+ ) -> Dict[str, np.ndarray]:
21
+ im = img.convert("RGB").resize(size, Image.LANCZOS)
22
+
23
+ im_ary = np.array(im)
24
+ im_ary = im_ary / np.max(im_ary)
25
+
26
+ tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
27
+ tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
28
+ tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
29
+ tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]
30
+
31
+ tmpImg = tmpImg.transpose((2, 0, 1))
32
+
33
+ return {
34
+ self.inner_session.get_inputs()[0]
35
+ .name: np.expand_dims(tmpImg, 0)
36
+ .astype(np.float32)
37
+ }
38
+
39
+ def predict(self, img: PILImage) -> List[PILImage]:
40
+ raise NotImplementedError
rembg/session_cloth.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ from PIL.Image import Image as PILImage
6
+ from scipy.special import log_softmax
7
+
8
+ from .session_base import BaseSession
9
+
10
+ pallete1 = [
11
+ 0,
12
+ 0,
13
+ 0,
14
+ 255,
15
+ 255,
16
+ 255,
17
+ 0,
18
+ 0,
19
+ 0,
20
+ 0,
21
+ 0,
22
+ 0,
23
+ ]
24
+
25
+ pallete2 = [
26
+ 0,
27
+ 0,
28
+ 0,
29
+ 0,
30
+ 0,
31
+ 0,
32
+ 255,
33
+ 255,
34
+ 255,
35
+ 0,
36
+ 0,
37
+ 0,
38
+ ]
39
+
40
+ pallete3 = [
41
+ 0,
42
+ 0,
43
+ 0,
44
+ 0,
45
+ 0,
46
+ 0,
47
+ 0,
48
+ 0,
49
+ 0,
50
+ 255,
51
+ 255,
52
+ 255,
53
+ ]
54
+
55
+
56
+ class ClothSession(BaseSession):
57
+ def predict(self, img: PILImage) -> List[PILImage]:
58
+ ort_outs = self.inner_session.run(
59
+ None, self.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), (768, 768))
60
+ )
61
+
62
+ pred = ort_outs
63
+ pred = log_softmax(pred[0], 1)
64
+ pred = np.argmax(pred, axis=1, keepdims=True)
65
+ pred = np.squeeze(pred, 0)
66
+ pred = np.squeeze(pred, 0)
67
+
68
+ mask = Image.fromarray(pred.astype("uint8"), mode="L")
69
+ mask = mask.resize(img.size, Image.LANCZOS)
70
+
71
+ masks = []
72
+
73
+ mask1 = mask.copy()
74
+ mask1.putpalette(pallete1)
75
+ mask1 = mask1.convert("RGB").convert("L")
76
+ masks.append(mask1)
77
+
78
+ mask2 = mask.copy()
79
+ mask2.putpalette(pallete2)
80
+ mask2 = mask2.convert("RGB").convert("L")
81
+ masks.append(mask2)
82
+
83
+ mask3 = mask.copy()
84
+ mask3.putpalette(pallete3)
85
+ mask3 = mask3.convert("RGB").convert("L")
86
+ masks.append(mask3)
87
+
88
+ return masks
rembg/session_factory.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Type
3
+
4
+ import onnxruntime as ort
5
+
6
+ from .sessions import sessions_class
7
+ from .sessions.base import BaseSession
8
+ from .sessions.u2net import U2netSession
9
+
10
+
11
+ def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
12
+ session_class: Type[BaseSession] = U2netSession
13
+
14
+ for sc in sessions_class:
15
+ if sc.name() == model_name:
16
+ session_class = sc
17
+ break
18
+
19
+ sess_opts = ort.SessionOptions()
20
+
21
+ if "OMP_NUM_THREADS" in os.environ:
22
+ sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
23
+
24
+ return session_class(model_name, sess_opts, *args, **kwargs)
rembg/session_simple.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ from PIL.Image import Image as PILImage
6
+
7
+ from .session_base import BaseSession
8
+
9
+
10
+ class SimpleSession(BaseSession):
11
+ def predict(self, img: PILImage) -> List[PILImage]:
12
+ ort_outs = self.inner_session.run(
13
+ None,
14
+ self.normalize(
15
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
16
+ ),
17
+ )
18
+
19
+ pred = ort_outs[0][:, 0, :, :]
20
+
21
+ ma = np.max(pred)
22
+ mi = np.min(pred)
23
+
24
+ pred = (pred - mi) / (ma - mi)
25
+ pred = np.squeeze(pred)
26
+
27
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
28
+ mask = mask.resize(img.size, Image.LANCZOS)
29
+
30
+ return [mask]
rembg/sessions/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from inspect import isclass
3
+ from pathlib import Path
4
+ from pkgutil import iter_modules
5
+
6
+ from .base import BaseSession
7
+
8
+ sessions_class = []
9
+ sessions_names = []
10
+
11
+ package_dir = Path(__file__).resolve().parent
12
+ for _b, module_name, _p in iter_modules([str(package_dir)]):
13
+ module = import_module(f"{__name__}.{module_name}")
14
+ for attribute_name in dir(module):
15
+ attribute = getattr(module, attribute_name)
16
+ if (
17
+ isclass(attribute)
18
+ and issubclass(attribute, BaseSession)
19
+ and attribute != BaseSession
20
+ ):
21
+ sessions_class.append(attribute)
22
+ sessions_names.append(attribute.name())
rembg/sessions/base.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Tuple
3
+
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+
10
+ class BaseSession:
11
+ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
12
+ self.model_name = model_name
13
+ self.inner_session = ort.InferenceSession(
14
+ str(self.__class__.download_models()),
15
+ providers=ort.get_available_providers(),
16
+ sess_options=sess_opts,
17
+ )
18
+
19
+ def normalize(
20
+ self,
21
+ img: PILImage,
22
+ mean: Tuple[float, float, float],
23
+ std: Tuple[float, float, float],
24
+ size: Tuple[int, int],
25
+ *args,
26
+ **kwargs
27
+ ) -> Dict[str, np.ndarray]:
28
+ im = img.convert("RGB").resize(size, Image.LANCZOS)
29
+
30
+ im_ary = np.array(im)
31
+ im_ary = im_ary / np.max(im_ary)
32
+
33
+ tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
34
+ tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
35
+ tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
36
+ tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]
37
+
38
+ tmpImg = tmpImg.transpose((2, 0, 1))
39
+
40
+ return {
41
+ self.inner_session.get_inputs()[0]
42
+ .name: np.expand_dims(tmpImg, 0)
43
+ .astype(np.float32)
44
+ }
45
+
46
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
47
+ raise NotImplementedError
48
+
49
+ @classmethod
50
+ def u2net_home(cls, *args, **kwargs):
51
+ return os.path.expanduser(
52
+ os.getenv(
53
+ "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
54
+ )
55
+ )
56
+
57
+ @classmethod
58
+ def download_models(cls, *args, **kwargs):
59
+ raise NotImplementedError
60
+
61
+ @classmethod
62
+ def name(cls, *args, **kwargs):
63
+ raise NotImplementedError
rembg/sessions/dis.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class DisSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
17
+ )
18
+
19
+ pred = ort_outs[0][:, 0, :, :]
20
+
21
+ ma = np.max(pred)
22
+ mi = np.min(pred)
23
+
24
+ pred = (pred - mi) / (ma - mi)
25
+ pred = np.squeeze(pred)
26
+
27
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
28
+ mask = mask.resize(img.size, Image.LANCZOS)
29
+
30
+ return [mask]
31
+
32
+ @classmethod
33
+ def download_models(cls, *args, **kwargs):
34
+ fname = f"{cls.name()}.onnx"
35
+ pooch.retrieve(
36
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
37
+ "md5:fc16ebd8b0c10d971d3513d564d01e29",
38
+ fname=fname,
39
+ path=cls.u2net_home(),
40
+ progressbar=True,
41
+ )
42
+
43
+ return os.path.join(cls.u2net_home(), fname)
44
+
45
+ @classmethod
46
+ def name(cls, *args, **kwargs):
47
+ return "isnet-general-use"
rembg/sessions/sam.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ import pooch
7
+ from PIL import Image
8
+ from PIL.Image import Image as PILImage
9
+
10
+ from .base import BaseSession
11
+
12
+
13
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
14
+ scale = long_side_length * 1.0 / max(oldh, oldw)
15
+ newh, neww = oldh * scale, oldw * scale
16
+ neww = int(neww + 0.5)
17
+ newh = int(newh + 0.5)
18
+ return (newh, neww)
19
+
20
+
21
+ def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray:
22
+ old_h, old_w = original_size
23
+ new_h, new_w = get_preprocess_shape(
24
+ original_size[0], original_size[1], target_length
25
+ )
26
+ coords = coords.copy().astype(float)
27
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
28
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
29
+ return coords
30
+
31
+
32
+ def resize_longes_side(img: PILImage, size=1024):
33
+ w, h = img.size
34
+ if h > w:
35
+ new_h, new_w = size, int(w * size / h)
36
+ else:
37
+ new_h, new_w = int(h * size / w), size
38
+
39
+ return img.resize((new_w, new_h))
40
+
41
+
42
+ def pad_to_square(img: np.ndarray, size=1024):
43
+ h, w = img.shape[:2]
44
+ padh = size - h
45
+ padw = size - w
46
+ img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant")
47
+ img = img.astype(np.float32)
48
+ return img
49
+
50
+
51
+ class SamSession(BaseSession):
52
+ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
53
+ self.model_name = model_name
54
+ paths = self.__class__.download_models()
55
+ self.encoder = ort.InferenceSession(
56
+ str(paths[0]),
57
+ providers=ort.get_available_providers(),
58
+ sess_options=sess_opts,
59
+ )
60
+ self.decoder = ort.InferenceSession(
61
+ str(paths[1]),
62
+ providers=ort.get_available_providers(),
63
+ sess_options=sess_opts,
64
+ )
65
+
66
+ def normalize(
67
+ self,
68
+ img: np.ndarray,
69
+ mean=(123.675, 116.28, 103.53),
70
+ std=(58.395, 57.12, 57.375),
71
+ size=(1024, 1024),
72
+ *args,
73
+ **kwargs,
74
+ ):
75
+ pixel_mean = np.array([*mean]).reshape(1, 1, -1)
76
+ pixel_std = np.array([*std]).reshape(1, 1, -1)
77
+ x = (img - pixel_mean) / pixel_std
78
+ return x
79
+
80
+ def predict(
81
+ self,
82
+ img: PILImage,
83
+ *args,
84
+ **kwargs,
85
+ ) -> List[PILImage]:
86
+ # Preprocess image
87
+ image = resize_longes_side(img)
88
+ image = np.array(image)
89
+ image = self.normalize(image)
90
+ image = pad_to_square(image)
91
+
92
+ input_labels = kwargs.get("input_labels")
93
+ input_points = kwargs.get("input_points")
94
+
95
+ if input_labels is None:
96
+ raise ValueError("input_labels is required")
97
+ if input_points is None:
98
+ raise ValueError("input_points is required")
99
+
100
+ # Transpose
101
+ image = image.transpose(2, 0, 1)[None, :, :, :]
102
+ # Run encoder (Image embedding)
103
+ encoded = self.encoder.run(None, {"x": image})
104
+ image_embedding = encoded[0]
105
+
106
+ # Add a batch index, concatenate a padding point, and transform.
107
+ onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
108
+ None, :, :
109
+ ]
110
+ onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
111
+ None, :
112
+ ].astype(np.float32)
113
+ onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)
114
+
115
+ # Create an empty mask input and an indicator for no mask.
116
+ onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
117
+ onnx_has_mask_input = np.zeros(1, dtype=np.float32)
118
+
119
+ decoder_inputs = {
120
+ "image_embeddings": image_embedding,
121
+ "point_coords": onnx_coord,
122
+ "point_labels": onnx_label,
123
+ "mask_input": onnx_mask_input,
124
+ "has_mask_input": onnx_has_mask_input,
125
+ "orig_im_size": np.array(img.size[::-1], dtype=np.float32),
126
+ }
127
+
128
+ masks, _, low_res_logits = self.decoder.run(None, decoder_inputs)
129
+ masks = masks > 0.0
130
+ masks = [
131
+ Image.fromarray((masks[i, 0] * 255).astype(np.uint8))
132
+ for i in range(masks.shape[0])
133
+ ]
134
+
135
+ return masks
136
+
137
+ @classmethod
138
+ def download_models(cls, *args, **kwargs):
139
+ fname_encoder = f"{cls.name()}_encoder.onnx"
140
+ fname_decoder = f"{cls.name()}_decoder.onnx"
141
+
142
+ pooch.retrieve(
143
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
144
+ "md5:13d97c5c79ab13ef86d67cbde5f1b250",
145
+ fname=fname_encoder,
146
+ path=cls.u2net_home(),
147
+ progressbar=True,
148
+ )
149
+
150
+ pooch.retrieve(
151
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
152
+ "md5:fa3d1c36a3187d3de1c8deebf33dd127",
153
+ fname=fname_decoder,
154
+ path=cls.u2net_home(),
155
+ progressbar=True,
156
+ )
157
+
158
+ return (
159
+ os.path.join(cls.u2net_home(), fname_encoder),
160
+ os.path.join(cls.u2net_home(), fname_decoder),
161
+ )
162
+
163
+ @classmethod
164
+ def name(cls, *args, **kwargs):
165
+ return "sam"
rembg/sessions/silueta.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class SiluetaSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(
17
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
18
+ ),
19
+ )
20
+
21
+ pred = ort_outs[0][:, 0, :, :]
22
+
23
+ ma = np.max(pred)
24
+ mi = np.min(pred)
25
+
26
+ pred = (pred - mi) / (ma - mi)
27
+ pred = np.squeeze(pred)
28
+
29
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
+ mask = mask.resize(img.size, Image.LANCZOS)
31
+
32
+ return [mask]
33
+
34
+ @classmethod
35
+ def download_models(cls, *args, **kwargs):
36
+ fname = f"{cls.name()}.onnx"
37
+ pooch.retrieve(
38
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
39
+ "md5:55e59e0d8062d2f5d013f4725ee84782",
40
+ fname=fname,
41
+ path=cls.u2net_home(),
42
+ progressbar=True,
43
+ )
44
+
45
+ return os.path.join(cls.u2net_home(), fname)
46
+
47
+ @classmethod
48
+ def name(cls, *args, **kwargs):
49
+ return "silueta"
rembg/sessions/u2net.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class U2netSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(
17
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
18
+ ),
19
+ )
20
+
21
+ pred = ort_outs[0][:, 0, :, :]
22
+
23
+ ma = np.max(pred)
24
+ mi = np.min(pred)
25
+
26
+ pred = (pred - mi) / (ma - mi)
27
+ pred = np.squeeze(pred)
28
+
29
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
+ mask = mask.resize(img.size, Image.LANCZOS)
31
+
32
+ return [mask]
33
+
34
+ @classmethod
35
+ def download_models(cls, *args, **kwargs):
36
+ fname = f"{cls.name()}.onnx"
37
+ pooch.retrieve(
38
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
39
+ "md5:60024c5c889badc19c04ad937298a77b",
40
+ fname=fname,
41
+ path=cls.u2net_home(),
42
+ progressbar=True,
43
+ )
44
+
45
+ return os.path.join(cls.u2net_home(), fname)
46
+
47
+ @classmethod
48
+ def name(cls, *args, **kwargs):
49
+ return "u2net"
rembg/sessions/u2net_cloth_seg.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+ from scipy.special import log_softmax
9
+
10
+ from .base import BaseSession
11
+
12
+ pallete1 = [
13
+ 0,
14
+ 0,
15
+ 0,
16
+ 255,
17
+ 255,
18
+ 255,
19
+ 0,
20
+ 0,
21
+ 0,
22
+ 0,
23
+ 0,
24
+ 0,
25
+ ]
26
+
27
+ pallete2 = [
28
+ 0,
29
+ 0,
30
+ 0,
31
+ 0,
32
+ 0,
33
+ 0,
34
+ 255,
35
+ 255,
36
+ 255,
37
+ 0,
38
+ 0,
39
+ 0,
40
+ ]
41
+
42
+ pallete3 = [
43
+ 0,
44
+ 0,
45
+ 0,
46
+ 0,
47
+ 0,
48
+ 0,
49
+ 0,
50
+ 0,
51
+ 0,
52
+ 255,
53
+ 255,
54
+ 255,
55
+ ]
56
+
57
+
58
+ class Unet2ClothSession(BaseSession):
59
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
60
+ ort_outs = self.inner_session.run(
61
+ None,
62
+ self.normalize(
63
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (768, 768)
64
+ ),
65
+ )
66
+
67
+ pred = ort_outs
68
+ pred = log_softmax(pred[0], 1)
69
+ pred = np.argmax(pred, axis=1, keepdims=True)
70
+ pred = np.squeeze(pred, 0)
71
+ pred = np.squeeze(pred, 0)
72
+
73
+ mask = Image.fromarray(pred.astype("uint8"), mode="L")
74
+ mask = mask.resize(img.size, Image.LANCZOS)
75
+
76
+ masks = []
77
+
78
+ mask1 = mask.copy()
79
+ mask1.putpalette(pallete1)
80
+ mask1 = mask1.convert("RGB").convert("L")
81
+ masks.append(mask1)
82
+
83
+ mask2 = mask.copy()
84
+ mask2.putpalette(pallete2)
85
+ mask2 = mask2.convert("RGB").convert("L")
86
+ masks.append(mask2)
87
+
88
+ mask3 = mask.copy()
89
+ mask3.putpalette(pallete3)
90
+ mask3 = mask3.convert("RGB").convert("L")
91
+ masks.append(mask3)
92
+
93
+ return masks
94
+
95
+ @classmethod
96
+ def download_models(cls, *args, **kwargs):
97
+ fname = f"{cls.name()}.onnx"
98
+ pooch.retrieve(
99
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
100
+ "md5:2434d1f3cb744e0e49386c906e5a08bb",
101
+ fname=fname,
102
+ path=cls.u2net_home(),
103
+ progressbar=True,
104
+ )
105
+
106
+ return os.path.join(cls.u2net_home(), fname)
107
+
108
+ @classmethod
109
+ def name(cls, *args, **kwargs):
110
+ return "u2net_cloth_seg"
rembg/sessions/u2net_human_seg.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class U2netHumanSegSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(
17
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
18
+ ),
19
+ )
20
+
21
+ pred = ort_outs[0][:, 0, :, :]
22
+
23
+ ma = np.max(pred)
24
+ mi = np.min(pred)
25
+
26
+ pred = (pred - mi) / (ma - mi)
27
+ pred = np.squeeze(pred)
28
+
29
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
+ mask = mask.resize(img.size, Image.LANCZOS)
31
+
32
+ return [mask]
33
+
34
+ @classmethod
35
+ def download_models(cls, *args, **kwargs):
36
+ fname = f"{cls.name()}.onnx"
37
+ pooch.retrieve(
38
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
39
+ "md5:c09ddc2e0104f800e3e1bb4652583d1f",
40
+ fname=fname,
41
+ path=cls.u2net_home(),
42
+ progressbar=True,
43
+ )
44
+
45
+ return os.path.join(cls.u2net_home(), fname)
46
+
47
+ @classmethod
48
+ def name(cls, *args, **kwargs):
49
+ return "u2net_human_seg"
rembg/sessions/u2netp.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import pooch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PILImage
8
+
9
+ from .base import BaseSession
10
+
11
+
12
+ class U2netpSession(BaseSession):
13
+ def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
14
+ ort_outs = self.inner_session.run(
15
+ None,
16
+ self.normalize(
17
+ img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
18
+ ),
19
+ )
20
+
21
+ pred = ort_outs[0][:, 0, :, :]
22
+
23
+ ma = np.max(pred)
24
+ mi = np.min(pred)
25
+
26
+ pred = (pred - mi) / (ma - mi)
27
+ pred = np.squeeze(pred)
28
+
29
+ mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
30
+ mask = mask.resize(img.size, Image.LANCZOS)
31
+
32
+ return [mask]
33
+
34
+ @classmethod
35
+ def download_models(cls, *args, **kwargs):
36
+ fname = f"{cls.name()}.onnx"
37
+ pooch.retrieve(
38
+ "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
39
+ "md5:8e83ca70e441ab06c318d82300c84806",
40
+ fname=fname,
41
+ path=cls.u2net_home(),
42
+ progressbar=True,
43
+ )
44
+
45
+ return os.path.join(cls.u2net_home(), fname)
46
+
47
+ @classmethod
48
+ def name(cls, *args, **kwargs):
49
+ return "u2netp"