| import os |
| import sys |
| import tqdm |
| import wget |
| import gdown |
| import torch |
| import shutil |
| import base64 |
| import warnings |
| import importlib |
|
|
| import numpy as np |
| import torch.nn.functional as F |
| import torchvision.transforms as transforms |
| import albumentations as A |
| import albumentations.pytorch as AP |
|
|
| from PIL import Image |
| from io import BytesIO |
| from packaging import version |
|
|
| filepath = os.path.abspath(__file__) |
| repopath = os.path.split(filepath)[0] |
| sys.path.append(repopath) |
|
|
| from transparent_background.InSPyReNet import InSPyReNet_SwinB |
| from transparent_background.utils import * |
|
|
| class Remover: |
| def __init__(self, mode="base", jit=False, device=None, ckpt=None, resize='static'): |
| """ |
| Args: |
| mode (str): Choose among below options |
| base -> slow & large gpu memory required, high quality results |
| fast -> resize input into small size for fast computation |
| base-nightly -> nightly release for base mode |
| jit (bool): use TorchScript for fast computation |
| device (str, optional): specifying device for computation. find available GPU resource if not specified. |
| ckpt (str, optional): specifying model checkpoint. find downloaded checkpoint or try download if not specified. |
| fast (bool, optional, DEPRECATED): replaced by mode argument. use fast mode if True. |
| """ |
| cfg_path = os.environ.get('TRANSPARENT_BACKGROUND_FILE_PATH', os.path.abspath(os.path.expanduser('~'))) |
| home_dir = os.path.join(cfg_path, ".transparent-background") |
| os.makedirs(home_dir, exist_ok=True) |
|
|
| if not os.path.isfile(os.path.join(home_dir, "config.yaml")): |
| shutil.copy(os.path.join(repopath, "config.yaml"), os.path.join(home_dir, "config.yaml")) |
| self.meta = load_config(os.path.join(home_dir, "config.yaml"))[mode] |
|
|
| if device is not None: |
| self.device = device |
| else: |
| self.device = "cpu" |
| if torch.cuda.is_available(): |
| self.device = "cuda:0" |
| elif ( |
| version.parse(torch.__version__) >= version.parse("1.13") |
| and torch.backends.mps.is_available() |
| ): |
| self.device = "mps:0" |
|
|
| download = False |
| if ckpt is None: |
| ckpt_dir = home_dir |
| ckpt_name = self.meta.ckpt_name |
|
|
| if not os.path.isfile(os.path.join(ckpt_dir, ckpt_name)): |
| download = True |
| elif ( |
| self.meta.md5 |
| != hashlib.md5( |
| open(os.path.join(ckpt_dir, ckpt_name), "rb").read() |
| ).hexdigest() |
| ): |
| if self.meta.md5 is not None: |
| download = True |
|
|
| if download: |
| if 'drive.google.com' in self.meta.url: |
| gdown.download(self.meta.url, os.path.join(ckpt_dir, ckpt_name), fuzzy=True, proxy=self.meta.http_proxy) |
| elif 'github.com' in self.meta.url: |
| wget.download(self.meta.url, os.path.join(ckpt_dir, ckpt_name)) |
| else: |
| raise NotImplementedError('Please use valid URL') |
| else: |
| ckpt_dir, ckpt_name = os.path.split(os.path.abspath(ckpt)) |
|
|
| self.model = InSPyReNet_SwinB(depth=64, pretrained=False, threshold=None, **self.meta) |
| self.model.eval() |
| self.model.load_state_dict( |
| torch.load(os.path.join(ckpt_dir, ckpt_name), map_location="cpu", weights_only=True), |
| strict=True, |
| ) |
| self.model = self.model.to(self.device) |
|
|
| if jit: |
| ckpt_name = self.meta.ckpt_name.replace( |
| ".pth", "_{}.pt".format(self.device) |
| ) |
| try: |
| traced_model = torch.jit.load( |
| os.path.join(ckpt_dir, ckpt_name), map_location=self.device |
| ) |
| del self.model |
| self.model = traced_model |
| except: |
| traced_model = torch.jit.trace( |
| self.model, |
| torch.rand(1, 3, *self.meta.base_size).to(self.device), |
| strict=True, |
| ) |
| del self.model |
| self.model = traced_model |
| torch.jit.save(self.model, os.path.join(ckpt_dir, ckpt_name)) |
| if resize != 'static': |
| warnings.warn('Resizing method for TorchScript mode only supports static resize. Fallback to static.') |
| resize = 'static' |
|
|
| resize_tf = None |
| resize_fn = None |
| if resize == 'static': |
| resize_tf = static_resize(self.meta.base_size) |
| resize_fn = A.Resize(*self.meta.base_size) |
| elif resize == 'dynamic': |
| if 'base' not in mode: |
| warnings.warn('Dynamic resizing only supports base and base-nightly mode. It will cause severe performance degradation with fast mode.') |
| resize_tf = dynamic_resize(L=1280) |
| resize_fn = dynamic_resize_a(L=1280) |
| else: |
| raise AttributeError(f'Unsupported resizing method {resize}') |
|
|
| self.transform = transforms.Compose( |
| [ |
| resize_tf, |
| tonumpy(), |
| normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| totensor(), |
| ] |
| ) |
|
|
| self.cv2_transform = A.Compose( |
| [ |
| resize_fn, |
| A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| AP.ToTensorV2(), |
| ] |
| ) |
|
|
| self.background = {'img': None, 'name': None, 'shape': None} |
| desc = "Mode={}, Device={}, Torchscript={}".format( |
| mode, self.device, "enabled" if jit else "disabled" |
| ) |
| print("Settings -> {}".format(desc)) |
|
|
| def process(self, img, type="rgba", threshold=None, reverse=False): |
| """ |
| Args: |
| img (PIL.Image or np.ndarray): input image as PIL.Image or np.ndarray type |
| type (str): output type option as below. |
| 'rgba' will generate RGBA output regarding saliency score as an alpha map. |
| 'green' will change the background with green screen. |
| 'white' will change the background with white color. |
| '[255, 0, 0]' will change the background with color code [255, 0, 0]. |
| 'blur' will blur the background. |
| 'overlay' will cover the salient object with translucent green color, and highlight the edges. |
| Another image file (e.g., 'samples/backgroud.png') will be used as a background, and the object will be overlapped on it. |
| threshold (float or str, optional): produce hard prediction w.r.t specified threshold value (0.0 ~ 1.0) |
| Returns: |
| PIL.Image: output image |
| |
| """ |
|
|
| if isinstance(img, np.ndarray): |
| is_numpy = True |
| shape = img.shape[:2] |
| x = self.cv2_transform(image=img)["image"] |
| else: |
| is_numpy = False |
| shape = img.size[::-1] |
| x = self.transform(img) |
|
|
| x = x.unsqueeze(0) |
| x = x.to(self.device) |
|
|
| with torch.no_grad(): |
| pred = self.model(x) |
|
|
| pred = F.interpolate(pred, shape, mode="bilinear", align_corners=True) |
| pred = pred.data.cpu() |
| pred = pred.numpy().squeeze() |
|
|
| if threshold is not None: |
| pred = (pred > float(threshold)).astype(np.float64) |
| if reverse: |
| pred = 1 - pred |
|
|
| img = np.array(img) |
|
|
| if type.startswith("["): |
| type = [int(i) for i in type[1:-1].split(",")] |
|
|
| if type == "map": |
| img = (np.stack([pred] * 3, axis=-1) * 255).astype(np.uint8) |
|
|
| elif type == "rgba": |
| if threshold is None: |
| |
| try: |
| from pymatting.foreground.estimate_foreground_ml_cupy import estimate_foreground_ml_cupy as estimate_foreground_ml |
| except ImportError: |
| try: |
| from pymatting.foreground.estimate_foreground_ml_pyopencl import estimate_foreground_ml_pyopencl as estimate_foreground_ml |
| except ImportError: |
| from pymatting import estimate_foreground_ml |
| img = estimate_foreground_ml(img / 255.0, pred) |
| img = 255 * np.clip(img, 0., 1.) + 0.5 |
| img = img.astype(np.uint8) |
|
|
| r, g, b = cv2.split(img) |
| pred = (pred * 255).astype(np.uint8) |
| img = cv2.merge([r, g, b, pred]) |
|
|
| elif type == "green": |
| bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] |
| img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis]) |
|
|
| elif type == "white": |
| bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * [255, 255, 255] |
| img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis]) |
|
|
| elif len(type) == 3: |
| print(type) |
| bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * type |
| img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis]) |
|
|
| elif type == "blur": |
| img = img * pred[..., np.newaxis] + cv2.GaussianBlur(img, (0, 0), 15) * ( |
| 1 - pred[..., np.newaxis] |
| ) |
|
|
| elif type == "overlay": |
| bg = ( |
| np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] + img |
| ) // 2 |
| img = bg * pred[..., np.newaxis] + img * (1 - pred[..., np.newaxis]) |
| border = cv2.Canny(((pred > 0.5) * 255).astype(np.uint8), 50, 100) |
| img[border != 0] = [120, 255, 155] |
|
|
| elif type.lower().endswith(IMG_EXTS): |
| if self.background['name'] != type: |
| background_img = cv2.cvtColor(cv2.imread(type), cv2.COLOR_BGR2RGB) |
| background_img = cv2.resize(background_img, img.shape[:2][::-1]) |
| |
| self.background['img'] = background_img |
| self.background['shape'] = img.shape[:2][::-1] |
| self.background['name'] = type |
| |
| elif self.background['shape'] != img.shape[:2][::-1]: |
| self.background['img'] = cv2.resize(self.background['img'], img.shape[:2][::-1]) |
| self.background['shape'] = img.shape[:2][::-1] |
|
|
| img = img * pred[..., np.newaxis] + self.background['img'] * ( |
| 1 - pred[..., np.newaxis] |
| ) |
|
|
| if is_numpy: |
| return img.astype(np.uint8) |
| else: |
| return Image.fromarray(img.astype(np.uint8)) |
|
|
| def to_base64(image): |
| buffered = BytesIO() |
| image.save(buffered, format="JPEG") |
| base64_img = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| return base64_img |
|
|
| def entry_point(out_type, mode, device, ckpt, source, dest, jit, threshold, resize, save_format=None, reverse=False, flet_progress=None, flet_page=None, preview=None, preview_out=None, options=None): |
| warnings.filterwarnings("ignore") |
|
|
| remover = Remover(mode=mode, jit=jit, device=device, ckpt=ckpt, resize=resize) |
|
|
| if source.isnumeric() is True: |
| save_dir = None |
| _format = "Webcam" |
| if importlib.util.find_spec('pyvirtualcam') is not None: |
| try: |
| import pyvirtualcam |
| vcam = pyvirtualcam.Camera(width=640, height=480, fps=30) |
| except: |
| vcam = None |
| else: |
| raise ImportError("pyvirtualcam not found. Install with \"pip install transparent-background[webcam]\"") |
|
|
| elif os.path.isdir(source): |
| save_dir = os.path.join(os.getcwd(), source.split(os.sep)[-1]) |
| _format = get_format(os.listdir(source)) |
|
|
| elif os.path.isfile(source): |
| save_dir = os.getcwd() |
| _format = get_format([source]) |
|
|
| else: |
| raise FileNotFoundError("File or directory {} is invalid.".format(source)) |
|
|
| if out_type == "rgba" and _format == "Video": |
| raise AttributeError("type 'rgba' cannot be applied to video input.") |
|
|
| if dest is not None: |
| save_dir = dest |
|
|
| if save_dir is not None: |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| loader = eval(_format + "Loader")(source) |
| frame_progress = tqdm.tqdm( |
| total=len(loader), |
| position=1 if (_format == "Video" and len(loader) > 1) else 0, |
| leave=False, |
| bar_format="{desc:<15}{percentage:3.0f}%|{bar:50}{r_bar}", |
| ) |
| sample_progress = ( |
| tqdm.tqdm( |
| total=len(loader), |
| desc="Total:", |
| position=0, |
| bar_format="{desc:<15}{percentage:3.0f}%|{bar:50}{r_bar}", |
| ) |
| if (_format == "Video" and len(loader) > 1) |
| else None |
| ) |
| if flet_progress is not None: |
| assert flet_page is not None |
| flet_progress.value = 0 |
| flet_step = 1 / frame_progress.total |
|
|
| writer = None |
|
|
| for img, name in loader: |
| filename, ext = os.path.splitext(name) |
| ext = ext[1:] |
| ext = save_format if save_format is not None else ext |
| frame_progress.set_description("{}".format(name)) |
| if out_type.lower().endswith(IMG_EXTS): |
| outname = "{}_{}".format( |
| filename, |
| os.path.splitext(os.path.split(out_type)[-1])[0], |
| ) |
| else: |
| outname = "{}_{}".format(filename, out_type) |
|
|
| if reverse: |
| outname += '_reverse' |
|
|
| if _format == "Video" and writer is None: |
| writer = cv2.VideoWriter( |
| os.path.join(save_dir, f"{outname}.{ext}"), |
| cv2.VideoWriter_fourcc(*"mp4v"), |
| loader.fps, |
| img.size, |
| ) |
| writer.set(cv2.VIDEOWRITER_PROP_QUALITY, 100) |
| frame_progress.refresh() |
| frame_progress.reset() |
| frame_progress.total = int(loader.cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| if sample_progress is not None: |
| sample_progress.update() |
|
|
| if flet_progress is not None: |
| assert flet_page is not None |
| flet_progress.value = 0 |
| flet_step = 1 / frame_progress.total |
| flet_progress.update() |
|
|
| if _format == "Video" and img is None: |
| if writer is not None: |
| writer.release() |
| writer = None |
| continue |
|
|
| out = remover.process(img, type=out_type, threshold=threshold, reverse=reverse) |
|
|
| if _format == "Image": |
| if out_type == "rgba" and ext.lower() != 'png': |
| warnings.warn('Output format for rgba mode only supports png format. Fallback to png output.') |
| ext = 'png' |
| out.save(os.path.join(save_dir, f"{outname}.{ext}")) |
| elif _format == "Video" and writer is not None: |
| writer.write(cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB)) |
| elif _format == "Webcam": |
| if vcam is not None: |
| vcam.send(np.array(out)) |
| vcam.sleep_until_next_frame() |
| else: |
| cv2.imshow( |
| "transparent-background", cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB) |
| ) |
| frame_progress.update() |
| if flet_progress is not None: |
| flet_progress.value += flet_step |
| flet_progress.update() |
|
|
| if out_type == 'rgba': |
| o = np.array(out).astype(np.float64) |
| o[:, :, :3] *= (o[:, :, -1:] / 255) |
| out = Image.fromarray(o[:, :, :3].astype(np.uint8)) |
|
|
| preview.src_base64 = to_base64(img.resize((480, 300)).convert('RGB')) |
| preview_out.src_base64 = to_base64(out.resize((480, 300)).convert('RGB')) |
| preview.update() |
| preview_out.update() |
|
|
| if options is not None and options['abort']: |
| break |
| |
| print("\nDone. Results are saved in {}".format(os.path.abspath(save_dir))) |
|
|
| def console(): |
| args = parse_args() |
| entry_point(args.type, args.mode, args.device, args.ckpt, args.source, args.dest, args.jit, args.threshold, args.resize, args.format, args.reverse) |
|
|