|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os.path as osp |
|
|
from copy import deepcopy |
|
|
from typing import Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from mmengine.config import Config |
|
|
from mmengine.dataset import pseudo_collate |
|
|
from mmengine.structures import InstanceData, PixelData |
|
|
|
|
|
from mmpose.structures import MultilevelPixelData, PoseDataSample |
|
|
from mmpose.structures.bbox import bbox_xyxy2cs |
|
|
|
|
|
|
|
|
def get_coco_sample( |
|
|
img_shape=(240, 320), |
|
|
img_fill: Optional[int] = None, |
|
|
num_instances=1, |
|
|
with_bbox_cs=True, |
|
|
with_img_mask=False, |
|
|
random_keypoints_visible=False, |
|
|
non_occlusion=False): |
|
|
"""Create a dummy data sample in COCO style.""" |
|
|
rng = np.random.RandomState(0) |
|
|
h, w = img_shape |
|
|
if img_fill is None: |
|
|
img = np.random.randint(0, 256, (h, w, 3), dtype=np.uint8) |
|
|
else: |
|
|
img = np.full((h, w, 3), img_fill, dtype=np.uint8) |
|
|
|
|
|
if non_occlusion: |
|
|
bbox = _rand_bboxes(rng, num_instances, w / num_instances, h) |
|
|
for i in range(num_instances): |
|
|
bbox[i, 0::2] += w / num_instances * i |
|
|
else: |
|
|
bbox = _rand_bboxes(rng, num_instances, w, h) |
|
|
|
|
|
keypoints = _rand_keypoints(rng, bbox, 17) |
|
|
if random_keypoints_visible: |
|
|
keypoints_visible = np.random.randint(0, 2, (num_instances, |
|
|
17)).astype(np.float32) |
|
|
else: |
|
|
keypoints_visible = np.full((num_instances, 17), 1, dtype=np.float32) |
|
|
|
|
|
upper_body_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] |
|
|
lower_body_ids = [11, 12, 13, 14, 15, 16] |
|
|
flip_pairs = [[2, 1], [1, 2], [4, 3], [3, 4], [6, 5], [5, 6], [8, 7], |
|
|
[7, 8], [10, 9], [9, 10], [12, 11], [11, 12], [14, 13], |
|
|
[13, 14], [16, 15], [15, 16]] |
|
|
flip_indices = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] |
|
|
dataset_keypoint_weights = np.array([ |
|
|
1., 1., 1., 1., 1., 1., 1., 1.2, 1.2, 1.5, 1.5, 1., 1., 1.2, 1.2, 1.5, |
|
|
1.5 |
|
|
]).astype(np.float32) |
|
|
|
|
|
data = { |
|
|
'img': img, |
|
|
'img_shape': img_shape, |
|
|
'ori_shape': img_shape, |
|
|
'bbox': bbox, |
|
|
'keypoints': keypoints, |
|
|
'keypoints_visible': keypoints_visible, |
|
|
'upper_body_ids': upper_body_ids, |
|
|
'lower_body_ids': lower_body_ids, |
|
|
'flip_pairs': flip_pairs, |
|
|
'flip_indices': flip_indices, |
|
|
'dataset_keypoint_weights': dataset_keypoint_weights, |
|
|
'invalid_segs': [], |
|
|
} |
|
|
|
|
|
if with_bbox_cs: |
|
|
data['bbox_center'], data['bbox_scale'] = bbox_xyxy2cs(data['bbox']) |
|
|
|
|
|
if with_img_mask: |
|
|
data['img_mask'] = np.random.randint(0, 2, (h, w), dtype=np.uint8) |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
def get_packed_inputs(batch_size=2, |
|
|
num_instances=1, |
|
|
num_keypoints=17, |
|
|
num_levels=1, |
|
|
img_shape=(256, 192), |
|
|
input_size=(192, 256), |
|
|
heatmap_size=(48, 64), |
|
|
simcc_split_ratio=2.0, |
|
|
with_heatmap=True, |
|
|
with_reg_label=True, |
|
|
with_simcc_label=True): |
|
|
"""Create a dummy batch of model inputs and data samples.""" |
|
|
rng = np.random.RandomState(0) |
|
|
|
|
|
inputs_list = [] |
|
|
for idx in range(batch_size): |
|
|
inputs = dict() |
|
|
|
|
|
|
|
|
h, w = img_shape |
|
|
image = rng.randint(0, 255, size=(3, h, w), dtype=np.uint8) |
|
|
inputs['inputs'] = torch.from_numpy(image) |
|
|
|
|
|
|
|
|
img_meta = { |
|
|
'id': idx, |
|
|
'img_id': idx, |
|
|
'img_path': '<demo>.png', |
|
|
'img_shape': img_shape, |
|
|
'input_size': input_size, |
|
|
'flip': False, |
|
|
'flip_direction': None, |
|
|
'flip_indices': list(range(num_keypoints)) |
|
|
} |
|
|
|
|
|
np.random.shuffle(img_meta['flip_indices']) |
|
|
data_sample = PoseDataSample(metainfo=img_meta) |
|
|
|
|
|
|
|
|
gt_instances = InstanceData() |
|
|
gt_instance_labels = InstanceData() |
|
|
bboxes = _rand_bboxes(rng, num_instances, w, h) |
|
|
bbox_centers, bbox_scales = bbox_xyxy2cs(bboxes) |
|
|
|
|
|
keypoints = _rand_keypoints(rng, bboxes, num_keypoints) |
|
|
keypoints_visible = np.ones((num_instances, num_keypoints), |
|
|
dtype=np.float32) |
|
|
|
|
|
|
|
|
|
|
|
if num_levels > 1: |
|
|
keypoint_weights = np.tile(keypoints_visible[:, None], |
|
|
(1, num_levels, 1)) |
|
|
else: |
|
|
keypoint_weights = keypoints_visible.copy() |
|
|
|
|
|
gt_instances.bboxes = bboxes |
|
|
gt_instances.bbox_centers = bbox_centers |
|
|
gt_instances.bbox_scales = bbox_scales |
|
|
gt_instances.bbox_scores = np.ones((num_instances, ), dtype=np.float32) |
|
|
gt_instances.keypoints = keypoints |
|
|
gt_instances.keypoints_visible = keypoints_visible |
|
|
|
|
|
gt_instance_labels.keypoint_weights = torch.FloatTensor( |
|
|
keypoint_weights) |
|
|
|
|
|
if with_reg_label: |
|
|
gt_instance_labels.keypoint_labels = torch.FloatTensor(keypoints / |
|
|
input_size) |
|
|
|
|
|
if with_simcc_label: |
|
|
len_x = np.around(input_size[0] * simcc_split_ratio) |
|
|
len_y = np.around(input_size[1] * simcc_split_ratio) |
|
|
gt_instance_labels.keypoint_x_labels = torch.FloatTensor( |
|
|
_rand_simcc_label(rng, num_instances, num_keypoints, len_x)) |
|
|
gt_instance_labels.keypoint_y_labels = torch.FloatTensor( |
|
|
_rand_simcc_label(rng, num_instances, num_keypoints, len_y)) |
|
|
|
|
|
|
|
|
if with_heatmap: |
|
|
if num_levels == 1: |
|
|
gt_fields = PixelData() |
|
|
|
|
|
W, H = heatmap_size |
|
|
heatmaps = rng.rand(num_keypoints, H, W) |
|
|
gt_fields.heatmaps = torch.FloatTensor(heatmaps) |
|
|
else: |
|
|
|
|
|
heatmaps = [] |
|
|
for _ in range(num_levels): |
|
|
W, H = heatmap_size |
|
|
heatmaps_ = rng.rand(num_keypoints, H, W) |
|
|
heatmaps.append(torch.FloatTensor(heatmaps_)) |
|
|
|
|
|
gt_fields = MultilevelPixelData() |
|
|
gt_fields.heatmaps = heatmaps |
|
|
data_sample.gt_fields = gt_fields |
|
|
|
|
|
data_sample.gt_instances = gt_instances |
|
|
data_sample.gt_instance_labels = gt_instance_labels |
|
|
|
|
|
inputs['data_samples'] = data_sample |
|
|
inputs_list.append(inputs) |
|
|
|
|
|
packed_inputs = pseudo_collate(inputs_list) |
|
|
return packed_inputs |
|
|
|
|
|
|
|
|
def _rand_keypoints(rng, bboxes, num_keypoints): |
|
|
n = bboxes.shape[0] |
|
|
relative_pos = rng.rand(n, num_keypoints, 2) |
|
|
keypoints = relative_pos * bboxes[:, None, :2] + ( |
|
|
1 - relative_pos) * bboxes[:, None, 2:4] |
|
|
|
|
|
return keypoints |
|
|
|
|
|
|
|
|
def _rand_simcc_label(rng, num_instances, num_keypoints, len_feats): |
|
|
simcc_label = rng.rand(num_instances, num_keypoints, int(len_feats)) |
|
|
return simcc_label |
|
|
|
|
|
|
|
|
def _rand_bboxes(rng, num_instances, img_w, img_h): |
|
|
cx, cy = rng.rand(num_instances, 2).T |
|
|
bw, bh = 0.2 + 0.8 * rng.rand(num_instances, 2).T |
|
|
|
|
|
tl_x = ((cx * img_w) - (img_w * bw / 2)).clip(0, img_w) |
|
|
tl_y = ((cy * img_h) - (img_h * bh / 2)).clip(0, img_h) |
|
|
br_x = ((cx * img_w) + (img_w * bw / 2)).clip(0, img_w) |
|
|
br_y = ((cy * img_h) + (img_h * bh / 2)).clip(0, img_h) |
|
|
|
|
|
bboxes = np.vstack([tl_x, tl_y, br_x, br_y]).T |
|
|
return bboxes |
|
|
|
|
|
|
|
|
def get_repo_dir(): |
|
|
"""Return the path of the MMPose repo directory.""" |
|
|
try: |
|
|
|
|
|
repo_dir = osp.dirname(osp.dirname(osp.dirname(__file__))) |
|
|
except NameError: |
|
|
|
|
|
import mmpose |
|
|
repo_dir = osp.dirname(osp.dirname(mmpose.__file__)) |
|
|
|
|
|
return repo_dir |
|
|
|
|
|
|
|
|
def get_config_file(fn: str): |
|
|
"""Return full path of a config file from the given relative path.""" |
|
|
repo_dir = get_repo_dir() |
|
|
if fn.startswith('configs'): |
|
|
fn_config = osp.join(repo_dir, fn) |
|
|
else: |
|
|
fn_config = osp.join(repo_dir, 'configs', fn) |
|
|
|
|
|
if not osp.isfile(fn_config): |
|
|
raise FileNotFoundError(f'Cannot find config file {fn_config}') |
|
|
|
|
|
return fn_config |
|
|
|
|
|
|
|
|
def get_pose_estimator_cfg(fn: str): |
|
|
"""Load model config from a config file.""" |
|
|
|
|
|
fn_config = get_config_file(fn) |
|
|
config = Config.fromfile(fn_config) |
|
|
return deepcopy(config.model) |
|
|
|