|
|
from typing import Dict, Any, Tuple |
|
|
import os |
|
|
import requests |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from transformers import AutoModelForImageSegmentation |
|
|
|
|
|
|
|
|
torch.set_float32_matmul_precision(["high", "highest"][0]) |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
usage_to_weights_file = { |
|
|
'General': 'BiRefNet', |
|
|
'General-Lite': 'BiRefNet_lite', |
|
|
'General-Lite-2K': 'BiRefNet_lite-2K', |
|
|
'General-reso_512': 'BiRefNet-reso_512', |
|
|
'General-HR': 'BiRefNet_HR' |
|
|
} |
|
|
|
|
|
usage = 'General' |
|
|
resolution = (1024, 1024) |
|
|
half_precision = True |
|
|
|
|
|
class ImagePreprocessor(): |
|
|
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None: |
|
|
self.transform_image = transforms.Compose([ |
|
|
transforms.Resize(resolution), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
|
]) |
|
|
def proc(self, image: Image.Image) -> torch.Tensor: |
|
|
image = self.transform_image(image) |
|
|
return image |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=''): |
|
|
|
|
|
self.birefnet = AutoModelForImageSegmentation.from_pretrained( |
|
|
'/'.join(('zhengpeng7', usage_to_weights_file[usage])), |
|
|
trust_remote_code=True |
|
|
) |
|
|
self.birefnet.to(device) |
|
|
self.birefnet.eval() |
|
|
if half_precision: |
|
|
self.birefnet.half() |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]): |
|
|
|
|
|
image_src = data["inputs"] |
|
|
image_ori = None |
|
|
|
|
|
|
|
|
if hasattr(image_src, 'convert') or isinstance(image_src, Image.Image): |
|
|
image_ori = image_src |
|
|
elif isinstance(image_src, str): |
|
|
if os.path.isfile(image_src): |
|
|
image_ori = Image.open(image_src) |
|
|
else: |
|
|
response = requests.get(image_src) |
|
|
image_ori = Image.open(BytesIO(response.content)) |
|
|
else: |
|
|
try: |
|
|
image_ori = Image.open(BytesIO(image_src)) |
|
|
except Exception: |
|
|
try: |
|
|
image_ori = Image.fromarray(image_src) |
|
|
except Exception: |
|
|
image_ori = image_src |
|
|
|
|
|
|
|
|
image = image_ori.convert('RGB') |
|
|
|
|
|
|
|
|
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution)) |
|
|
image_proc = image_preprocessor.proc(image) |
|
|
image_proc = image_proc.unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu() |
|
|
|
|
|
pred = preds[0].squeeze() |
|
|
|
|
|
|
|
|
|
|
|
mask_pil = transforms.ToPILImage()(pred) |
|
|
|
|
|
|
|
|
mask_pil = mask_pil.resize(image.size, resample=Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
|
|
|
image.putalpha(mask_pil) |
|
|
|
|
|
return image |