fastfit / parse_utils /dwpose.py
tigger13's picture
Upload 452 files
2711c5f verified
from typing import Callable, Dict, Optional, Union
import cv2
import numpy as np
import PIL
import PIL.Image
import torch
from easy_dwpose.body_estimation import Wholebody, resize_image
from easy_dwpose.draw import draw_openpose
class DWposeDetector:
def __init__(self, pretrained_model_name_or_path: str = "RedHash/DWPose", device: str = "сpu"):
local_dir = pretrained_model_name_or_path
self.pose_estimation = Wholebody(
device=device,
model_det=f"{local_dir}/yolox_l.onnx",
model_pose=f"{local_dir}/dw-ll_ucoco_384.onnx"
)
def _format_pose(self, candidates, scores, width, height):
num_candidates, _, locs = candidates.shape
candidates[..., 0] /= float(width)
candidates[..., 1] /= float(height)
bodies = candidates[:, :18].copy()
bodies = bodies.reshape(num_candidates * 18, locs)
body_scores = scores[:, :18]
for i in range(len(body_scores)):
for j in range(len(body_scores[i])):
if body_scores[i][j] > 0.3:
body_scores[i][j] = int(18 * i + j)
else:
body_scores[i][j] = -1
faces = candidates[:, 24:92]
faces_scores = scores[:, 24:92]
hands = np.vstack([candidates[:, 92:113], candidates[:, 113:]])
hands_scores = np.vstack([scores[:, 92:113], scores[:, 113:]])
pose = dict(
bodies=bodies,
body_scores=body_scores,
hands=hands,
hands_scores=hands_scores,
faces=faces,
faces_scores=faces_scores,
)
return pose
@torch.inference_mode()
def __call__(
self,
image: Union[PIL.Image.Image, np.ndarray],
detect_resolution: int = 512,
draw_pose: Optional[Callable] = draw_openpose,
output_type: str = "pil",
**kwargs,
) -> Union[PIL.Image.Image, np.ndarray, Dict]:
if not isinstance(image, np.ndarray):
image = np.array(image.convert("RGB"))
image = image.copy()
original_height, original_width, _ = image.shape
image = resize_image(image, target_resolution=detect_resolution)
height, width, _ = image.shape
candidates, scores = self.pose_estimation(image)
pose = self._format_pose(candidates, scores, width, height)
if not draw_pose:
return pose
pose_image = draw_pose(pose, height=height, width=width, **kwargs)
pose_image = cv2.resize(pose_image, (original_width, original_height), cv2.INTER_LANCZOS4)
if output_type == "pil":
pose_image = PIL.Image.fromarray(pose_image)
elif output_type == "np":
pass
else:
raise ValueError("output_type should be 'pil' or 'np'")
return pose_image