| import torch |
| import numpy as np |
| from collections import OrderedDict |
| from dataclasses import dataclass |
| from typing import Optional |
| from enum import Enum |
|
|
| from PIL import Image |
| from torchvision import transforms |
|
|
| from app.simple_segmentation.network import U2NET |
| from app.utils import check_or_download_model, load_checkpoint, image_to_base64 |
|
|
| MODEL_URL = "https://huggingface.co/spaces/wildoctopus/cloth-segmentation/resolve/main/model/cloth_segm.pth" |
|
|
| class Mode(Enum): |
| BINARY = 'binary' |
| VARIABLE = 'variable' |
|
|
| @dataclass |
| class Result: |
| upper_body: Optional[Image.Image] = None |
| lower_body: Optional[Image.Image] = None |
| full_body: Optional[Image.Image] = None |
|
|
| def load_seg_model(checkpoint_path, device='cpu'): |
| net = U2NET(in_ch=3, out_ch=4) |
| check_or_download_model(MODEL_URL, checkpoint_path) |
| net = load_checkpoint(net, checkpoint_path) |
| net = net.to(device) |
| net = net.eval() |
|
|
| return net |
|
|
| def apply_transform(image): |
| transform_rgb = transforms.Compose( |
| [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]) |
| return transform_rgb(image) |
|
|
| def binary_segment(image: Image, net: U2NET, mode=Mode.BINARY) -> Result: |
| original_size = image.size |
|
|
| image = image.resize((768, 768), Image.BICUBIC).convert('RGB') |
| image_tensor = apply_transform(image) |
| image_tensor = torch.unsqueeze(image_tensor, 0) |
|
|
| images = [] |
|
|
| with torch.no_grad(): |
| output_tensor: torch.Tensor = net(image_tensor.to("cpu")) |
| output_tensor = torch.nn.functional.log_softmax( |
| output_tensor[0], dim=1) |
|
|
| if mode == Mode.VARIABLE: |
| for mask in range(1, 4): |
| mask_probabilities = output_tensor.reshape((4, 768, 768))[mask] |
| mask_probabilities = torch.exp(mask_probabilities).numpy() |
|
|
| mask_probabilities = (mask_probabilities * 255).astype(np.uint8) |
|
|
| alpha_mask_img = Image.fromarray(mask_probabilities, mode='L') |
| alpha_mask_img = alpha_mask_img.resize( |
| original_size, Image.BICUBIC) |
| images.append(alpha_mask_img) |
| else: |
| output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1] |
| output_tensor = torch.squeeze(output_tensor, dim=0) |
| output_arr = output_tensor.cpu().numpy() |
|
|
| for mask in range(1, 4): |
| alpha_mask = (output_arr == mask).astype(np.uint8) * 255 |
| alpha_mask_img = Image.fromarray(alpha_mask[0], mode='L') |
| alpha_mask_img = alpha_mask_img.resize( |
| original_size, Image.BICUBIC) |
| images.append(alpha_mask_img) |
|
|
| return [ |
| {"label": "upper_body", "image": image_to_base64(images[0])}, |
| {"label": "lower_body", "image": image_to_base64(images[1])}, |
| {"label": "full_body", "image": image_to_base64(images[2])}, |
| ] |
|
|