|
|
import argparse |
|
|
import os |
|
|
import sys |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple, Union |
|
|
import itertools |
|
|
import time |
|
|
|
|
|
import cv2 |
|
|
import imageio.v2 as imageio |
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image |
|
|
|
|
|
from preprocessing.dwpose.pose import convert_to_numpy, draw_pose |
|
|
from preprocessing.dwpose.wholebody import HWC3, Wholebody, resize_image |
|
|
from shared.utils import files_locator as fl |
|
|
from shared.utils.utils import process_images_multithread |
|
|
from shared.utils.utils import convert_tensor_to_image, convert_image_to_tensor |
|
|
|
|
|
ArrayImage = Union[np.ndarray, Image.Image, torch.Tensor] |
|
|
|
|
|
|
|
|
def _to_bgr_image(image: ArrayImage) -> np.ndarray: |
|
|
"""Convert supported image types to uint8 BGR HWC.""" |
|
|
if isinstance(image, torch.Tensor): |
|
|
|
|
|
if image.dim() == 4: |
|
|
image = image[:, 0] |
|
|
img = image.detach().cpu() |
|
|
if img.shape[0] in (1, 3, 4): |
|
|
pass |
|
|
elif img.shape[-1] in (1, 3, 4): |
|
|
img = img.permute(2, 0, 1) |
|
|
|
|
|
if img.min() < 0: |
|
|
img = (img + 1.0) * 127.5 |
|
|
elif img.max() <= 1.0: |
|
|
img = img * 255.0 |
|
|
arr = img.clamp(0, 255).byte().permute(1, 2, 0).numpy() |
|
|
if arr.shape[2] == 1: |
|
|
arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR) |
|
|
elif arr.shape[2] == 3: |
|
|
arr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) |
|
|
elif arr.shape[2] == 4: |
|
|
arr = cv2.cvtColor(arr, cv2.COLOR_RGBA2BGR) |
|
|
return HWC3(arr) |
|
|
was_pil = isinstance(image, Image.Image) |
|
|
arr = convert_to_numpy(image) |
|
|
|
|
|
|
|
|
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[0] != arr.shape[1]: |
|
|
arr = np.transpose(arr, (1, 2, 0)) |
|
|
|
|
|
if arr.ndim == 2: |
|
|
arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR) |
|
|
elif arr.shape[2] == 4: |
|
|
arr = cv2.cvtColor(arr, cv2.COLOR_RGBA2BGR) |
|
|
elif was_pil: |
|
|
arr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
if arr.dtype != np.uint8: |
|
|
if arr.max() <= 1.0: |
|
|
arr = arr * 255.0 |
|
|
arr = np.clip(arr, 0, 255).astype(np.uint8) |
|
|
|
|
|
return HWC3(arr) |
|
|
|
|
|
|
|
|
def _resize_to_height(img: np.ndarray, target_h: int) -> np.ndarray: |
|
|
"""Resize while preserving aspect ratio to a specific height.""" |
|
|
h, w = img.shape[:2] |
|
|
if h == target_h: |
|
|
return img |
|
|
new_w = int(round(w * target_h / float(h))) |
|
|
return cv2.resize(img, (new_w, target_h), interpolation=cv2.INTER_CUBIC) |
|
|
|
|
|
|
|
|
def _frames_to_tensor(frames: List[np.ndarray]) -> torch.Tensor: |
|
|
"""Convert a list of RGB uint8 frames to a tensor in [-1, 1] with shape 3,F,H,W.""" |
|
|
if not frames: |
|
|
return torch.empty(0) |
|
|
arr = np.stack(frames).astype(np.float32) / 127.5 - 1.0 |
|
|
return torch.from_numpy(arr).permute(3, 0, 1, 2) |
|
|
|
|
|
|
|
|
def _tensor_to_frames(tensor: torch.Tensor) -> List[np.ndarray]: |
|
|
"""Convert a tensor in [-1, 1] shaped 3,F,H,W back to RGB uint8 frames.""" |
|
|
if tensor.numel() == 0: |
|
|
return [] |
|
|
arr = tensor.permute(1, 2, 3, 0).cpu().numpy() |
|
|
arr = ((arr + 1.0) * 127.5).clip(0, 255).astype(np.uint8) |
|
|
return [frame for frame in arr] |
|
|
|
|
|
|
|
|
def _ensure_xyz(arr: np.ndarray) -> np.ndarray: |
|
|
"""Ensure last dimension is 3 by padding zeros if needed.""" |
|
|
if arr.shape[-1] >= 3: |
|
|
return arr |
|
|
pad_width = [(0, 0)] * arr.ndim |
|
|
pad_width[-1] = (0, 3 - arr.shape[-1]) |
|
|
return np.pad(arr, pad_width, mode="constant", constant_values=0) |
|
|
|
|
|
|
|
|
def _xy_only(arr: np.ndarray) -> np.ndarray: |
|
|
"""Return only x,y channels to satisfy drawing utils.""" |
|
|
if arr.shape[-1] <= 2: |
|
|
return arr |
|
|
return arr[..., :2] |
|
|
|
|
|
|
|
|
def _mask_to_float01(mask: Union[np.ndarray, Image.Image, torch.Tensor]) -> np.ndarray: |
|
|
"""Normalize any mask type to float32 single-channel in [0,1] with 0.5 threshold.""" |
|
|
if isinstance(mask, Image.Image): |
|
|
mask = np.array(mask) |
|
|
elif isinstance(mask, torch.Tensor): |
|
|
m = mask.detach().cpu() |
|
|
if m.dim() == 4: |
|
|
m = m[0] |
|
|
if m.shape[0] in (1, 3, 4): |
|
|
m = m.permute(1, 2, 0) |
|
|
mask = m.numpy() |
|
|
mask = np.asarray(mask) |
|
|
if mask.ndim == 3 and mask.shape[2] > 1: |
|
|
mask = mask.mean(axis=2) |
|
|
mask = mask.astype(np.float32) |
|
|
mask_min, mask_max = float(mask.min()), float(mask.max()) |
|
|
if mask_max > 1.0: |
|
|
mask = mask / 255.0 |
|
|
elif mask_min < 0.0: |
|
|
mask = (mask + 1.0) / 2.0 |
|
|
mask = np.clip(mask, 0.0, 1.0) |
|
|
mask = (mask > 0.5).astype(np.float32) |
|
|
return mask |
|
|
|
|
|
|
|
|
def _safe_ratio(num: float, den: float) -> float: |
|
|
if den == 0 or not np.isfinite(den): |
|
|
return 1.0 |
|
|
val = num / den |
|
|
return float(val) if np.isfinite(val) else 1.0 |
|
|
|
|
|
|
|
|
def _nan_to_one(val: float) -> float: |
|
|
return 1.0 if not np.isfinite(val) else float(val) |
|
|
|
|
|
|
|
|
def _augment_pose( |
|
|
pose: Dict, |
|
|
offset_x: Tuple[float, float], |
|
|
offset_y: Tuple[float, float], |
|
|
scale_range: Tuple[float, float], |
|
|
aspect_ratio_range: Tuple[float, float], |
|
|
fixed_params: Tuple[float, float, float, float] | None = None, |
|
|
) -> Dict: |
|
|
"""Lightweight pose jitter used by the diff-aug variant.""" |
|
|
pose_aug = { |
|
|
"bodies": {"candidate": pose["bodies"]["candidate"].copy(), "subset": pose["bodies"]["subset"].copy()}, |
|
|
"hands": pose["hands"].copy(), |
|
|
"faces": pose["faces"].copy(), |
|
|
} |
|
|
|
|
|
if fixed_params is None: |
|
|
sx = np.random.uniform(scale_range[0], scale_range[1]) |
|
|
aspect = np.random.uniform(aspect_ratio_range[0], aspect_ratio_range[1]) |
|
|
scale_x = sx * aspect |
|
|
scale_y = sx / max(aspect, 1e-6) |
|
|
dx = np.random.uniform(offset_x[0], offset_x[1]) |
|
|
dy = np.random.uniform(offset_y[0], offset_y[1]) |
|
|
else: |
|
|
dx, dy, scale_x, scale_y = fixed_params |
|
|
|
|
|
def _apply(arr: np.ndarray) -> np.ndarray: |
|
|
arr = arr.copy() |
|
|
mask = arr[..., 0] >= 0 |
|
|
arr[..., 0] = np.where(mask, arr[..., 0] * scale_x + dx, arr[..., 0]) |
|
|
arr[..., 1] = np.where(mask, arr[..., 1] * scale_y + dy, arr[..., 1]) |
|
|
return arr |
|
|
|
|
|
pose_aug["bodies"]["candidate"] = _apply(pose_aug["bodies"]["candidate"]) |
|
|
pose_aug["hands"] = _apply(pose_aug["hands"]) |
|
|
pose_aug["faces"] = _apply(pose_aug["faces"]) |
|
|
return pose_aug |
|
|
|
|
|
|
|
|
def _save_video(path: str, frames: List[np.ndarray], fps: float) -> None: |
|
|
if not frames: |
|
|
return |
|
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
|
with imageio.get_writer(path, fps=fps, codec="libx264") as writer: |
|
|
for frame in frames: |
|
|
writer.append_data(frame) |
|
|
|
|
|
|
|
|
def align_img(img: np.ndarray, pose_ori: Dict, scales: Dict[str, float]) -> Dict: |
|
|
"""Align a single pose dictionary using pre-computed scale factors.""" |
|
|
body_pose = pose_ori["bodies"]["candidate"].copy() |
|
|
hands = pose_ori["hands"].copy() |
|
|
faces = pose_ori["faces"].copy() |
|
|
|
|
|
H_in, W_in, _ = img.shape |
|
|
video_ratio = W_in / H_in |
|
|
body_pose[:, 0] *= video_ratio |
|
|
hands[:, :, 0] *= video_ratio |
|
|
faces[:, :, 0] *= video_ratio |
|
|
|
|
|
scale_neck = scales["scale_neck"] |
|
|
scale_face_left = scales["scale_face_left"] |
|
|
scale_face_right = scales["scale_face_right"] |
|
|
scale_shoulder = scales["scale_shoulder"] |
|
|
scale_arm_upper = scales["scale_arm_upper"] |
|
|
scale_arm_lower = scales["scale_arm_lower"] |
|
|
scale_hand = scales["scale_hand"] |
|
|
scale_body_len = scales["scale_body_len"] |
|
|
scale_leg_upper = scales["scale_leg_upper"] |
|
|
scale_leg_lower = scales["scale_leg_lower"] |
|
|
|
|
|
scale_list = [ |
|
|
scale_neck, |
|
|
scale_face_left, |
|
|
scale_face_right, |
|
|
scale_shoulder, |
|
|
scale_arm_upper, |
|
|
scale_arm_lower, |
|
|
scale_hand, |
|
|
scale_body_len, |
|
|
scale_leg_upper, |
|
|
scale_leg_lower, |
|
|
] |
|
|
finite_vals = [v for v in scale_list if np.isfinite(v)] |
|
|
mean_scale = np.mean(finite_vals) if finite_vals else 1.0 |
|
|
scale_list = [mean_scale if np.isinf(v) else v for v in scale_list] |
|
|
( |
|
|
scale_neck, |
|
|
scale_face_left, |
|
|
scale_face_right, |
|
|
scale_shoulder, |
|
|
scale_arm_upper, |
|
|
scale_arm_lower, |
|
|
scale_hand, |
|
|
scale_body_len, |
|
|
scale_leg_upper, |
|
|
scale_leg_lower, |
|
|
) = scale_list |
|
|
|
|
|
offset = { |
|
|
"14_16_to_0": body_pose[[14, 16], :] - body_pose[[0], :], |
|
|
"15_17_to_0": body_pose[[15, 17], :] - body_pose[[0], :], |
|
|
"3_to_2": body_pose[[3], :] - body_pose[[2], :], |
|
|
"4_to_3": body_pose[[4], :] - body_pose[[3], :], |
|
|
"6_to_5": body_pose[[6], :] - body_pose[[5], :], |
|
|
"7_to_6": body_pose[[7], :] - body_pose[[6], :], |
|
|
"9_to_8": body_pose[[9], :] - body_pose[[8], :], |
|
|
"10_to_9": body_pose[[10], :] - body_pose[[9], :], |
|
|
"12_to_11": body_pose[[12], :] - body_pose[[11], :], |
|
|
"13_to_12": body_pose[[13], :] - body_pose[[12], :], |
|
|
"hand_left_to_4": hands[1, :, :] - body_pose[[4], :], |
|
|
"hand_right_to_7": hands[0, :, :] - body_pose[[7], :], |
|
|
} |
|
|
|
|
|
def _warp(target: np.ndarray, center: np.ndarray, scale: float) -> np.ndarray: |
|
|
M = cv2.getRotationMatrix2D((center[0], center[1]), 0, scale) |
|
|
return warpAffine_kps(target, M) |
|
|
|
|
|
body_pose[[0], :] = _warp(body_pose[[0], :], body_pose[1], scale_neck) |
|
|
body_pose[[14, 16], :] = _warp(offset["14_16_to_0"] + body_pose[[0], :], body_pose[0], scale_face_left) |
|
|
body_pose[[15, 17], :] = _warp(offset["15_17_to_0"] + body_pose[[0], :], body_pose[0], scale_face_right) |
|
|
body_pose[[2, 5], :] = _warp(body_pose[[2, 5], :], body_pose[1], scale_shoulder) |
|
|
|
|
|
body_pose[[3], :] = _warp(offset["3_to_2"] + body_pose[[2], :], body_pose[2], scale_arm_upper) |
|
|
body_pose[[4], :] = _warp(offset["4_to_3"] + body_pose[[3], :], body_pose[3], scale_arm_lower) |
|
|
hands[1, :, :] = _warp(offset["hand_left_to_4"] + body_pose[[4], :], body_pose[4], scale_hand) |
|
|
|
|
|
body_pose[[6], :] = _warp(offset["6_to_5"] + body_pose[[5], :], body_pose[5], scale_arm_upper) |
|
|
body_pose[[7], :] = _warp(offset["7_to_6"] + body_pose[[6], :], body_pose[6], scale_arm_lower) |
|
|
hands[0, :, :] = _warp(offset["hand_right_to_7"] + body_pose[[7], :], body_pose[7], scale_hand) |
|
|
|
|
|
body_pose[[8, 11], :] = _warp(body_pose[[8, 11], :], body_pose[1], scale_body_len) |
|
|
body_pose[[9], :] = _warp(offset["9_to_8"] + body_pose[[8], :], body_pose[8], scale_leg_upper) |
|
|
body_pose[[10], :] = _warp(offset["10_to_9"] + body_pose[[9], :], body_pose[9], scale_leg_lower) |
|
|
body_pose[[12], :] = _warp(offset["12_to_11"] + body_pose[[11], :], body_pose[11], scale_leg_upper) |
|
|
body_pose[[13], :] = _warp(offset["13_to_12"] + body_pose[[12], :], body_pose[12], scale_leg_lower) |
|
|
|
|
|
body_pose_none = pose_ori["bodies"]["candidate"] == -1.0 |
|
|
hands_none = pose_ori["hands"] == -1.0 |
|
|
faces_none = pose_ori["faces"] == -1.0 |
|
|
|
|
|
body_pose[body_pose_none] = -1.0 |
|
|
hands[hands_none] = -1.0 |
|
|
faces[faces_none] = -1.0 |
|
|
|
|
|
body_pose = np.nan_to_num(body_pose, nan=-1.0) |
|
|
hands = np.nan_to_num(hands, nan=-1.0) |
|
|
faces = np.nan_to_num(faces, nan=-1.0) |
|
|
|
|
|
pose_align = dict( |
|
|
bodies={"candidate": body_pose, "subset": pose_ori["bodies"]["subset"]}, |
|
|
hands=hands, |
|
|
faces=faces, |
|
|
) |
|
|
return pose_align |
|
|
|
|
|
|
|
|
def warpAffine_kps(kps: np.ndarray, M: np.ndarray) -> np.ndarray: |
|
|
kps_t = kps.copy() |
|
|
kps_t[..., 0] = kps[..., 0] * M[0, 0] + kps[..., 1] * M[0, 1] + M[0, 2] |
|
|
kps_t[..., 1] = kps[..., 0] * M[1, 0] + kps[..., 1] * M[1, 1] + M[1, 2] |
|
|
return kps_t |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PoseDetection: |
|
|
pose: Dict |
|
|
pose_map_rgb: np.ndarray |
|
|
frame_rgb: np.ndarray |
|
|
orig_hw: Tuple[int, int] |
|
|
|
|
|
|
|
|
class PoseAligner: |
|
|
def __init__(self, detect_resolution: int = 1024, device: str = None, detection_workers: int = 2) -> None: |
|
|
det_model = fl.locate_file("pose/yolox_l.onnx") |
|
|
pose_model = fl.locate_file("pose/dw-ll_ucoco_384.onnx") |
|
|
resolved_device = device or ("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
self.det_model = det_model |
|
|
self.pose_model = pose_model |
|
|
self.device = resolved_device |
|
|
self.pose_estimation = Wholebody(det_model, pose_model, device=resolved_device) |
|
|
self.detection_workers = max(1, int(detection_workers)) |
|
|
self.detect_resolution = detect_resolution |
|
|
|
|
|
def _detect_pose_session(self, image: ArrayImage, pose_estimation: Wholebody) -> PoseDetection | None: |
|
|
bgr_orig = _to_bgr_image(image) |
|
|
resized = resize_image(bgr_orig, self.detect_resolution) |
|
|
H, W, _ = resized.shape |
|
|
candidate, subset, _ = pose_estimation(resized) |
|
|
if len(candidate) == 0: |
|
|
return None |
|
|
|
|
|
nums, keys, locs = candidate.shape |
|
|
if keys < 18: |
|
|
return None |
|
|
|
|
|
candidate = _ensure_xyz(candidate.copy()) |
|
|
subset = subset.copy() |
|
|
|
|
|
|
|
|
candidate[..., 0] /= float(W) |
|
|
candidate[..., 1] /= float(H) |
|
|
|
|
|
locs = candidate.shape[2] |
|
|
|
|
|
|
|
|
body = candidate[0, :18].copy() |
|
|
body_scores = subset[0, :18].copy() |
|
|
|
|
|
|
|
|
score = np.zeros((1, 18), dtype=np.float64) |
|
|
for j in range(18): |
|
|
if body_scores[j] <= 0.3: |
|
|
body[j] = -1 |
|
|
score[0, j] = -1 |
|
|
else: |
|
|
score[0, j] = j |
|
|
|
|
|
|
|
|
faces = candidate[0:1, 24:92].copy() |
|
|
if subset.shape[1] > 24: |
|
|
face_scores = subset[0, 24:92] |
|
|
for j in range(faces.shape[1]): |
|
|
if j < len(face_scores) and face_scores[j] <= 0.3: |
|
|
faces[0, j] = -1 |
|
|
faces = _ensure_xyz(faces) |
|
|
|
|
|
|
|
|
|
|
|
right_hand = candidate[0, 92:113].copy() if candidate.shape[1] > 92 else np.zeros((21, locs)) |
|
|
left_hand = candidate[0, 113:134].copy() if candidate.shape[1] > 113 else np.zeros((21, locs)) |
|
|
|
|
|
|
|
|
if subset.shape[1] > 92: |
|
|
hand_scores = subset[0, 92:134] |
|
|
for j in range(21): |
|
|
if j < len(hand_scores) and hand_scores[j] <= 0.3: |
|
|
right_hand[j] = -1 |
|
|
if j + 21 < len(hand_scores) and hand_scores[j + 21] <= 0.3: |
|
|
left_hand[j] = -1 |
|
|
|
|
|
hands = _ensure_xyz(np.stack([right_hand, left_hand])) |
|
|
|
|
|
pose = {"bodies": {"candidate": body, "subset": score}, "hands": hands, "faces": faces} |
|
|
|
|
|
|
|
|
pose_for_draw = { |
|
|
"bodies": {"candidate": _xy_only(body), "subset": score}, |
|
|
"hands": _xy_only(hands), |
|
|
"faces": _xy_only(faces), |
|
|
} |
|
|
pose_map = draw_pose(pose_for_draw, H, W, use_body=True, use_hand=True, use_face=False) |
|
|
pose_map_rgb = cv2.cvtColor(pose_map, cv2.COLOR_BGR2RGB) |
|
|
frame_rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) |
|
|
return PoseDetection(pose=pose, pose_map_rgb=pose_map_rgb, frame_rgb=frame_rgb, orig_hw=bgr_orig.shape[:2]) |
|
|
|
|
|
def _detect_pose(self, image: ArrayImage) -> PoseDetection | None: |
|
|
return self._detect_pose_session(image, self.pose_estimation) |
|
|
|
|
|
def _compute_alignment(self, ref_pose: Dict, first_pose: Dict, ref_ratio: float, video_ratio: float) -> Tuple[Dict, np.ndarray]: |
|
|
body_ref = ref_pose["bodies"]["candidate"].copy() |
|
|
hands_ref = ref_pose["hands"].copy() |
|
|
faces_ref = ref_pose["faces"].copy() |
|
|
|
|
|
body_first = first_pose["bodies"]["candidate"].copy() |
|
|
hands_first = first_pose["hands"].copy() |
|
|
faces_first = first_pose["faces"].copy() |
|
|
|
|
|
body_ref[:, 0] *= ref_ratio |
|
|
hands_ref[:, :, 0] *= ref_ratio |
|
|
faces_ref[:, :, 0] *= ref_ratio |
|
|
|
|
|
body_first[:, 0] *= video_ratio |
|
|
hands_first[:, :, 0] *= video_ratio |
|
|
faces_first[:, :, 0] *= video_ratio |
|
|
|
|
|
def dist(body: np.ndarray, a: int, b: int) -> float: |
|
|
pa, pb = body[a, :2], body[b, :2] |
|
|
if (pa < 0).any() or (pb < 0).any(): |
|
|
return np.nan |
|
|
return float(np.linalg.norm(pa - pb)) |
|
|
|
|
|
def hand_dist(hand: np.ndarray, idx_a: int, idx_b: int) -> float: |
|
|
pa, pb = hand[idx_a, :2], hand[idx_b, :2] |
|
|
if (pa < 0).any() or (pb < 0).any(): |
|
|
return np.nan |
|
|
return float(np.linalg.norm(pa - pb)) |
|
|
|
|
|
align_args = { |
|
|
"scale_neck": _safe_ratio(dist(body_ref, 0, 1), dist(body_first, 0, 1)), |
|
|
"scale_face_left": _safe_ratio( |
|
|
dist(body_ref, 16, 14) + dist(body_ref, 14, 0), |
|
|
dist(body_first, 16, 14) + dist(body_first, 14, 0), |
|
|
), |
|
|
"scale_face_right": _safe_ratio( |
|
|
dist(body_ref, 17, 15) + dist(body_ref, 15, 0), |
|
|
dist(body_first, 17, 15) + dist(body_first, 15, 0), |
|
|
), |
|
|
"scale_shoulder": _safe_ratio(dist(body_ref, 2, 5), dist(body_first, 2, 5)), |
|
|
"scale_arm_upper": np.nanmean( |
|
|
[ |
|
|
_safe_ratio(dist(body_ref, 2, 3), dist(body_first, 2, 3)), |
|
|
_safe_ratio(dist(body_ref, 5, 6), dist(body_first, 5, 6)), |
|
|
] |
|
|
), |
|
|
"scale_arm_lower": np.nanmean( |
|
|
[ |
|
|
_safe_ratio(dist(body_ref, 3, 4), dist(body_first, 3, 4)), |
|
|
_safe_ratio(dist(body_ref, 6, 7), dist(body_first, 6, 7)), |
|
|
] |
|
|
), |
|
|
"scale_body_len": _safe_ratio( |
|
|
dist(body_ref, 1, 8) if not np.isnan(dist(body_ref, 1, 8)) else dist(body_ref, 1, 11), |
|
|
dist(body_first, 1, 8) if not np.isnan(dist(body_first, 1, 8)) else dist(body_first, 1, 11), |
|
|
), |
|
|
"scale_leg_upper": np.nanmean( |
|
|
[ |
|
|
_safe_ratio(dist(body_ref, 8, 9), dist(body_first, 8, 9)), |
|
|
_safe_ratio(dist(body_ref, 11, 12), dist(body_first, 11, 12)), |
|
|
] |
|
|
), |
|
|
"scale_leg_lower": np.nanmean( |
|
|
[ |
|
|
_safe_ratio(dist(body_ref, 9, 10), dist(body_first, 9, 10)), |
|
|
_safe_ratio(dist(body_ref, 12, 13), dist(body_first, 12, 13)), |
|
|
] |
|
|
), |
|
|
} |
|
|
|
|
|
hand_pairs = [(0, 1), (0, 5), (0, 9), (0, 13), (0, 17)] |
|
|
hand_ratios = [] |
|
|
for idx_a, idx_b in hand_pairs: |
|
|
hand_ratios.append( |
|
|
_safe_ratio( |
|
|
hand_dist(hands_ref[0], idx_a, idx_b), |
|
|
hand_dist(hands_first[0], idx_a, idx_b), |
|
|
) |
|
|
) |
|
|
hand_ratios.append( |
|
|
_safe_ratio( |
|
|
hand_dist(hands_ref[1], idx_a, idx_b), |
|
|
hand_dist(hands_first[1], idx_a, idx_b), |
|
|
) |
|
|
) |
|
|
hand_ratios = [v for v in hand_ratios if np.isfinite(v)] |
|
|
align_args["scale_hand"] = np.mean(hand_ratios) if hand_ratios else ( |
|
|
align_args["scale_arm_upper"] + align_args["scale_arm_lower"] |
|
|
) / 2 |
|
|
|
|
|
align_args = {k: _nan_to_one(v) for k, v in align_args.items()} |
|
|
offset = np.array( |
|
|
[body_ref[1, 0] - body_first[1, 0], body_ref[1, 1] - body_first[1, 1], 0.0], |
|
|
dtype=np.float32, |
|
|
) |
|
|
return align_args, offset |
|
|
|
|
|
def align( |
|
|
self, |
|
|
ref_video_frames: List[ArrayImage], |
|
|
ref_image: ArrayImage, |
|
|
ref_video_mask: List[ArrayImage] | None = None, |
|
|
align_frame: int = 0, |
|
|
max_frames: int | None = None, |
|
|
include_composite: bool = False, |
|
|
augment: bool = True, |
|
|
augment_mode: str = "per_frame", |
|
|
augment_params: Dict | None = None, |
|
|
cpu_resize_workers: int = None, |
|
|
resize_ref_video: bool = False, |
|
|
detection_chunk_size: int = 8, |
|
|
expand_scale = 0, |
|
|
verbose: int = 0, |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
t0 = time.perf_counter() |
|
|
ref_detection = self._detect_pose(ref_image) |
|
|
if ref_detection is None: |
|
|
raise ValueError("Unable to detect pose in the reference image.") |
|
|
|
|
|
target_hw = ref_detection.frame_rgb.shape[:2] |
|
|
|
|
|
def _tensor_video_to_list(t: torch.Tensor) -> List[np.ndarray]: |
|
|
if t.dim() != 4 or t.shape[0] not in (1, 3, 4): |
|
|
raise ValueError("Video tensor must be 4D CTHW or BGRHWC list") |
|
|
vid = t.detach().cpu() |
|
|
if vid.min() < 0: |
|
|
vid = (vid + 1.0) * 127.5 |
|
|
elif vid.max() <= 1.0: |
|
|
vid = vid * 255.0 |
|
|
vid = vid.clamp(0, 255).byte() |
|
|
|
|
|
frames = [] |
|
|
for i in range(vid.shape[1]): |
|
|
rgb = vid[:, i].permute(1, 2, 0).numpy() |
|
|
bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) |
|
|
frames.append(bgr) |
|
|
return frames |
|
|
|
|
|
if isinstance(ref_video_frames, torch.Tensor): |
|
|
ref_video_frames = _tensor_video_to_list(ref_video_frames) |
|
|
def _preprocess_item(idx_frame_mask): |
|
|
idx, frame, mask = idx_frame_mask |
|
|
|
|
|
if mask is not None: |
|
|
mask = _mask_to_float01(mask) |
|
|
if resize_ref_video and frame.shape[:2] != target_hw: |
|
|
frame = cv2.resize(frame, (target_hw[1], target_hw[0]), interpolation=cv2.INTER_LANCZOS4) |
|
|
|
|
|
if mask is not None: |
|
|
mask_f = mask |
|
|
if mask_f.shape != frame.shape[:2]: |
|
|
mask_f = cv2.resize(mask_f, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_LINEAR) |
|
|
if expand_scale != 0: |
|
|
kernel_size = abs(expand_scale) |
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) |
|
|
op_expand = cv2.dilate if expand_scale > 0 else cv2.erode |
|
|
mask = op_expand(mask, kernel, iterations=3) |
|
|
frame = (frame.astype(np.float32) * mask_f[..., None]).clip(0, 255).astype(np.uint8) |
|
|
return idx, frame |
|
|
|
|
|
|
|
|
items = [] |
|
|
for idx, frame in enumerate(ref_video_frames): |
|
|
mask = None |
|
|
if ref_video_mask is not None: |
|
|
mask = ref_video_mask[:, idx] if torch.is_tensor(ref_video_mask) else ref_video_mask[idx] |
|
|
items.append((idx, frame, mask)) |
|
|
if cpu_resize_workers is None: |
|
|
cpu_resize_workers = max(1, int(os.cpu_count() / 2)) |
|
|
preprocessed = process_images_multithread( |
|
|
_preprocess_item, |
|
|
items, |
|
|
process_type="prephase", |
|
|
wrap_in_list=False, |
|
|
max_workers=cpu_resize_workers, |
|
|
in_place=False, |
|
|
) |
|
|
idx_to_frame = {idx: frame for idx, frame in preprocessed} |
|
|
t_preprocess = time.perf_counter() - t0 |
|
|
if verbose: |
|
|
print(f"[pose_align] preprocess done in {t_preprocess:.2f}s") |
|
|
|
|
|
ref_ratio = ref_detection.frame_rgb.shape[1] / ref_detection.frame_rgb.shape[0] |
|
|
pose_list: List[Dict] = [] |
|
|
video_frames: List[np.ndarray] = [] if include_composite else [] |
|
|
video_pose_frames: List[np.ndarray] = [] if include_composite else [] |
|
|
|
|
|
align_args = None |
|
|
offset = None |
|
|
|
|
|
|
|
|
detectors = [self.pose_estimation] |
|
|
for _ in range(max(1, self.detection_workers) - 1): |
|
|
detectors.append(Wholebody(self.det_model, self.pose_model, device=self.device)) |
|
|
|
|
|
sequential_done = False |
|
|
pending = [] |
|
|
t_first_stage = 0.0 |
|
|
|
|
|
for idx in sorted(idx_to_frame.keys()): |
|
|
if idx < align_frame: |
|
|
continue |
|
|
if max_frames is not None and len(pose_list) >= max_frames: |
|
|
break |
|
|
|
|
|
frame = idx_to_frame[idx] |
|
|
t_detect_start = time.perf_counter() |
|
|
detection = self._detect_pose(frame) if not sequential_done else None |
|
|
if detection is None: |
|
|
|
|
|
pending.append((idx, frame)) |
|
|
continue |
|
|
|
|
|
if align_args is None: |
|
|
video_ratio = detection.frame_rgb.shape[1] / detection.frame_rgb.shape[0] |
|
|
align_args, offset = self._compute_alignment(ref_detection.pose, detection.pose, ref_ratio, video_ratio) |
|
|
|
|
|
pose_aligned = align_img(detection.frame_rgb, detection.pose, align_args) |
|
|
mask_body = pose_aligned["bodies"]["candidate"][..., 0] < 0 |
|
|
mask_hands = pose_aligned["hands"][..., 0] < 0 |
|
|
mask_faces = pose_aligned["faces"][..., 0] < 0 |
|
|
|
|
|
pose_aligned["bodies"]["candidate"] = _ensure_xyz(pose_aligned["bodies"]["candidate"]) + offset |
|
|
pose_aligned["hands"] = _ensure_xyz(pose_aligned["hands"]) + offset |
|
|
pose_aligned["faces"] = _ensure_xyz(pose_aligned["faces"]) + offset |
|
|
|
|
|
pose_aligned["bodies"]["candidate"][mask_body] = -1 |
|
|
pose_aligned["hands"][mask_hands] = -1 |
|
|
pose_aligned["faces"][mask_faces] = -1 |
|
|
|
|
|
pose_aligned["bodies"]["candidate"][:, 0] /= ref_ratio |
|
|
pose_aligned["hands"][:, :, 0] /= ref_ratio |
|
|
pose_aligned["faces"][:, :, 0] /= ref_ratio |
|
|
|
|
|
pose_list.append(pose_aligned) |
|
|
if include_composite: |
|
|
video_frames.append(detection.frame_rgb) |
|
|
video_pose_frames.append(detection.pose_map_rgb) |
|
|
|
|
|
sequential_done = True |
|
|
start_idx = idx + 1 |
|
|
pending.extend([(j, idx_to_frame[j]) for j in sorted(idx_to_frame.keys()) if j >= start_idx]) |
|
|
t_first_stage = time.perf_counter() - t_detect_start |
|
|
if verbose: |
|
|
print(f"[pose_align] first frame detection+align done in {t_first_stage:.2f}s") |
|
|
break |
|
|
|
|
|
if align_args is None: |
|
|
return {"composite": torch.empty(0), "pose_only": torch.empty(0), "pose_aug": torch.empty(0)} |
|
|
|
|
|
|
|
|
def _process_pending(item): |
|
|
idx, frame, det_idx = item |
|
|
det = self._detect_pose_session(frame, detectors[det_idx]) |
|
|
if det is None: |
|
|
return idx, None, None, None |
|
|
|
|
|
pa = align_img(det.frame_rgb, det.pose, align_args) |
|
|
mb = pa["bodies"]["candidate"][..., 0] < 0 |
|
|
mh = pa["hands"][..., 0] < 0 |
|
|
mf = pa["faces"][..., 0] < 0 |
|
|
|
|
|
pa["bodies"]["candidate"] = _ensure_xyz(pa["bodies"]["candidate"]) + offset |
|
|
pa["hands"] = _ensure_xyz(pa["hands"]) + offset |
|
|
pa["faces"] = _ensure_xyz(pa["faces"]) + offset |
|
|
|
|
|
pa["bodies"]["candidate"][mb] = -1 |
|
|
pa["hands"][mh] = -1 |
|
|
pa["faces"][mf] = -1 |
|
|
|
|
|
pa["bodies"]["candidate"][:, 0] /= ref_ratio |
|
|
pa["hands"][:, :, 0] /= ref_ratio |
|
|
pa["faces"][:, :, 0] /= ref_ratio |
|
|
return idx, pa, det.frame_rgb, det.pose_map_rgb |
|
|
|
|
|
if pending: |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
|
results = [] |
|
|
det_cycle = itertools.cycle(range(self.detection_workers)) |
|
|
with ThreadPoolExecutor(max_workers=self.detection_workers) as executor: |
|
|
for start in range(0, len(pending), detection_chunk_size * self.detection_workers): |
|
|
chunk = pending[start : start + detection_chunk_size * self.detection_workers] |
|
|
tasks = [] |
|
|
for idx, frame in chunk: |
|
|
tasks.append((idx, frame, next(det_cycle))) |
|
|
futures = {executor.submit(_process_pending, t): t[0] for t in tasks} |
|
|
for future in as_completed(futures): |
|
|
res = future.result() |
|
|
if res[1] is None: |
|
|
continue |
|
|
results.append(res) |
|
|
for idx, pa, fr, pm in sorted(results, key=lambda x: x[0]): |
|
|
if max_frames is not None and len(pose_list) >= max_frames: |
|
|
break |
|
|
pose_list.append(pa) |
|
|
if include_composite: |
|
|
video_frames.append(fr) |
|
|
video_pose_frames.append(pm) |
|
|
t_parallel = time.perf_counter() - t0 - t_preprocess - t_first_stage |
|
|
if verbose: |
|
|
print(f"[pose_align] parallel detection done in {t_parallel:.2f}s") |
|
|
else: |
|
|
t_parallel = 0.0 |
|
|
|
|
|
if not pose_list: |
|
|
return {"composite": torch.empty(0), "pose_only": torch.empty(0), "pose_aug": torch.empty(0)} |
|
|
|
|
|
body_seq = [_ensure_xyz(pose["bodies"]["candidate"][:18]) for pose in pose_list] |
|
|
body_seq_subset = [pose["bodies"]["subset"][:1] for pose in pose_list] |
|
|
hands_seq = [_ensure_xyz(pose["hands"][:2]) for pose in pose_list] |
|
|
faces_seq = [_ensure_xyz(pose["faces"][:1]) for pose in pose_list] |
|
|
|
|
|
ref_H, ref_W = ref_detection.frame_rgb.shape[:2] |
|
|
ref_target_H, ref_target_W = ref_detection.orig_hw if hasattr(ref_detection, "orig_hw") else (ref_H, ref_W) |
|
|
composite_frames: List[np.ndarray] = [] if include_composite else [] |
|
|
pose_only_frames: List[np.ndarray] = [] |
|
|
pose_aug_frames: List[np.ndarray] = [] |
|
|
|
|
|
aug_cfg = augment_params or {} |
|
|
offset_x = aug_cfg.get("offset_x", (-0.2, 0.2)) |
|
|
offset_y = aug_cfg.get("offset_y", (-0.2, 0.2)) |
|
|
scale_rng = aug_cfg.get("scale", (0.7, 1.3)) |
|
|
aspect_rng = aug_cfg.get("aspect_ratio_range", (0.6, 1.4)) |
|
|
fixed_aug = None |
|
|
if augment and augment_mode == "fixed": |
|
|
sx = np.random.uniform(scale_rng[0], scale_rng[1]) |
|
|
aspect = np.random.uniform(aspect_rng[0], aspect_rng[1]) |
|
|
scale_x = sx * aspect |
|
|
scale_y = sx / max(aspect, 1e-6) |
|
|
dx = np.random.uniform(offset_x[0], offset_x[1]) |
|
|
dy = np.random.uniform(offset_y[0], offset_y[1]) |
|
|
fixed_aug = (dx, dy, scale_x, scale_y) |
|
|
|
|
|
for i in range(len(body_seq)): |
|
|
pose_t = { |
|
|
"bodies": {"candidate": body_seq[i], "subset": body_seq_subset[i]}, |
|
|
"hands": hands_seq[i], |
|
|
"faces": faces_seq[i], |
|
|
} |
|
|
|
|
|
pose_for_draw = { |
|
|
"bodies": {"candidate": _xy_only(body_seq[i]), "subset": body_seq_subset[i]}, |
|
|
"hands": _xy_only(hands_seq[i]), |
|
|
"faces": _xy_only(faces_seq[i]), |
|
|
} |
|
|
|
|
|
aligned_pose = draw_pose(pose_for_draw, ref_H, ref_W, use_body=True, use_hand=True, use_face=False) |
|
|
aligned_pose = cv2.cvtColor(aligned_pose, cv2.COLOR_BGR2RGB) |
|
|
pose_only_frames.append(aligned_pose) |
|
|
|
|
|
aug_frame = None |
|
|
if augment: |
|
|
params = None if augment_mode == "per_frame" else fixed_aug |
|
|
pose_t_aug = _augment_pose(pose_t, offset_x, offset_y, scale_rng, aspect_rng, fixed_params=params) |
|
|
pose_t_aug_draw = { |
|
|
"bodies": {"candidate": _xy_only(pose_t_aug["bodies"]["candidate"]), "subset": pose_t_aug["bodies"]["subset"]}, |
|
|
"hands": _xy_only(pose_t_aug["hands"]), |
|
|
"faces": _xy_only(pose_t_aug["faces"]), |
|
|
} |
|
|
aug_pose = draw_pose(pose_t_aug_draw, ref_H, ref_W, use_body=True, use_hand=True, use_face=False) |
|
|
aug_frame = cv2.cvtColor(aug_pose, cv2.COLOR_BGR2RGB) |
|
|
pose_aug_frames.append(aug_frame) |
|
|
|
|
|
if include_composite: |
|
|
video_frame = _resize_to_height(video_frames[i], ref_H) |
|
|
video_pose = _resize_to_height(video_pose_frames[i], ref_H) |
|
|
ref_pose_resized = _resize_to_height(ref_detection.pose_map_rgb, ref_H) |
|
|
ref_img_resized = _resize_to_height(ref_detection.frame_rgb, ref_H) |
|
|
parts = [ |
|
|
ref_img_resized, |
|
|
ref_pose_resized, |
|
|
aligned_pose, |
|
|
] |
|
|
if augment and aug_frame is not None: |
|
|
parts.append(aug_frame) |
|
|
parts.extend([video_frame, video_pose]) |
|
|
comp = np.concatenate(parts, axis=1) |
|
|
composite_frames.append(comp) |
|
|
|
|
|
|
|
|
if (ref_target_H, ref_target_W) != (ref_H, ref_W): |
|
|
def _resize_list(frames: List[np.ndarray]) -> List[np.ndarray]: |
|
|
return [cv2.resize(f, (ref_target_W, ref_target_H), interpolation=cv2.INTER_CUBIC) for f in frames] |
|
|
pose_only_frames = _resize_list(pose_only_frames) |
|
|
pose_aug_frames = _resize_list(pose_aug_frames) |
|
|
t_render = time.perf_counter() - t0 - t_preprocess - t_first_stage - t_parallel |
|
|
t_total = time.perf_counter() - t0 |
|
|
if verbose: |
|
|
print(f"[pose_align] render done in {t_render:.2f}s; total {t_total:.2f}s") |
|
|
|
|
|
return { |
|
|
"composite": _frames_to_tensor(composite_frames) if include_composite else torch.empty(0), |
|
|
"pose_only": _frames_to_tensor(pose_only_frames), |
|
|
"pose_aug": _frames_to_tensor(pose_aug_frames) if augment else torch.empty(0), |
|
|
} |
|
|
|
|
|
|
|
|
def load_video_frames(path: str) -> Tuple[List[np.ndarray], float]: |
|
|
cap = cv2.VideoCapture(path) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 24.0 |
|
|
frames = [] |
|
|
while cap.isOpened(): |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
frames.append(frame) |
|
|
cap.release() |
|
|
cv2.destroyAllWindows() |
|
|
return frames, fps |
|
|
|
|
|
|
|
|
def load_mask_frames(path: str) -> List[np.ndarray]: |
|
|
frames, _ = load_video_frames(path) |
|
|
return frames |
|
|
|
|
|
|
|
|
def demo( |
|
|
ref_image_path: str, |
|
|
ref_video_path: str, |
|
|
ref_video_mask_path: str | None, |
|
|
output_dir: str, |
|
|
include_composite: bool = True, |
|
|
augment: bool = True, |
|
|
augment_mode: str = "per_frame", |
|
|
align_frame: int = 0, |
|
|
max_frames: int | None = None, |
|
|
detection_workers: int = 2, |
|
|
cpu_resize_workers: int = None, |
|
|
resize_ref_video: bool = False, |
|
|
detection_chunk_size: int = 8, |
|
|
verbose: int = 0, |
|
|
) -> None: |
|
|
frames, fps = load_video_frames(ref_video_path) |
|
|
mask_frames = load_mask_frames(ref_video_mask_path) if ref_video_mask_path else None |
|
|
aligner = PoseAligner(detection_workers=detection_workers) |
|
|
outputs = aligner.align( |
|
|
frames, |
|
|
Image.open(ref_image_path), |
|
|
ref_video_mask=mask_frames, |
|
|
align_frame=align_frame, |
|
|
max_frames=max_frames, |
|
|
include_composite=include_composite, |
|
|
augment=augment, |
|
|
augment_mode=augment_mode, |
|
|
cpu_resize_workers=cpu_resize_workers, |
|
|
resize_ref_video=resize_ref_video, |
|
|
detection_chunk_size=detection_chunk_size, |
|
|
verbose=verbose, |
|
|
) |
|
|
|
|
|
composite_frames = _tensor_to_frames(outputs["composite"]) |
|
|
pose_frames = _tensor_to_frames(outputs["pose_only"]) |
|
|
pose_aug_frames = _tensor_to_frames(outputs["pose_aug"]) |
|
|
|
|
|
if include_composite and composite_frames: |
|
|
_save_video(os.path.join(output_dir, "composite.mp4"), composite_frames, fps) |
|
|
if pose_frames: |
|
|
_save_video(os.path.join(output_dir, "pose_only.mp4"), pose_frames, fps) |
|
|
if pose_aug_frames: |
|
|
_save_video(os.path.join(output_dir, "pose_only_aug.mp4"), pose_aug_frames, fps) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Pose alignment demo") |
|
|
parser.add_argument("--ref_image", type=str, default=r"e:/steffete.png", help="Path to the reference image") |
|
|
parser.add_argument("--ref_video", type=str, default=r"e:\stage.mp4", help="Path to the reference pose video (mp4)") |
|
|
parser.add_argument("--ref_video_mask", type=str, default='e:/stagemasked_alpha.mp4', help="Optional path to a mask video (same length as ref video)") |
|
|
|
|
|
parser.add_argument("--output_dir", type=str, default="pose_align_demo", help="Directory to store demo mp4s") |
|
|
parser.add_argument("--no_composite", action="store_true", help="Skip saving comparison/composite video") |
|
|
parser.add_argument("--no_augment", action="store_true", help="Disable augmented pose output") |
|
|
parser.add_argument("--augment_mode", type=str, default="per_frame", choices=["fixed", "per_frame"], help="Augmentation mode (fixed per video or per frame jitter)") |
|
|
parser.add_argument("--align_frame", type=int, default=0, help="Frame index to start alignment from") |
|
|
parser.add_argument("--max_frames", type=int, default=48, help="Maximum frames to process") |
|
|
parser.add_argument("--detection_workers", type=int, default=2, help="Number of parallel detection workers (GPU-backed)") |
|
|
parser.add_argument("--cpu_resize_workers", type=int, default=None, help="CPU workers for pre-mask/resize") |
|
|
parser.add_argument("--resize_ref_video", action="store_true", help="Resize ref video to ref image size (Lanczos)") |
|
|
parser.add_argument("--detection_chunk_size", type=int, default=8, help="Chunk size for parallel detection batches") |
|
|
parser.add_argument("--verbose", type=int, default=1, help="Verbosity level for timing/debug") |
|
|
args = parser.parse_args() |
|
|
|
|
|
demo( |
|
|
args.ref_image, |
|
|
args.ref_video, |
|
|
args.ref_video_mask, |
|
|
args.output_dir, |
|
|
include_composite=not args.no_composite, |
|
|
augment=not args.no_augment, |
|
|
augment_mode=args.augment_mode, |
|
|
align_frame=args.align_frame, |
|
|
max_frames=args.max_frames, |
|
|
detection_workers=args.detection_workers, |
|
|
cpu_resize_workers=args.cpu_resize_workers, |
|
|
resize_ref_video=args.resize_ref_video, |
|
|
detection_chunk_size=args.detection_chunk_size, |
|
|
verbose=args.verbose, |
|
|
) |
|
|
|