keypoint_classifier_pytorch / keypoint_classifier_pytorch.py
BobbyDUVA's picture
pytorch initial commit (#1)
c0edf01 verified
import torch
import torch.nn as nn
# -------------------------------
# PyTorch model definition
# -------------------------------
class KeyPointClassifierModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(42, 20)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(20, 10)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(10, 12) # match checkpoint output classes
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
x = self.relu2(x)
x = self.fc3(x)
return x
# -------------------------------
# Wrapper class for easy usage
# -------------------------------
class KeyPointClassifier:
def __init__(self, model_path="keypoint_classifier_pytorch.pth", device='cpu'):
self.device = device
self.model = KeyPointClassifierModel()
# Load the checkpoint
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.to(self.device)
self.model.eval()
def __call__(self, landmark_list):
with torch.no_grad():
x = torch.tensor([landmark_list], dtype=torch.float32).to(self.device)
output = self.model(x)
prob = torch.softmax(output, dim=1)
conf, pred = torch.max(prob, dim=1)
return pred.item(), conf.item()