| from typing import Dict, Any | |
| from io import BytesIO | |
| import base64 | |
| from model import ISNetDIS | |
| import torch | |
| import os | |
| from PIL import Image | |
| from torchvision.transforms import Compose, Normalize, functional | |
| def process_image(image: torch.Tensor): | |
| pipe = Compose([Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
| img = pipe(image) | |
| return torch.unsqueeze(img, 0) | |
| def get_model(device="cpu"): | |
| model = ISNetDIS() | |
| weight_pth = os.path.join(os.path.dirname(__file__), "isnet.pth") | |
| weights = torch.load(weight_pth, map_location=device) | |
| model.load_state_dict(weights) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| class EndpointHandler(): | |
| def __init__(self, path=""): | |
| self._model = get_model() | |
| def __call__(self, data: Dict[str, Any]) -> list[Dict[str, Any]]: | |
| inputs = data.pop("inputs", data) | |
| image = Image.open(BytesIO(base64.b64decode(inputs['image']))) | |
| t = functional.pil_to_tensor(image).float().divide(255.0) | |
| arr = process_image(t) | |
| model = get_model() | |
| v = model(arr)[0] | |
| pred_val = v[0][0, :, :, :] | |
| ma = torch.max(pred_val) | |
| mi = torch.min(pred_val) | |
| pred_val = (pred_val - mi) / (ma - mi) | |
| msk = torch.gt(pred_val, 0.1) | |
| w = torch.where(msk, t, 1) | |
| w = torch.cat([w, msk], dim=0) | |
| img2 = functional.to_pil_image(torch.squeeze(w)) | |
| stream = BytesIO() | |
| img2.save(stream, format="png") | |
| res = {"status": 200, | |
| "image": base64.b64encode(stream.getvalue()).decode("utf8") | |
| } | |
| return res | |
| if __name__ == "__main__": | |
| h = EndpointHandler() | |
| v = h({}) | |
| print(v) |