Spaces:
Runtime error
Runtime error
| import os | |
| os.system("pip install git+https://github.com/elliottzheng/face-detection.git@master") | |
| os.system("git clone https://github.com/thohemp/6DRepNet") | |
| import sys | |
| sys.path.append("6DRepNet") | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from face_detection import RetinaFace | |
| from model import SixDRepNet | |
| import utils | |
| import cv2 | |
| from PIL import Image | |
| snapshot_path = hf_hub_download(repo_id="osanseviero/6DRepNet_300W_LP_AFLW2000", filename="model.pth") | |
| model = SixDRepNet(backbone_name='RepVGG-B1g2', | |
| backbone_file='', | |
| deploy=True, | |
| pretrained=False) | |
| detector = RetinaFace(0) | |
| saved_state_dict = torch.load(os.path.join( | |
| snapshot_path), map_location='cpu') | |
| if 'model_state_dict' in saved_state_dict: | |
| model.load_state_dict(saved_state_dict['model_state_dict']) | |
| else: | |
| model.load_state_dict(saved_state_dict) | |
| model.cuda(0) | |
| model.eval() | |
| def predict(frame): | |
| faces = detector(frame) | |
| for box, landmarks, score in faces: | |
| # Print the location of each face in this image | |
| if score < .95: | |
| continue | |
| x_min = int(box[0]) | |
| y_min = int(box[1]) | |
| x_max = int(box[2]) | |
| y_max = int(box[3]) | |
| bbox_width = abs(x_max - x_min) | |
| bbox_height = abs(y_max - y_min) | |
| x_min = max(0,x_min-int(0.2*bbox_height)) | |
| y_min = max(0,y_min-int(0.2*bbox_width)) | |
| x_max = x_max+int(0.2*bbox_height) | |
| y_max = y_max+int(0.2*bbox_width) | |
| img = frame[y_min:y_max,x_min:x_max] | |
| img = cv2.resize(img, (244, 244))/255.0 | |
| img = img.transpose(2, 0, 1) | |
| img = torch.from_numpy(img).type(torch.FloatTensor) | |
| img = torch.Tensor(img).cuda(0) | |
| img=img.unsqueeze(0) | |
| R_pred = model(img) | |
| euler = utils.compute_euler_angles_from_rotation_matrices( | |
| R_pred)*180/np.pi | |
| p_pred_deg = euler[:, 0].cpu() | |
| y_pred_deg = euler[:, 1].cpu() | |
| r_pred_deg = euler[:, 2].cpu() | |
| return utils.plot_pose_cube(frame, y_pred_deg, p_pred_deg, r_pred_deg, x_min + int(.5*(x_max-x_min)), y_min + int(.5*(y_max-y_min)), size = bbox_width) | |
| title = "6D Rotation Representation for Unconstrained Head Pose Estimation" | |
| description = "Gradio demo for 6DRepNet. To use it, simply click the camera picture. Read more at the links below." | |
| article = "<div style='text-align: center;'><a href='https://github.com/thohemp/6DRepNet' target='_blank'>Github Repo</a> | <a href='https://arxiv.org/abs/2202.12555' target='_blank'>Paper</a></div>" | |
| image_flip_css = """ | |
| .input-image .image-preview img{ | |
| -webkit-transform: scaleX(-1); | |
| transform: scaleX(-1) !important; | |
| } | |
| .output-image img { | |
| -webkit-transform: scaleX(-1); | |
| transform: scaleX(-1) !important; | |
| } | |
| """ | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.inputs.Image(label="Input Image", source="webcam"), | |
| outputs='image', | |
| live=True, | |
| title=title, | |
| description=description, | |
| article=article, | |
| css = image_flip_css | |
| ) | |
| iface.launch() |