Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| from __future__ import annotations | |
| import functools | |
| import os | |
| import pathlib | |
| import sys | |
| import tarfile | |
| import urllib.request | |
| from typing import Callable | |
| import cv2 | |
| import gradio as gr | |
| import huggingface_hub | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| import torchvision.transforms as T | |
| sys.path.insert(0, 'anime_face_landmark_detection') | |
| from CFA import CFA | |
| DESCRIPTION = '# [kanosawa/anime_face_landmark_detection](https://github.com/kanosawa/anime_face_landmark_detection)' | |
| NUM_LANDMARK = 24 | |
| CROP_SIZE = 128 | |
| def load_sample_image_paths() -> list[pathlib.Path]: | |
| image_dir = pathlib.Path('images') | |
| if not image_dir.exists(): | |
| dataset_repo = 'hysts/sample-images-TADNE' | |
| path = huggingface_hub.hf_hub_download(dataset_repo, | |
| 'images.tar.gz', | |
| repo_type='dataset') | |
| with tarfile.open(path) as f: | |
| f.extractall() | |
| return sorted(image_dir.glob('*')) | |
| def load_face_detector() -> cv2.CascadeClassifier: | |
| url = 'https://raw.githubusercontent.com/nagadomi/lbpcascade_animeface/master/lbpcascade_animeface.xml' | |
| path = pathlib.Path('lbpcascade_animeface.xml') | |
| if not path.exists(): | |
| urllib.request.urlretrieve(url, path.as_posix()) | |
| return cv2.CascadeClassifier(path.as_posix()) | |
| def load_landmark_detector(device: torch.device) -> torch.nn.Module: | |
| path = huggingface_hub.hf_hub_download( | |
| 'public-data/anime_face_landmark_detection', | |
| 'checkpoint_landmark_191116.pth') | |
| model = CFA(output_channel_num=NUM_LANDMARK + 1, checkpoint_name=path) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def detect(image_path: str, face_detector: cv2.CascadeClassifier, | |
| device: torch.device, transform: Callable, | |
| landmark_detector: torch.nn.Module) -> np.ndarray: | |
| image = cv2.imread(image_path) | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| preds = face_detector.detectMultiScale(gray, | |
| scaleFactor=1.1, | |
| minNeighbors=5, | |
| minSize=(24, 24)) | |
| image_h, image_w = image.shape[:2] | |
| pil_image = PIL.Image.fromarray(image[:, :, ::-1].copy()) | |
| res = image.copy() | |
| for x_orig, y_orig, w_orig, h_orig in preds: | |
| x0 = round(max(x_orig - w_orig / 8, 0)) | |
| x1 = round(min(x_orig + w_orig * 9 / 8, image_w)) | |
| y0 = round(max(y_orig - h_orig / 4, 0)) | |
| y1 = y_orig + h_orig | |
| w = x1 - x0 | |
| h = y1 - y0 | |
| temp = pil_image.crop((x0, y0, x1, y1)) | |
| temp = temp.resize((CROP_SIZE, CROP_SIZE), PIL.Image.BICUBIC) | |
| data = transform(temp) | |
| data = data.to(device).unsqueeze(0) | |
| heatmaps = landmark_detector(data) | |
| heatmaps = heatmaps[-1].cpu().numpy()[0] | |
| cv2.rectangle(res, (x0, y0), (x1, y1), (0, 255, 0), 2) | |
| for i in range(NUM_LANDMARK): | |
| heatmap = cv2.resize(heatmaps[i], (CROP_SIZE, CROP_SIZE), | |
| interpolation=cv2.INTER_CUBIC) | |
| pty, ptx = np.unravel_index(np.argmax(heatmap), heatmap.shape) | |
| pt_crop = np.round(np.array([ptx * w, pty * h]) / | |
| CROP_SIZE).astype(int) | |
| pt = np.array([x0, y0]) + pt_crop | |
| cv2.circle(res, tuple(pt), 2, (0, 0, 255), cv2.FILLED) | |
| return res[:, :, ::-1] | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| image_paths = load_sample_image_paths() | |
| examples = [[path.as_posix()] for path in image_paths] | |
| face_detector = load_face_detector() | |
| landmark_detector = load_landmark_detector(device) | |
| transform = T.Compose([ | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ]) | |
| fn = functools.partial(detect, | |
| face_detector=face_detector, | |
| device=device, | |
| transform=transform, | |
| landmark_detector=landmark_detector) | |
| with gr.Blocks(css='style.css') as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image = gr.Image(label='Input', type='filepath') | |
| run_button = gr.Button('Run') | |
| with gr.Column(): | |
| result = gr.Image(label='Result') | |
| gr.Examples(examples=examples, | |
| inputs=image, | |
| outputs=result, | |
| fn=fn, | |
| cache_examples=os.getenv('CACHE_EXAMPLES') == '1') | |
| run_button.click(fn=fn, inputs=image, outputs=result, api_name='predict') | |
| demo.queue(max_size=15).launch() | |