File size: 3,917 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import copy
import os
import pickle
import time
from argparse import ArgumentParser

import cv2
import numpy as np
import torch
from mmdet.apis import init_detector
from mmengine.dataset import Compose, pseudo_collate
from mmengine.registry import init_default_scope
from mmpose.apis import init_model
from PIL import Image


def parse_args():
    parser = ArgumentParser()
    parser.add_argument('root', help='Video folder root')
    parser.add_argument('--pose_config', help='Pose config file')
    parser.add_argument('--pose_ckpt', help='Pose checkpoint file')
    parser.add_argument('--det_config', help='Hand detection config file')
    parser.add_argument('--det_ckpt', help='Hand detection checkpoint file')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    args = parser.parse_args()
    return args


@torch.no_grad()
def inference_topdown(model, pose_pipeline, det_model, det_pipeline, folder):

    img_paths = [f'{folder}/{img}' for img in os.listdir(folder)]

    w, h = Image.open(img_paths[0]).size
    bbox0 = np.array([[0, 0, w, h]], dtype=np.float32)

    imgs = [cv2.imread(img_path) for img_path in img_paths]

    data_list = [
        dict(img=copy.deepcopy(img), img_id=idx)
        for idx, img in enumerate(imgs)
    ]
    data_list = [det_pipeline(data_info) for data_info in data_list]
    batch = pseudo_collate(data_list)
    bbox_results = det_model.test_step(batch)
    bboxes = [i.pred_instances.bboxes[:1].cpu().numpy() for i in bbox_results]
    scores = []
    for i in bbox_results:
        try:
            score = i.pred_instances.scores[0].item()
        except Exception as ex:
            print(ex)
            score = 0
        scores.append(score)
    data_list = []
    for img, bbox, score in zip(imgs, bboxes, scores):
        data_info = dict(img=img)
        if bbox.shape == bbox0.shape and score > 0.3:
            if score > 0.5:
                data_info['bbox'] = bbox
            else:
                w = (score - 0.1) / 0.4
                data_info['bbox'] = w * bbox + (1 - w) * bbox0
        else:
            data_info['bbox'] = bbox0
        data_info['bbox_score'] = np.ones(1, dtype=np.float32)  # shape (1,)
        data_info.update(model.dataset_meta)
        data_list.append(pose_pipeline(data_info))

    batch = pseudo_collate(data_list)
    results = model.test_step(batch)

    lookup = {}
    for img_path, result in zip(img_paths, results):
        keypoints = result.pred_instances.keypoints
        scores = result.pred_instances.keypoint_scores
        lookup[img_path] = (keypoints, scores, (w, h))
    return lookup


def main():
    args = parse_args()

    det_model = init_detector(
        args.det_config, args.det_ckpt, device=args.device)
    det_model.cfg.test_dataloader.dataset.pipeline[
        0].type = 'mmdet.LoadImageFromNDArray'
    det_pipeline = Compose(det_model.cfg.test_dataloader.dataset.pipeline)

    model = init_model(
        args.pose_config, args.pose_checkpoint, device=args.device)
    init_default_scope(model.cfg.get('default_scope', 'mmpose'))

    folders = [f'{args.root}/{folder}' for folder in os.listdir(args.root)]

    pose_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
    # inference a single image
    lookup = {}
    L = len(folders)
    t = time.time()
    for idx, folder in enumerate(folders):
        results = inference_topdown(model, pose_pipeline, det_model,
                                    det_pipeline, folder)
        lookup.update(results)
        if idx % 100 == 99:
            eta = (time.time() - t) / (idx + 1) * (L - idx) / 3600
            print('Require %.2f hours' % eta)

    with open('jester.pkl', 'wb') as f:
        pickle.dump(lookup, f)


if __name__ == '__main__':
    main()