faranbutt789 commited on
Commit
9c6ae09
·
verified ·
1 Parent(s): 5fe0c4a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
+ import cv2
5
+ import numpy as np
6
+ import gradio as gr
7
+
8
+ # -------------------------
9
+ # Model definition
10
+ # -------------------------
11
+ def get_model():
12
+ model = models.vgg16(pretrained=True)
13
+ for param in model.parameters():
14
+ param.requires_grad = False
15
+ model.avgpool = nn.Sequential(
16
+ nn.Conv2d(512,512,3),
17
+ nn.MaxPool2d(2),
18
+ nn.Flatten()
19
+ )
20
+ model.classifier = nn.Sequential(
21
+ nn.Linear(2048,512),
22
+ nn.ReLU(),
23
+ nn.Dropout(0.5),
24
+ nn.Linear(512,136), # 68 x,y pairs
25
+ nn.Sigmoid()
26
+ )
27
+ return model
28
+
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+ # Load model
32
+ model = get_model().to(device)
33
+ model.load_state_dict(torch.load("facial_keypoints.pth", map_location=device))
34
+ model.eval()
35
+
36
+ # -------------------------
37
+ # Image preprocessing
38
+ # -------------------------
39
+ transform = transforms.Compose([
40
+ transforms.ToTensor(),
41
+ transforms.Resize((224,224)),
42
+ transforms.Normalize(mean=[0.485,0.456,0.406],
43
+ std=[0.229,0.224,0.225])
44
+ ])
45
+
46
+ def denormalize_keypoints(pred, img_h=224, img_w=224):
47
+ pred = pred.detach().cpu().numpy()
48
+ x = pred[:,:68] * img_w
49
+ y = pred[:,68:] * img_h
50
+ return np.stack([x,y], axis=2)
51
+
52
+ # -------------------------
53
+ # Inference function for Gradio
54
+ # -------------------------
55
+ def predict_keypoints(image):
56
+ # Convert PIL → CV2 → tensor
57
+ img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) / 255.0
58
+ img_resized = cv2.resize(img, (224,224))
59
+ input_tensor = transform(img_resized).unsqueeze(0).to(device)
60
+
61
+ with torch.no_grad():
62
+ pred = model(input_tensor)
63
+ kps = denormalize_keypoints(pred)[0] # first batch only
64
+
65
+ # Draw keypoints on image
66
+ vis_img = cv2.cvtColor((img_resized*255).astype(np.uint8), cv2.COLOR_BGR2RGB)
67
+ for (x,y) in kps:
68
+ cv2.circle(vis_img, (int(x), int(y)), 2, (255,0,0), -1)
69
+
70
+ return vis_img
71
+
72
+ # -------------------------
73
+ # Gradio Interface
74
+ # -------------------------
75
+ demo = gr.Interface(
76
+ fn=predict_keypoints,
77
+ inputs=gr.Image(type="pil"),
78
+ outputs=gr.Image(type="numpy"),
79
+ title="Facial Keypoints Detection",
80
+ description="Upload a face image and the model will predict 68 facial keypoints."
81
+ )
82
+
83
+ if __name__ == "__main__":
84
+ demo.launch()