AZIIIIIIIIZ's picture
Upload 1039 files
d670799 verified
raw
history blame
3.92 kB
import copy
import os
import pickle
import time
from argparse import ArgumentParser
import cv2
import numpy as np
import torch
from mmdet.apis import init_detector
from mmengine.dataset import Compose, pseudo_collate
from mmengine.registry import init_default_scope
from mmpose.apis import init_model
from PIL import Image
def parse_args():
parser = ArgumentParser()
parser.add_argument('root', help='Video folder root')
parser.add_argument('--pose_config', help='Pose config file')
parser.add_argument('--pose_ckpt', help='Pose checkpoint file')
parser.add_argument('--det_config', help='Hand detection config file')
parser.add_argument('--det_ckpt', help='Hand detection checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
return args
@torch.no_grad()
def inference_topdown(model, pose_pipeline, det_model, det_pipeline, folder):
img_paths = [f'{folder}/{img}' for img in os.listdir(folder)]
w, h = Image.open(img_paths[0]).size
bbox0 = np.array([[0, 0, w, h]], dtype=np.float32)
imgs = [cv2.imread(img_path) for img_path in img_paths]
data_list = [
dict(img=copy.deepcopy(img), img_id=idx)
for idx, img in enumerate(imgs)
]
data_list = [det_pipeline(data_info) for data_info in data_list]
batch = pseudo_collate(data_list)
bbox_results = det_model.test_step(batch)
bboxes = [i.pred_instances.bboxes[:1].cpu().numpy() for i in bbox_results]
scores = []
for i in bbox_results:
try:
score = i.pred_instances.scores[0].item()
except Exception as ex:
print(ex)
score = 0
scores.append(score)
data_list = []
for img, bbox, score in zip(imgs, bboxes, scores):
data_info = dict(img=img)
if bbox.shape == bbox0.shape and score > 0.3:
if score > 0.5:
data_info['bbox'] = bbox
else:
w = (score - 0.1) / 0.4
data_info['bbox'] = w * bbox + (1 - w) * bbox0
else:
data_info['bbox'] = bbox0
data_info['bbox_score'] = np.ones(1, dtype=np.float32) # shape (1,)
data_info.update(model.dataset_meta)
data_list.append(pose_pipeline(data_info))
batch = pseudo_collate(data_list)
results = model.test_step(batch)
lookup = {}
for img_path, result in zip(img_paths, results):
keypoints = result.pred_instances.keypoints
scores = result.pred_instances.keypoint_scores
lookup[img_path] = (keypoints, scores, (w, h))
return lookup
def main():
args = parse_args()
det_model = init_detector(
args.det_config, args.det_ckpt, device=args.device)
det_model.cfg.test_dataloader.dataset.pipeline[
0].type = 'mmdet.LoadImageFromNDArray'
det_pipeline = Compose(det_model.cfg.test_dataloader.dataset.pipeline)
model = init_model(
args.pose_config, args.pose_checkpoint, device=args.device)
init_default_scope(model.cfg.get('default_scope', 'mmpose'))
folders = [f'{args.root}/{folder}' for folder in os.listdir(args.root)]
pose_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
# inference a single image
lookup = {}
L = len(folders)
t = time.time()
for idx, folder in enumerate(folders):
results = inference_topdown(model, pose_pipeline, det_model,
det_pipeline, folder)
lookup.update(results)
if idx % 100 == 99:
eta = (time.time() - t) / (idx + 1) * (L - idx) / 3600
print('Require %.2f hours' % eta)
with open('jester.pkl', 'wb') as f:
pickle.dump(lookup, f)
if __name__ == '__main__':
main()