pillipop
return obj for data layer
fc8563b unverified
import torch
import numpy as np
from collections import OrderedDict
from dataclasses import dataclass
from typing import Optional
from enum import Enum
from PIL import Image
from torchvision import transforms
from app.simple_segmentation.network import U2NET
from app.utils import check_or_download_model, load_checkpoint, image_to_base64
MODEL_URL = "https://huggingface.co/spaces/wildoctopus/cloth-segmentation/resolve/main/model/cloth_segm.pth"
class Mode(Enum):
BINARY = 'binary'
VARIABLE = 'variable'
@dataclass
class Result:
upper_body: Optional[Image.Image] = None
lower_body: Optional[Image.Image] = None
full_body: Optional[Image.Image] = None
def load_seg_model(checkpoint_path, device='cpu'):
net = U2NET(in_ch=3, out_ch=4)
check_or_download_model(MODEL_URL, checkpoint_path)
net = load_checkpoint(net, checkpoint_path)
net = net.to(device)
net = net.eval()
return net
def apply_transform(image):
transform_rgb = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])
return transform_rgb(image)
def binary_segment(image: Image, net: U2NET, mode=Mode.BINARY) -> Result:
original_size = image.size
image = image.resize((768, 768), Image.BICUBIC).convert('RGB')
image_tensor = apply_transform(image)
image_tensor = torch.unsqueeze(image_tensor, 0)
images = []
with torch.no_grad():
output_tensor: torch.Tensor = net(image_tensor.to("cpu"))
output_tensor = torch.nn.functional.log_softmax(
output_tensor[0], dim=1)
if mode == Mode.VARIABLE:
for mask in range(1, 4):
mask_probabilities = output_tensor.reshape((4, 768, 768))[mask]
mask_probabilities = torch.exp(mask_probabilities).numpy()
mask_probabilities = (mask_probabilities * 255).astype(np.uint8)
alpha_mask_img = Image.fromarray(mask_probabilities, mode='L')
alpha_mask_img = alpha_mask_img.resize(
original_size, Image.BICUBIC)
images.append(alpha_mask_img)
else:
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
output_tensor = torch.squeeze(output_tensor, dim=0)
output_arr = output_tensor.cpu().numpy()
for mask in range(1, 4):
alpha_mask = (output_arr == mask).astype(np.uint8) * 255
alpha_mask_img = Image.fromarray(alpha_mask[0], mode='L')
alpha_mask_img = alpha_mask_img.resize(
original_size, Image.BICUBIC)
images.append(alpha_mask_img)
return [
{"label": "upper_body", "image": image_to_base64(images[0])},
{"label": "lower_body", "image": image_to_base64(images[1])},
{"label": "full_body", "image": image_to_base64(images[2])},
]