|
|
from typing import Dict, Any |
|
|
import os |
|
|
import requests |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from transformers import AutoModelForImageSegmentation |
|
|
|
|
|
|
|
|
torch.set_float32_matmul_precision(["high", "highest"][0]) |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=''): |
|
|
|
|
|
self.model = AutoModelForImageSegmentation.from_pretrained( |
|
|
'zhengpeng7/BiRefNet', |
|
|
trust_remote_code=True |
|
|
) |
|
|
self.model.to(device) |
|
|
self.model.eval() |
|
|
self.model.half() |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]): |
|
|
|
|
|
image_src = data["inputs"] |
|
|
image = None |
|
|
|
|
|
if isinstance(image_src, Image.Image): |
|
|
image = image_src |
|
|
elif isinstance(image_src, str): |
|
|
if image_src.startswith('http'): |
|
|
image = Image.open(BytesIO(requests.get(image_src).content)) |
|
|
else: |
|
|
image = Image.open(image_src) |
|
|
else: |
|
|
image = Image.open(BytesIO(image_src)) |
|
|
|
|
|
|
|
|
image = image.convert("RGB") |
|
|
orig_size = image.size |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((1024, 1024)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
input_tensor = transform(image).unsqueeze(0).to(device).half() |
|
|
|
|
|
with torch.no_grad(): |
|
|
preds = self.model(input_tensor)[-1].sigmoid().cpu() |
|
|
|
|
|
|
|
|
pred = preds[0].squeeze() |
|
|
mask_pil = transforms.ToPILImage()(pred) |
|
|
mask_pil = mask_pil.resize(orig_size, resample=Image.Resampling.LANCZOS) |
|
|
|
|
|
|
|
|
image.putalpha(mask_pil) |
|
|
|
|
|
return image |