import torch import torch.nn as nn # ------------------------------- # PyTorch model definition # ------------------------------- class KeyPointClassifierModel(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(42, 20) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(20, 10) self.relu2 = nn.ReLU() self.fc3 = nn.Linear(10, 12) # match checkpoint output classes def forward(self, x): x = self.fc1(x) x = self.relu1(x) x = self.fc2(x) x = self.relu2(x) x = self.fc3(x) return x # ------------------------------- # Wrapper class for easy usage # ------------------------------- class KeyPointClassifier: def __init__(self, model_path="keypoint_classifier_pytorch.pth", device='cpu'): self.device = device self.model = KeyPointClassifierModel() # Load the checkpoint self.model.load_state_dict(torch.load(model_path, map_location=self.device)) self.model.to(self.device) self.model.eval() def __call__(self, landmark_list): with torch.no_grad(): x = torch.tensor([landmark_list], dtype=torch.float32).to(self.device) output = self.model(x) prob = torch.softmax(output, dim=1) conf, pred = torch.max(prob, dim=1) return pred.item(), conf.item()