| import os |
| import io |
| import torch |
| import numpy as np |
| from PIL import Image |
| import cv2 |
| import pytorch_lightning as pl |
| from model import ISNetDIS, ISNetGTEncoder, U2NET, U2NET_full2, U2NET_lite2, MODNet |
|
|
| model = None |
|
|
| def get_mask(model, input_img): |
| h, w = input_img.shape[0], input_img.shape[1] |
| ph, pw = 0, 0 |
| tmpImg = np.zeros([h, w, 3], dtype=np.float16) |
| tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255 |
| tmpImg = tmpImg.transpose((2, 0, 1)) |
| tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device) |
| with torch.no_grad(): |
| pred = model(tmpImg) |
| pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] |
| pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w, h))[:, :, np.newaxis] |
| return pred |
|
|
| def get_net(net_name): |
| if net_name == "isnet": |
| return ISNetDIS() |
| elif net_name == "isnet_is": |
| return ISNetDIS() |
| elif net_name == "isnet_gt": |
| return ISNetGTEncoder() |
| elif net_name == "u2net": |
| return U2NET_full2() |
| elif net_name == "u2netl": |
| return U2NET_lite2() |
| elif net_name == "modnet": |
| return MODNet() |
| raise NotImplemented |
|
|
| |
| class AnimeSegmentation(pl.LightningModule): |
| def __init__(self, net_name): |
| super().__init__() |
| assert net_name in ["isnet_is", "isnet", "isnet_gt", "u2net", "u2netl", "modnet"] |
| self.net = get_net(net_name) |
| if net_name == "isnet_is": |
| self.gt_encoder = get_net("isnet_gt") |
| for param in self.gt_encoder.parameters(): |
| param.requires_grad = False |
| else: |
| self.gt_encoder = None |
|
|
| @classmethod |
| def try_load(cls, net_name, ckpt_path, map_location=None): |
| state_dict = torch.load(ckpt_path, map_location=map_location) |
| if "epoch" in state_dict: |
| return cls.load_from_checkpoint(ckpt_path, net_name=net_name, map_location=map_location) |
| else: |
| model = cls(net_name) |
| if any([k.startswith("net.") for k, v in state_dict.items()]): |
| model.load_state_dict(state_dict) |
| else: |
| model.net.load_state_dict(state_dict) |
| return model |
|
|
| def forward(self, x): |
| if isinstance(self.net, ISNetDIS): |
| return self.net(x)[0][0].sigmoid() |
| if isinstance(self.net, ISNetGTEncoder): |
| return self.net(x)[0][0].sigmoid() |
| elif isinstance(self.net, U2NET): |
| return self.net(x)[0].sigmoid() |
| elif isinstance(self.net, MODNet): |
| return self.net(x, True)[2] |
| raise NotImplemented |
|
|
| def load_model(): |
| global model |
|
|
| if torch.cuda.is_available(): |
| device = 'cuda' |
| else: |
| device = 'cpu' |
| |
| model = AnimeSegmentation.try_load('isnet_is', 'anime-seg/isnetis.ckpt', device) |
| model.eval() |
| model.to(device) |
| |
| def animeseg(image): |
| global model |
|
|
| if not image: |
| return None |
|
|
| if not model: |
| model = load_model() |
|
|
| img = np.array(image, dtype=np.uint8) |
| mask = get_mask(model, img) |
| img = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(np.uint8) |
| return img |
|
|
| def pil_to_webp(img): |
| buffer = io.BytesIO() |
| img.save(buffer, 'webp') |
|
|
| return buffer.getvalue() |
|
|
| def bin_to_base64(bin): |
| return base64.b64encode(bin).decode('ascii') |
|
|