from transformers import VitMatteImageProcessor, VitMatteForImageMatting import torch from PIL import Image from huggingface_hub import hf_hub_download import torchvision.transforms as T from typing import Dict, List, Any from io import BytesIO import base64 # image = Image.open("man.png").convert("RGB") # trimap = Image.open("mask2.png").convert("L") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class EndpointHandler(): def __init__(self, path=""): self.processor = VitMatteImageProcessor.from_pretrained( "hustvl/vitmatte-small-composition-1k") self.model = VitMatteForImageMatting.from_pretrained( "hustvl/vitmatte-small-composition-1k") self.model = self.model.to(device) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: inputs = data.pop("inputs", data) # parameters = data.pop("parameters", {"mode": "image"}) image = Image.open( BytesIO(base64.b64decode(inputs['image']))).convert("RGB") trimap = Image.open( BytesIO(base64.b64decode(inputs['trimap']))).convert("L") # image = data.pop("image") # trimap = data.pop("trimap") inputs = self.processor( images=image, trimaps=trimap, return_tensors="pt").to(device) with torch.no_grad(): alphas = self.model(**inputs).alphas print(alphas.shape) image = T.ToPILImage()(torch.squeeze(alphas)) return {"result": image}