File size: 1,432 Bytes
c0edf01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn

# -------------------------------
# PyTorch model definition
# -------------------------------
class KeyPointClassifierModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(42, 20)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(20, 10)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(10, 12)  # match checkpoint output classes

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x

# -------------------------------
# Wrapper class for easy usage
# -------------------------------
class KeyPointClassifier:
    def __init__(self, model_path="keypoint_classifier_pytorch.pth", device='cpu'):
        self.device = device
        self.model = KeyPointClassifierModel()
        # Load the checkpoint
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.to(self.device)
        self.model.eval()

    def __call__(self, landmark_list):
        with torch.no_grad():
            x = torch.tensor([landmark_list], dtype=torch.float32).to(self.device)
            output = self.model(x)
            prob = torch.softmax(output, dim=1)
            conf, pred = torch.max(prob, dim=1)
            return pred.item(), conf.item()