| |
| |
| |
| |
|
|
| import numpy as np |
| import torch |
| from munkres import Munkres |
|
|
| from mmpose.core.evaluation import post_dark_udp |
|
|
|
|
| def _py_max_match(scores): |
| """Apply munkres algorithm to get the best match. |
| |
| Args: |
| scores(np.ndarray): cost matrix. |
| |
| Returns: |
| np.ndarray: best match. |
| """ |
| m = Munkres() |
| tmp = m.compute(scores) |
| tmp = np.array(tmp).astype(int) |
| return tmp |
|
|
|
|
| def _match_by_tag(inp, params): |
| """Match joints by tags. Use Munkres algorithm to calculate the best match |
| for keypoints grouping. |
| |
| Note: |
| number of keypoints: K |
| max number of people in an image: M (M=30 by default) |
| dim of tags: L |
| If use flip testing, L=2; else L=1. |
| |
| Args: |
| inp(tuple): |
| tag_k (np.ndarray[KxMxL]): tag corresponding to the |
| top k values of feature map per keypoint. |
| loc_k (np.ndarray[KxMx2]): top k locations of the |
| feature maps for keypoint. |
| val_k (np.ndarray[KxM]): top k value of the |
| feature maps per keypoint. |
| params(Params): class Params(). |
| |
| Returns: |
| np.ndarray: result of pose groups. |
| """ |
| assert isinstance(params, _Params), 'params should be class _Params()' |
|
|
| tag_k, loc_k, val_k = inp |
|
|
| default_ = np.zeros((params.num_joints, 3 + tag_k.shape[2]), |
| dtype=np.float32) |
|
|
| joint_dict = {} |
| tag_dict = {} |
| for i in range(params.num_joints): |
| idx = params.joint_order[i] |
|
|
| tags = tag_k[idx] |
| joints = np.concatenate((loc_k[idx], val_k[idx, :, None], tags), 1) |
| mask = joints[:, 2] > params.detection_threshold |
| tags = tags[mask] |
| joints = joints[mask] |
|
|
| if joints.shape[0] == 0: |
| continue |
|
|
| if i == 0 or len(joint_dict) == 0: |
| for tag, joint in zip(tags, joints): |
| key = tag[0] |
| joint_dict.setdefault(key, np.copy(default_))[idx] = joint |
| tag_dict[key] = [tag] |
| else: |
| grouped_keys = list(joint_dict.keys())[:params.max_num_people] |
| grouped_tags = [np.mean(tag_dict[i], axis=0) for i in grouped_keys] |
|
|
| if (params.ignore_too_much |
| and len(grouped_keys) == params.max_num_people): |
| continue |
|
|
| diff = joints[:, None, 3:] - np.array(grouped_tags)[None, :, :] |
| diff_normed = np.linalg.norm(diff, ord=2, axis=2) |
| diff_saved = np.copy(diff_normed) |
|
|
| if params.use_detection_val: |
| diff_normed = np.round(diff_normed) * 100 - joints[:, 2:3] |
|
|
| num_added = diff.shape[0] |
| num_grouped = diff.shape[1] |
|
|
| if num_added > num_grouped: |
| diff_normed = np.concatenate( |
| (diff_normed, |
| np.zeros((num_added, num_added - num_grouped), |
| dtype=np.float32) + 1e10), |
| axis=1) |
|
|
| pairs = _py_max_match(diff_normed) |
| for row, col in pairs: |
| if (row < num_added and col < num_grouped |
| and diff_saved[row][col] < params.tag_threshold): |
| key = grouped_keys[col] |
| joint_dict[key][idx] = joints[row] |
| tag_dict[key].append(tags[row]) |
| else: |
| key = tags[row][0] |
| joint_dict.setdefault(key, np.copy(default_))[idx] = \ |
| joints[row] |
| tag_dict[key] = [tags[row]] |
|
|
| results = np.array([joint_dict[i] for i in joint_dict]).astype(np.float32) |
| return results |
|
|
|
|
| class _Params: |
| """A class of parameter. |
| |
| Args: |
| cfg(Config): config. |
| """ |
|
|
| def __init__(self, cfg): |
| self.num_joints = cfg['num_joints'] |
| self.max_num_people = cfg['max_num_people'] |
|
|
| self.detection_threshold = cfg['detection_threshold'] |
| self.tag_threshold = cfg['tag_threshold'] |
| self.use_detection_val = cfg['use_detection_val'] |
| self.ignore_too_much = cfg['ignore_too_much'] |
|
|
| if self.num_joints == 17: |
| self.joint_order = [ |
| i - 1 for i in |
| [1, 2, 3, 4, 5, 6, 7, 12, 13, 8, 9, 10, 11, 14, 15, 16, 17] |
| ] |
| else: |
| self.joint_order = list(np.arange(self.num_joints)) |
|
|
|
|
| class HeatmapParser: |
| """The heatmap parser for post processing.""" |
|
|
| def __init__(self, cfg): |
| self.params = _Params(cfg) |
| self.tag_per_joint = cfg['tag_per_joint'] |
| self.pool = torch.nn.MaxPool2d(cfg['nms_kernel'], 1, |
| cfg['nms_padding']) |
| self.use_udp = cfg.get('use_udp', False) |
| self.score_per_joint = cfg.get('score_per_joint', False) |
|
|
| def nms(self, heatmaps): |
| """Non-Maximum Suppression for heatmaps. |
| |
| Args: |
| heatmap(torch.Tensor): Heatmaps before nms. |
| |
| Returns: |
| torch.Tensor: Heatmaps after nms. |
| """ |
|
|
| maxm = self.pool(heatmaps) |
| maxm = torch.eq(maxm, heatmaps).float() |
| heatmaps = heatmaps * maxm |
|
|
| return heatmaps |
|
|
| def match(self, tag_k, loc_k, val_k): |
| """Group keypoints to human poses in a batch. |
| |
| Args: |
| tag_k (np.ndarray[NxKxMxL]): tag corresponding to the |
| top k values of feature map per keypoint. |
| loc_k (np.ndarray[NxKxMx2]): top k locations of the |
| feature maps for keypoint. |
| val_k (np.ndarray[NxKxM]): top k value of the |
| feature maps per keypoint. |
| |
| Returns: |
| list |
| """ |
|
|
| def _match(x): |
| return _match_by_tag(x, self.params) |
|
|
| return list(map(_match, zip(tag_k, loc_k, val_k))) |
|
|
| def top_k(self, heatmaps, tags): |
| """Find top_k values in an image. |
| |
| Note: |
| batch size: N |
| number of keypoints: K |
| heatmap height: H |
| heatmap width: W |
| max number of people: M |
| dim of tags: L |
| If use flip testing, L=2; else L=1. |
| |
| Args: |
| heatmaps (torch.Tensor[NxKxHxW]) |
| tags (torch.Tensor[NxKxHxWxL]) |
| |
| Returns: |
| dict: A dict containing top_k values. |
| |
| - tag_k (np.ndarray[NxKxMxL]): |
| tag corresponding to the top k values of |
| feature map per keypoint. |
| - loc_k (np.ndarray[NxKxMx2]): |
| top k location of feature map per keypoint. |
| - val_k (np.ndarray[NxKxM]): |
| top k value of feature map per keypoint. |
| """ |
| heatmaps = self.nms(heatmaps) |
| N, K, H, W = heatmaps.size() |
| heatmaps = heatmaps.view(N, K, -1) |
| val_k, ind = heatmaps.topk(self.params.max_num_people, dim=2) |
|
|
| tags = tags.view(tags.size(0), tags.size(1), W * H, -1) |
| if not self.tag_per_joint: |
| tags = tags.expand(-1, self.params.num_joints, -1, -1) |
|
|
| tag_k = torch.stack( |
| [torch.gather(tags[..., i], 2, ind) for i in range(tags.size(3))], |
| dim=3) |
|
|
| x = ind % W |
| y = ind // W |
|
|
| ind_k = torch.stack((x, y), dim=3) |
|
|
| results = { |
| 'tag_k': tag_k.cpu().numpy(), |
| 'loc_k': ind_k.cpu().numpy(), |
| 'val_k': val_k.cpu().numpy() |
| } |
|
|
| return results |
|
|
| @staticmethod |
| def adjust(results, heatmaps): |
| """Adjust the coordinates for better accuracy. |
| |
| Note: |
| batch size: N |
| number of keypoints: K |
| heatmap height: H |
| heatmap width: W |
| |
| Args: |
| results (list(np.ndarray)): Keypoint predictions. |
| heatmaps (torch.Tensor[NxKxHxW]): Heatmaps. |
| """ |
| _, _, H, W = heatmaps.shape |
| for batch_id, people in enumerate(results): |
| for people_id, people_i in enumerate(people): |
| for joint_id, joint in enumerate(people_i): |
| if joint[2] > 0: |
| x, y = joint[0:2] |
| xx, yy = int(x), int(y) |
| tmp = heatmaps[batch_id][joint_id] |
| if tmp[min(H - 1, yy + 1), xx] > tmp[max(0, yy - 1), |
| xx]: |
| y += 0.25 |
| else: |
| y -= 0.25 |
|
|
| if tmp[yy, min(W - 1, xx + 1)] > tmp[yy, |
| max(0, xx - 1)]: |
| x += 0.25 |
| else: |
| x -= 0.25 |
| results[batch_id][people_id, joint_id, |
| 0:2] = (x + 0.5, y + 0.5) |
| return results |
|
|
| @staticmethod |
| def refine(heatmap, tag, keypoints, use_udp=False): |
| """Given initial keypoint predictions, we identify missing joints. |
| |
| Note: |
| number of keypoints: K |
| heatmap height: H |
| heatmap width: W |
| dim of tags: L |
| If use flip testing, L=2; else L=1. |
| |
| Args: |
| heatmap: np.ndarray(K, H, W). |
| tag: np.ndarray(K, H, W) | np.ndarray(K, H, W, L) |
| keypoints: np.ndarray of size (K, 3 + L) |
| last dim is (x, y, score, tag). |
| use_udp: bool-unbiased data processing |
| |
| Returns: |
| np.ndarray: The refined keypoints. |
| """ |
|
|
| K, H, W = heatmap.shape |
| if len(tag.shape) == 3: |
| tag = tag[..., None] |
|
|
| tags = [] |
| for i in range(K): |
| if keypoints[i, 2] > 0: |
| |
| x, y = keypoints[i][:2].astype(int) |
| x = np.clip(x, 0, W - 1) |
| y = np.clip(y, 0, H - 1) |
| tags.append(tag[i, y, x]) |
|
|
| |
| prev_tag = np.mean(tags, axis=0) |
| results = [] |
|
|
| for _heatmap, _tag in zip(heatmap, tag): |
| |
| |
| distance_tag = (((_tag - |
| prev_tag[None, None, :])**2).sum(axis=2)**0.5) |
| norm_heatmap = _heatmap - np.round(distance_tag) |
|
|
| |
| y, x = np.unravel_index(np.argmax(norm_heatmap), _heatmap.shape) |
| xx = x.copy() |
| yy = y.copy() |
| |
| val = _heatmap[y, x] |
| if not use_udp: |
| |
| x += 0.5 |
| y += 0.5 |
|
|
| |
| if _heatmap[yy, min(W - 1, xx + 1)] > _heatmap[yy, max(0, xx - 1)]: |
| x += 0.25 |
| else: |
| x -= 0.25 |
|
|
| if _heatmap[min(H - 1, yy + 1), xx] > _heatmap[max(0, yy - 1), xx]: |
| y += 0.25 |
| else: |
| y -= 0.25 |
|
|
| results.append((x, y, val)) |
| results = np.array(results) |
|
|
| if results is not None: |
| for i in range(K): |
| |
| if results[i, 2] > 0 and keypoints[i, 2] == 0: |
| keypoints[i, :3] = results[i, :3] |
|
|
| return keypoints |
|
|
| def parse(self, heatmaps, tags, adjust=True, refine=True): |
| """Group keypoints into poses given heatmap and tag. |
| |
| Note: |
| batch size: N |
| number of keypoints: K |
| heatmap height: H |
| heatmap width: W |
| dim of tags: L |
| If use flip testing, L=2; else L=1. |
| |
| Args: |
| heatmaps (torch.Tensor[NxKxHxW]): model output heatmaps. |
| tags (torch.Tensor[NxKxHxWxL]): model output tagmaps. |
| |
| Returns: |
| tuple: A tuple containing keypoint grouping results. |
| |
| - results (list(np.ndarray)): Pose results. |
| - scores (list/list(np.ndarray)): Score of people. |
| """ |
| results = self.match(**self.top_k(heatmaps, tags)) |
|
|
| if adjust: |
| if self.use_udp: |
| for i in range(len(results)): |
| if results[i].shape[0] > 0: |
| results[i][..., :2] = post_dark_udp( |
| results[i][..., :2].copy(), heatmaps[i:i + 1, :]) |
| else: |
| results = self.adjust(results, heatmaps) |
|
|
| if self.score_per_joint: |
| scores = [i[:, 2] for i in results[0]] |
| else: |
| scores = [i[:, 2].mean() for i in results[0]] |
|
|
| if refine: |
| results = results[0] |
| |
| for i in range(len(results)): |
| heatmap_numpy = heatmaps[0].cpu().numpy() |
| tag_numpy = tags[0].cpu().numpy() |
| if not self.tag_per_joint: |
| tag_numpy = np.tile(tag_numpy, |
| (self.params.num_joints, 1, 1, 1)) |
| results[i] = self.refine( |
| heatmap_numpy, tag_numpy, results[i], use_udp=self.use_udp) |
| results = [results] |
|
|
| return results, scores |
|
|