File size: 4,607 Bytes
5db43ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
import os
#import pprint
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
#import tools.init_paths
from lib.config import cfg
from lib.config import update_config
#from core.loss import JointsMSELoss
#from core.function import validate
#from core.inference import get_max_preds
#from utils.utils import create_logger

#import dataset
#from datasets.mannequin import MannequinDataset
import lib.models as models
import numpy as np

def parse_args():
    parser = argparse.ArgumentParser(description='Train keypoints network')
    # general
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        # required=True,
                        default='experiments/mpii/hrnet/w32_256x256_adam_lr1e-3.yaml',
                        type=str)

    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    parser.add_argument('--modelDir',
                        help='model directory',
                        type=str,
                        default='')
    parser.add_argument('--logDir',
                        help='log directory',
                        type=str,
                        default='')
    parser.add_argument('--dataDir',
                        help='data directory',
                        type=str,
                        default='')
    parser.add_argument('--prevModelDir',
                        help='prev Model directory',
                        type=str,
                        default='')

    args = parser.parse_args(args=[])
    return args


class HumanPose():
    def __init__(self, ckpt=None,useCUDA=True):
        self.useCUDA=useCUDA

        self.model = self._load_model(ckpt=ckpt)

    def __call__(self, img):
        return self.model(img)


    def _load_model(self, ckpt):
        if ckpt is None:
            #ckpt= 'output/deepfashion2/pose_hrnet/top1_forward_only/2022-08-08-20-32/model_best.pth'
            ckpt = 'pretrained_models/pose_hrnet/human_pose.pth'
        if not os.path.exists(ckpt):
            print("Model not found, downloading from google drive...")
            import gdown
            path, name=os.path.split(ckpt)
            os.makedirs(path,exist_ok=True)
            id = "1_wn2ifmoQprBrFvUCDedjPON4Y6jsN-v"
            output = ckpt
            gdown.download(
                f"https://drive.google.com/uc?export=download&confirm=pbef&id="+id,
                output
            )
        args = parse_args()
        args.opts=[]
        update_config(cfg, args)



        # cudnn related setting
        cudnn.benchmark = cfg.CUDNN.BENCHMARK
        torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
        torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

        model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(
            cfg, is_train=False
        )
        if not self.useCUDA:
            model.load_state_dict(torch.load(ckpt,map_location=torch.device('cpu')), strict=True)
        else:
            model.load_state_dict(torch.load(ckpt), strict=True)

        model.eval()
        if torch.cuda.is_available() and self.useCUDA:
            model=model.cuda()
        return model


class HumanJointHeatmap:
    def __init__(self):
        self.joint_detector = HumanPose()
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        resize2 = transforms.Resize((256, 256))
        self.joint_pre_transform = transforms.Compose([normalize, resize2])
        self.post_transform = transforms.Resize((512, 512))
        self.joint_id_list = [2, 3, 6, 7, 11, 12, 13, 14]

    def forward(self,img):
        if not isinstance(img,torch.Tensor):
            img = img.astype(np.float32)
            img /= 255.0
            img = torch.from_numpy(img)
            img = img.permute(2, 0, 1)
            img = img.unsqueeze(0)
            if torch.cuda.is_available():
                img = img.cuda()
        with torch.no_grad():
            heatmaps = self.joint_detector(self.joint_pre_transform(img))
        heatmaps = heatmaps[:, self.joint_id_list, :, :]
        heatmaps = self.post_transform(heatmaps)
        return heatmaps


if __name__ == '__main__':
    human_pose =HumanPose()
    example_input = torch.zeros(1, 3, 256, 256)
    output = human_pose(example_input)
    print(output.shape)