Spaces:
Build error
Build error
| # Modified from: | |
| # https://github.com/anibali/pytorch-stacked-hourglass | |
| # https://github.com/bearpaw/pytorch-pose | |
| import torch | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../')) | |
| from src.stacked_hourglass.utils.evaluation import final_preds_untransformed | |
| from src.stacked_hourglass.utils.imfit import fit, calculate_fit_contain_output_area | |
| from src.stacked_hourglass.utils.transforms import color_normalize, fliplr, flip_back | |
| def _check_batched(images): | |
| if isinstance(images, (tuple, list)): | |
| return True | |
| if images.ndimension() == 4: | |
| return True | |
| return False | |
| class HumanPosePredictor: | |
| def __init__(self, model, device=None, data_info=None, input_shape=None): | |
| """Helper class for predicting 2D human pose joint locations. | |
| Args: | |
| model: The model for generating joint heatmaps. | |
| device: The computational device to use for inference. | |
| data_info: Specifications of the data (defaults to ``Mpii.DATA_INFO``). | |
| input_shape: The input dimensions of the model (height, width). | |
| """ | |
| if device is None: | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| device = torch.device(device) | |
| model.to(device) | |
| self.model = model | |
| self.device = device | |
| if data_info is None: | |
| raise ValueError | |
| # self.data_info = Mpii.DATA_INFO | |
| else: | |
| self.data_info = data_info | |
| # Input shape ordering: H, W | |
| if input_shape is None: | |
| self.input_shape = (256, 256) | |
| elif isinstance(input_shape, int): | |
| self.input_shape = (input_shape, input_shape) | |
| else: | |
| self.input_shape = input_shape | |
| def do_forward(self, input_tensor): | |
| self.model.eval() | |
| with torch.no_grad(): | |
| output = self.model(input_tensor) | |
| return output | |
| def prepare_image(self, image): | |
| was_fixed_point = not image.is_floating_point() | |
| image = torch.empty_like(image, dtype=torch.float32).copy_(image) | |
| if was_fixed_point: | |
| image /= 255.0 | |
| if image.shape[-2:] != self.input_shape: | |
| image = fit(image, self.input_shape, fit_mode='contain') | |
| image = color_normalize(image, self.data_info.rgb_mean, self.data_info.rgb_stddev) | |
| return image | |
| def estimate_heatmaps(self, images, flip=False): | |
| is_batched = _check_batched(images) | |
| raw_images = images if is_batched else images.unsqueeze(0) | |
| input_tensor = torch.empty((len(raw_images), 3, *self.input_shape), | |
| device=self.device, dtype=torch.float32) | |
| for i, raw_image in enumerate(raw_images): | |
| input_tensor[i] = self.prepare_image(raw_image) | |
| heatmaps = self.do_forward(input_tensor)[-1].cpu() | |
| if flip: | |
| flip_input = fliplr(input_tensor) | |
| flip_heatmaps = self.do_forward(flip_input)[-1].cpu() | |
| heatmaps += flip_back(flip_heatmaps, self.data_info.hflip_indices) | |
| heatmaps /= 2 | |
| if is_batched: | |
| return heatmaps | |
| else: | |
| return heatmaps[0] | |
| def estimate_joints(self, images, flip=False): | |
| """Estimate human joint locations from input images. | |
| Images are expected to be centred on a human subject and scaled reasonably. | |
| Args: | |
| images: The images to estimate joint locations for. Can be a single image or a list | |
| of images. | |
| flip (bool): If set to true, evaluates on flipped versions of the images as well and | |
| averages the results. | |
| Returns: | |
| The predicted human joint locations in image pixel space. | |
| """ | |
| is_batched = _check_batched(images) | |
| raw_images = images if is_batched else images.unsqueeze(0) | |
| heatmaps = self.estimate_heatmaps(raw_images, flip=flip).cpu() | |
| # final_preds_untransformed compares the first component of shape with x and second with y | |
| # This relates to the image Width, Height (Heatmap has shape Height, Width) | |
| coords = final_preds_untransformed(heatmaps, heatmaps.shape[-2:][::-1]) | |
| # Rescale coords to pixel space of specified images. | |
| for i, image in enumerate(raw_images): | |
| # When returning to original image space we need to compensate for the fact that we are | |
| # used fit_mode='contain' when preparing the images for inference. | |
| y_off, x_off, height, width = calculate_fit_contain_output_area(*image.shape[-2:], *self.input_shape) | |
| coords[i, :, 1] *= self.input_shape[-2] / heatmaps.shape[-2] | |
| coords[i, :, 1] -= y_off | |
| coords[i, :, 1] *= image.shape[-2] / height | |
| coords[i, :, 0] *= self.input_shape[-1] / heatmaps.shape[-1] | |
| coords[i, :, 0] -= x_off | |
| coords[i, :, 0] *= image.shape[-1] / width | |
| if is_batched: | |
| return coords | |
| else: | |
| return coords[0] | |