Spaces:
Runtime error
Runtime error
| import os | |
| # from tqdm import tqdm | |
| from PIL import Image | |
| import numpy as np | |
| import sys | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| from .u2net_cloth_seg.data.base_dataset import Normalize_image | |
| from .u2net_cloth_seg.utils.saving_utils import load_checkpoint_mgpu | |
| from .u2net_cloth_seg.networks import U2NET | |
| class U2NETParser: | |
| def __init__(self, checkpoint_path): | |
| self.cp_path = checkpoint_path | |
| self.img_transforms = transforms.Compose([ | |
| transforms.ToTensor(), | |
| Normalize_image(0.5, 0.5) | |
| ]) | |
| self.model = self.load_model() | |
| def load_model(self): | |
| model = U2NET(in_ch=3, out_ch=4) | |
| model = load_checkpoint_mgpu(model, self.cp_path) | |
| model = model.to("cuda") | |
| model = model.eval() | |
| return model | |
| def get_image_mask(self, img): | |
| # print("Evaluating total class number {} with {}".format(self.num_classes, self.label)) | |
| img_size = img.size | |
| img = img.resize((768, 768), Image.BICUBIC) | |
| image_tensor = self.img_transforms(img) | |
| image_tensor = torch.unsqueeze(image_tensor, 0) | |
| with torch.no_grad(): | |
| output_tensor = self.model(image_tensor.to("cuda")) | |
| output_tensor = F.log_softmax(output_tensor[0], dim=1) | |
| output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1] | |
| output_tensor = torch.squeeze(output_tensor, dim=0) | |
| output_tensor = torch.squeeze(output_tensor, dim=0) | |
| output_arr = output_tensor.cpu().numpy() | |
| output_arr[output_arr != 1] = 0 | |
| output_arr[output_arr == 1] = 255 | |
| output_img = Image.fromarray(output_arr.astype('uint8'), mode='L') | |
| output_img = output_img.resize(img_size, Image.BICUBIC) | |
| return output_img |