|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import models, transforms |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model(): |
|
|
model = models.vgg16(pretrained=True) |
|
|
for param in model.parameters(): |
|
|
param.requires_grad = False |
|
|
model.avgpool = nn.Sequential( |
|
|
nn.Conv2d(512,512,3), |
|
|
nn.MaxPool2d(2), |
|
|
nn.Flatten() |
|
|
) |
|
|
model.classifier = nn.Sequential( |
|
|
nn.Linear(2048,512), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.5), |
|
|
nn.Linear(512,136), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
return model |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
model = get_model().to(device) |
|
|
model.load_state_dict(torch.load("facial_keypoints.pth", map_location=device)) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Resize((224,224)), |
|
|
transforms.Normalize(mean=[0.485,0.456,0.406], |
|
|
std=[0.229,0.224,0.225]) |
|
|
]) |
|
|
|
|
|
def denormalize_keypoints(pred, img_h=224, img_w=224): |
|
|
pred = pred.detach().cpu().numpy() |
|
|
x = pred[:,:68] * img_w |
|
|
y = pred[:,68:] * img_h |
|
|
return np.stack([x,y], axis=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_keypoints(image): |
|
|
|
|
|
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) / 255.0 |
|
|
img_resized = cv2.resize(img, (224,224)) |
|
|
input_tensor = transform(img_resized).unsqueeze(0).float().to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
pred = model(input_tensor) |
|
|
kps = denormalize_keypoints(pred)[0] |
|
|
|
|
|
|
|
|
vis_img = cv2.cvtColor((img_resized*255).astype(np.uint8), cv2.COLOR_BGR2RGB) |
|
|
for (x,y) in kps: |
|
|
cv2.circle(vis_img, (int(x), int(y)), 2, (255,0,0), -1) |
|
|
|
|
|
return vis_img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict_keypoints, |
|
|
inputs=gr.Image(type="pil"), |
|
|
outputs=gr.Image(type="numpy"), |
|
|
title="Facial Keypoints Detection", |
|
|
description="Upload a face image and the model will predict 68 facial keypoints." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|