Spaces:
Build error
Build error
| from rtmlib import YOLOX, RTMPose, draw_bbox, draw_skeleton | |
| import functools | |
| from typing import Callable | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| TITLE = 'Face Parsing' | |
| def get_palette(num_cls): | |
| """ Returns the color map for visualizing the segmentation mask. | |
| Args: | |
| num_cls: Number of classes | |
| Returns: | |
| The color map | |
| """ | |
| n = num_cls | |
| palette = [0] * (n * 3) | |
| for j in range(0, n): | |
| lab = j | |
| palette[j * 3 + 0] = 0 | |
| palette[j * 3 + 1] = 0 | |
| palette[j * 3 + 2] = 0 | |
| i = 0 | |
| while lab: | |
| palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) | |
| palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) | |
| palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) | |
| i += 1 | |
| lab >>= 3 | |
| return palette | |
| def predict(image: PIL.Image.Image, model, transform: Callable, | |
| device: torch.device,palette) -> np.ndarray: | |
| img_show = np.array(image.copy()) | |
| bboxes = model[1](np. array(image)) | |
| img_show = draw_bbox(img_show, bboxes) | |
| keypoints,scores = model[2](np. array(image),bboxes=bboxes) | |
| img_show = draw_skeleton(img_show,keypoints,scores) | |
| data = transform(image) | |
| data = data.unsqueeze(0).to(device) | |
| out = model[0](data) | |
| out = F.interpolate(out, [image.size[1],image.size[0]], mode="bilinear") | |
| output = out[0].permute(1,2,0) | |
| parsing = torch.argmax(output,dim=2).cpu().numpy() | |
| output_im = Image.fromarray(np.asarray(parsing, dtype=np.uint8)) | |
| image = Image.fromarray(np.asarray(img_show, dtype=np.uint8)) | |
| output_im.putpalette(palette) | |
| output_im = output_im.convert('RGB') | |
| # output_im.save('output.png') | |
| res = Image.blend(image.convert('RGB'), output_im, 0.5) | |
| return output_im, res | |
| def load_parsing_model(): | |
| model = torch.jit.load(Path("models/faceparsing_512_512.pt")) | |
| model.eval() | |
| return model | |
| def main(): | |
| device = torch.device('cpu') | |
| model_ls =[] | |
| model = load_parsing_model() | |
| transform = T.Compose([ | |
| T.Resize((512, 512), interpolation=PIL.Image.NEAREST), | |
| T.ToTensor(), | |
| T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
| ]) | |
| palette = get_palette(20) | |
| det_model = YOLOX('models/det.onnx',model_input_size=(640,640),backend='onnxruntime', device='cpu') | |
| pose_model = RTMPose('models/pose.onnx', model_input_size=(192, 256),to_openpose=False, backend='onnxruntime', device='cpu') | |
| model_ls.append(model) | |
| model_ls.append(det_model) | |
| model_ls.append(pose_model) | |
| func = functools.partial(predict, | |
| model=model_ls, | |
| transform=transform, | |
| device=device,palette=palette) | |
| gr.Interface( | |
| fn=func, | |
| inputs=gr.Image(label='Input', type='pil'), | |
| outputs=[ | |
| gr.Image(label='Predicted Labels', type='pil'), | |
| gr.Image(label='Masked', type='pil'), | |
| ], | |
| title=TITLE, | |
| ).queue().launch(show_api=False) | |
| if __name__ == "__main__": | |
| main() | |