import torch import numpy as np from collections import OrderedDict from dataclasses import dataclass from typing import Optional from enum import Enum from PIL import Image from torchvision import transforms from app.simple_segmentation.network import U2NET from app.utils import check_or_download_model, load_checkpoint, image_to_base64 MODEL_URL = "https://huggingface.co/spaces/wildoctopus/cloth-segmentation/resolve/main/model/cloth_segm.pth" class Mode(Enum): BINARY = 'binary' VARIABLE = 'variable' @dataclass class Result: upper_body: Optional[Image.Image] = None lower_body: Optional[Image.Image] = None full_body: Optional[Image.Image] = None def load_seg_model(checkpoint_path, device='cpu'): net = U2NET(in_ch=3, out_ch=4) check_or_download_model(MODEL_URL, checkpoint_path) net = load_checkpoint(net, checkpoint_path) net = net.to(device) net = net.eval() return net def apply_transform(image): transform_rgb = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]) return transform_rgb(image) def binary_segment(image: Image, net: U2NET, mode=Mode.BINARY) -> Result: original_size = image.size image = image.resize((768, 768), Image.BICUBIC).convert('RGB') image_tensor = apply_transform(image) image_tensor = torch.unsqueeze(image_tensor, 0) images = [] with torch.no_grad(): output_tensor: torch.Tensor = net(image_tensor.to("cpu")) output_tensor = torch.nn.functional.log_softmax( output_tensor[0], dim=1) if mode == Mode.VARIABLE: for mask in range(1, 4): mask_probabilities = output_tensor.reshape((4, 768, 768))[mask] mask_probabilities = torch.exp(mask_probabilities).numpy() mask_probabilities = (mask_probabilities * 255).astype(np.uint8) alpha_mask_img = Image.fromarray(mask_probabilities, mode='L') alpha_mask_img = alpha_mask_img.resize( original_size, Image.BICUBIC) images.append(alpha_mask_img) else: output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1] output_tensor = torch.squeeze(output_tensor, dim=0) output_arr = output_tensor.cpu().numpy() for mask in range(1, 4): alpha_mask = (output_arr == mask).astype(np.uint8) * 255 alpha_mask_img = Image.fromarray(alpha_mask[0], mode='L') alpha_mask_img = alpha_mask_img.resize( original_size, Image.BICUBIC) images.append(alpha_mask_img) return [ {"label": "upper_body", "image": image_to_base64(images[0])}, {"label": "lower_body", "image": image_to_base64(images[1])}, {"label": "full_body", "image": image_to_base64(images[2])}, ]