import torch import torch.nn as nn from torchvision import models, transforms import cv2 import numpy as np import gradio as gr # ------------------------- # Model definition # ------------------------- def get_model(): model = models.vgg16(pretrained=True) for param in model.parameters(): param.requires_grad = False model.avgpool = nn.Sequential( nn.Conv2d(512,512,3), nn.MaxPool2d(2), nn.Flatten() ) model.classifier = nn.Sequential( nn.Linear(2048,512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512,136), # 68 x,y pairs nn.Sigmoid() ) return model device = "cuda" if torch.cuda.is_available() else "cpu" # Load model model = get_model().to(device) model.load_state_dict(torch.load("facial_keypoints.pth", map_location=device)) model.eval() # ------------------------- # Image preprocessing # ------------------------- transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((224,224)), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]) def denormalize_keypoints(pred, img_h=224, img_w=224): pred = pred.detach().cpu().numpy() x = pred[:,:68] * img_w y = pred[:,68:] * img_h return np.stack([x,y], axis=2) # ------------------------- # Inference function for Gradio # ------------------------- def predict_keypoints(image): # Convert PIL → CV2 → tensor img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) / 255.0 img_resized = cv2.resize(img, (224,224)) input_tensor = transform(img_resized).unsqueeze(0).float().to(device) with torch.no_grad(): pred = model(input_tensor) kps = denormalize_keypoints(pred)[0] # first batch only # Draw keypoints on image vis_img = cv2.cvtColor((img_resized*255).astype(np.uint8), cv2.COLOR_BGR2RGB) for (x,y) in kps: cv2.circle(vis_img, (int(x), int(y)), 2, (255,0,0), -1) return vis_img # ------------------------- # Gradio Interface # ------------------------- demo = gr.Interface( fn=predict_keypoints, inputs=gr.Image(type="pil"), outputs=gr.Image(type="numpy"), title="Facial Keypoints Detection", description="Upload a face image and the model will predict 68 facial keypoints." ) if __name__ == "__main__": demo.launch()