faranbutt789's picture
Update app.py
60364fb verified
import torch
import torch.nn as nn
from torchvision import models, transforms
import cv2
import numpy as np
import gradio as gr
# -------------------------
# Model definition
# -------------------------
def get_model():
model = models.vgg16(pretrained=True)
for param in model.parameters():
param.requires_grad = False
model.avgpool = nn.Sequential(
nn.Conv2d(512,512,3),
nn.MaxPool2d(2),
nn.Flatten()
)
model.classifier = nn.Sequential(
nn.Linear(2048,512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512,136), # 68 x,y pairs
nn.Sigmoid()
)
return model
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model
model = get_model().to(device)
model.load_state_dict(torch.load("facial_keypoints.pth", map_location=device))
model.eval()
# -------------------------
# Image preprocessing
# -------------------------
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224,224)),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
def denormalize_keypoints(pred, img_h=224, img_w=224):
pred = pred.detach().cpu().numpy()
x = pred[:,:68] * img_w
y = pred[:,68:] * img_h
return np.stack([x,y], axis=2)
# -------------------------
# Inference function for Gradio
# -------------------------
def predict_keypoints(image):
# Convert PIL → CV2 → tensor
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) / 255.0
img_resized = cv2.resize(img, (224,224))
input_tensor = transform(img_resized).unsqueeze(0).float().to(device)
with torch.no_grad():
pred = model(input_tensor)
kps = denormalize_keypoints(pred)[0] # first batch only
# Draw keypoints on image
vis_img = cv2.cvtColor((img_resized*255).astype(np.uint8), cv2.COLOR_BGR2RGB)
for (x,y) in kps:
cv2.circle(vis_img, (int(x), int(y)), 2, (255,0,0), -1)
return vis_img
# -------------------------
# Gradio Interface
# -------------------------
demo = gr.Interface(
fn=predict_keypoints,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="numpy"),
title="Facial Keypoints Detection",
description="Upload a face image and the model will predict 68 facial keypoints."
)
if __name__ == "__main__":
demo.launch()