vr-hmr / scripts /demo /hamer_inference.py
zirobtc's picture
Upload folder using huggingface_hub
7e120dd
"""
HaMeR (Hand Mesh Recovery) wrapper for integration with GENMO.
Runs HaMeR on detected hand bounding boxes and returns MANO parameters.
"""
import sys
import os
from pathlib import Path
# Add HaMeR to path BEFORE any other imports
# This needs to be at the absolute path level
_SCRIPT_DIR = Path(__file__).resolve().parent
_GENMO_ROOT = _SCRIPT_DIR.parent.parent # GENMO/scripts/demo -> GENMO
HAMER_ROOT = _GENMO_ROOT / "third_party" / "hamer"
if str(HAMER_ROOT) not in sys.path:
sys.path.insert(0, str(HAMER_ROOT))
import torch
import numpy as np
import cv2
import mmcv
from tqdm import tqdm
class HaMeRInference:
"""
HaMeR wrapper for hand mesh recovery.
Input: Hand bounding boxes + video frames
Output: MANO parameters (hand_pose, global_orient, betas)
"""
def __init__(self, device='cuda:0'):
self.device = torch.device(device)
self._model = None
self._model_cfg = None
def _load_model(self):
"""Lazy load HaMeR model."""
if self._model is not None:
return
# Override CACHE_DIR_HAMER to point to the correct location BEFORE importing
import hamer.configs
hamer.configs.CACHE_DIR_HAMER = str(HAMER_ROOT / "_DATA")
from hamer.configs import CACHE_DIR_HAMER
from hamer.models import load_hamer, DEFAULT_CHECKPOINT
# The checkpoint path also needs to be updated
checkpoint_path = HAMER_ROOT / "_DATA" / "hamer_ckpts" / "checkpoints" / "hamer.ckpt"
if not checkpoint_path.exists():
raise FileNotFoundError(f"HaMeR checkpoint not found at {checkpoint_path}. Run fetch_demo_data.sh in third_party/hamer/")
# Load HaMeR
self._model, self._model_cfg = load_hamer(str(checkpoint_path))
self._model = self._model.to(self.device)
self._model.eval()
print(f"[HaMeR] Loaded model from {checkpoint_path}")
def _prepare_input(self, frame, bbox, is_right):
"""
Prepare input for HaMeR model.
Args:
frame: (H, W, 3) BGR image
bbox: [x1, y1, x2, y2] hand bounding box
is_right: bool, True for right hand
Returns:
batch dict for HaMeR model
"""
from hamer.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD
# Validate frame shape - must be (H, W, 3)
if frame is None:
raise ValueError("Frame is None")
if not isinstance(frame, np.ndarray):
raise ValueError(f"Frame must be numpy array, got {type(frame)}")
if frame.ndim != 3:
raise ValueError(f"Frame must be 3D (H, W, C), got {frame.ndim}D with shape {frame.shape}")
if frame.shape[2] != 3:
raise ValueError(f"Frame must have 3 channels, got {frame.shape[2]} with shape {frame.shape}")
# Ensure frame is contiguous and correct dtype
if not frame.flags['C_CONTIGUOUS']:
frame = np.ascontiguousarray(frame)
if frame.dtype != np.uint8:
frame = frame.astype(np.uint8)
# Create dataset for single hand
boxes = np.array([bbox])
right = np.array([1 if is_right else 0])
dataset = ViTDetDataset(
self._model_cfg,
frame, # BGR image
boxes,
right,
rescale_factor=2.0
)
return dataset[0]
@torch.no_grad()
def predict_single(self, frame, bbox, is_right):
"""
Predict MANO parameters for a single hand.
Args:
frame: (H, W, 3) BGR image
bbox: [x1, y1, x2, y2] hand bounding box
is_right: bool, True for right hand
Returns:
dict with:
- hand_pose: (15, 3) axis-angle for finger joints
- global_orient: (3,) axis-angle for wrist
- betas: (10,) shape parameters
- vertices: (778, 3) mesh vertices
- keypoints_3d: (21, 3) 3D hand joints
"""
self._load_model()
from hamer.utils import recursive_to
batch = self._prepare_input(frame, bbox, is_right)
# Add batch dimension to all array-like values (both numpy and torch)
# ViTDetDataset returns numpy arrays, not torch tensors, so we need to handle both
processed_batch = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
processed_batch[k] = v.unsqueeze(0)
elif isinstance(v, np.ndarray):
# Add batch dimension to numpy array and convert to tensor
processed_batch[k] = torch.from_numpy(v).unsqueeze(0)
else:
processed_batch[k] = v
batch = recursive_to(processed_batch, self.device)
out = self._model(batch)
# Extract MANO parameters
pred_mano = out['pred_mano_params']
# Convert rotation matrices to axis-angle
global_orient_rotmat = pred_mano['global_orient'][0] # (1, 3, 3)
hand_pose_rotmat = pred_mano['hand_pose'][0] # (15, 3, 3)
betas = pred_mano['betas'][0] # (10,)
# Mirror for left hand (HaMeR predicts right-hand rotations)
if not is_right:
mirror = torch.diag(torch.tensor([-1.0, 1.0, 1.0], device=global_orient_rotmat.device))
global_orient_rotmat = mirror @ global_orient_rotmat @ mirror
hand_pose_rotmat = mirror @ hand_pose_rotmat @ mirror
global_orient_aa = self._rotmat_to_axis_angle(global_orient_rotmat.reshape(-1, 3, 3)) # (1, 3)
hand_pose_aa = self._rotmat_to_axis_angle(hand_pose_rotmat.reshape(-1, 3, 3)) # (15, 3)
# Compute full-frame camera translation for rendering
from hamer.utils.renderer import cam_crop_to_full
right_val = 1 if is_right else 0
multiplier = (2 * right_val - 1)
pred_cam = out['pred_cam'][0].detach().float().clone()
pred_cam[1] = multiplier * pred_cam[1]
box_center = torch.as_tensor(batch["box_center"], device=pred_cam.device, dtype=pred_cam.dtype)
box_size = torch.as_tensor(batch["box_size"], device=pred_cam.device, dtype=pred_cam.dtype)
img_size = torch.as_tensor(batch["img_size"], device=pred_cam.device, dtype=pred_cam.dtype)
scaled_focal_length = (
self._model_cfg.EXTRA.FOCAL_LENGTH / self._model_cfg.MODEL.IMAGE_SIZE * img_size.max()
)
cam_t_full = cam_crop_to_full(
pred_cam.unsqueeze(0), box_center, box_size, img_size, scaled_focal_length
)[0].detach().cpu().numpy()
verts = out['pred_vertices'][0].detach().cpu().numpy()
if not is_right:
verts[:, 0] *= -1.0
return {
'hand_pose': hand_pose_aa.cpu().numpy(), # (15, 3)
'global_orient': global_orient_aa.cpu().numpy().squeeze(0), # (3,)
'betas': betas.cpu().numpy(), # (10,)
'vertices': verts, # (778, 3)
'keypoints_3d': out['pred_keypoints_3d'][0].cpu().numpy(), # (21, 3)
'cam_t': cam_t_full, # (3,)
'focal_length': float(scaled_focal_length),
'is_right': right_val,
}
def _rotmat_to_axis_angle(self, rotmat):
"""Convert rotation matrices to axis-angle representation."""
from pytorch3d.transforms import matrix_to_axis_angle
return matrix_to_axis_angle(rotmat)
@torch.no_grad()
def predict_video(self, video_path, left_bboxes, right_bboxes, masks=None):
"""
Predict MANO parameters for all frames in video.
Args:
video_path: Path to video file
left_bboxes: List of left hand bboxes (None if not visible)
right_bboxes: List of right hand bboxes (None if not visible)
masks: Optional list of SAM masks
Returns:
left_hand_params: List of dicts with MANO params (or None)
right_hand_params: List of dicts with MANO params (or None)
"""
self._load_model()
if isinstance(video_path, str):
video = mmcv.VideoReader(video_path)
else:
video = video_path
L = len(left_bboxes)
left_results = []
right_results = []
for i in tqdm(range(L), desc="HaMeR Hands"):
frame = video[i]
# Validate frame read from video
if frame is None:
print(f"[HaMeR] Warning: frame {i} is None, skipping")
left_results.append(None)
right_results.append(None)
continue
if not isinstance(frame, np.ndarray):
print(f"[HaMeR] Warning: frame {i} has unexpected type {type(frame)}, skipping")
left_results.append(None)
right_results.append(None)
continue
if frame.ndim != 3 or frame.shape[2] != 3:
print(f"[HaMeR] Warning: frame {i} has unexpected shape {frame.shape}, expected (H, W, 3), skipping")
left_results.append(None)
right_results.append(None)
continue
# Apply mask if available
if masks is not None and i < len(masks) and masks[i] is not None:
mask = masks[i]
if isinstance(mask, torch.Tensor):
mask = mask.numpy()
frame_h, frame_w = frame.shape[:2]
if mask.shape[0] != frame_h or mask.shape[1] != frame_w:
mask = cv2.resize(mask.astype(np.uint8), (frame_w, frame_h), interpolation=cv2.INTER_NEAREST)
gray_bg = np.full_like(frame, 128)
mask_3ch = mask[:, :, None].astype(bool)
frame = np.where(mask_3ch, frame, gray_bg)
# Validate post-mask frame shape
if frame.ndim != 3 or frame.shape[2] != 3:
print(f"[HaMeR] Warning: frame {i} after masking has unexpected shape {frame.shape}, skipping")
left_results.append(None)
right_results.append(None)
continue
# Left hand
if left_bboxes[i] is not None:
try:
left_result = self.predict_single(frame, left_bboxes[i], is_right=False)
except Exception as e:
print(f"[HaMeR] Left hand frame {i} failed: {e}")
left_result = None
else:
left_result = None
left_results.append(left_result)
# Right hand
if right_bboxes[i] is not None:
try:
right_result = self.predict_single(frame, right_bboxes[i], is_right=True)
except Exception as e:
print(f"[HaMeR] Right hand frame {i} failed: {e}")
right_result = None
else:
right_result = None
right_results.append(right_result)
return left_results, right_results
def mano_to_smplx_hands(
left_results,
right_results,
num_frames,
smooth_alpha=None,
median_window=7,
mean_window=5,
max_delta=0.1,
):
"""
Convert HaMeR MANO results to SMPL-X hand pose format.
SMPL-X expects:
- left_hand_pose: (L, 15, 3) axis-angle
- right_hand_pose: (L, 15, 3) axis-angle
Args:
left_results: List of dicts from HaMeR (or None)
right_results: List of dicts from HaMeR (or None)
num_frames: Total number of frames
Returns:
left_hand_pose: (L, 15, 3) numpy array
right_hand_pose: (L, 15, 3) numpy array
"""
left_hand_pose = np.zeros((num_frames, 15, 3), dtype=np.float32)
right_hand_pose = np.zeros((num_frames, 15, 3), dtype=np.float32)
left_valid = np.zeros(num_frames, dtype=bool)
right_valid = np.zeros(num_frames, dtype=bool)
for i in range(num_frames):
if left_results[i] is not None:
left_hand_pose[i] = left_results[i]['hand_pose']
left_valid[i] = True
if right_results[i] is not None:
right_hand_pose[i] = right_results[i]['hand_pose']
right_valid[i] = True
# Forward-fill missing frames to avoid jitter on occlusions
for i in range(1, num_frames):
if not left_valid[i]:
left_hand_pose[i] = left_hand_pose[i - 1]
if not right_valid[i]:
right_hand_pose[i] = right_hand_pose[i - 1]
# Median filter to suppress outliers
if median_window is not None and median_window > 1:
half = median_window // 2
left_filtered = left_hand_pose.copy()
right_filtered = right_hand_pose.copy()
for i in range(num_frames):
s = max(0, i - half)
e = min(num_frames, i + half + 1)
left_filtered[i] = np.median(left_hand_pose[s:e], axis=0)
right_filtered[i] = np.median(right_hand_pose[s:e], axis=0)
left_hand_pose = left_filtered
right_hand_pose = right_filtered
# Centered moving average to remove high-frequency jitter without lag
if mean_window is not None and mean_window > 1:
half = mean_window // 2
left_smoothed = left_hand_pose.copy()
right_smoothed = right_hand_pose.copy()
for i in range(num_frames):
s = max(0, i - half)
e = min(num_frames, i + half + 1)
left_smoothed[i] = left_hand_pose[s:e].mean(axis=0)
right_smoothed[i] = right_hand_pose[s:e].mean(axis=0)
left_hand_pose = left_smoothed
right_hand_pose = right_smoothed
# Optional EMA smoothing (disabled by default)
if smooth_alpha is not None and smooth_alpha > 0.0:
for i in range(1, num_frames):
left_hand_pose[i] = (1.0 - smooth_alpha) * left_hand_pose[i - 1] + smooth_alpha * left_hand_pose[i]
right_hand_pose[i] = (1.0 - smooth_alpha) * right_hand_pose[i - 1] + smooth_alpha * right_hand_pose[i]
# Clamp per-frame rotation change to suppress jitter spikes
if max_delta is not None and max_delta > 0.0:
for i in range(1, num_frames):
left_diff = left_hand_pose[i] - left_hand_pose[i - 1]
right_diff = right_hand_pose[i] - right_hand_pose[i - 1]
left_norm = np.linalg.norm(left_diff, axis=-1, keepdims=True)
right_norm = np.linalg.norm(right_diff, axis=-1, keepdims=True)
left_scale = np.minimum(1.0, max_delta / (left_norm + 1e-8))
right_scale = np.minimum(1.0, max_delta / (right_norm + 1e-8))
left_hand_pose[i] = left_hand_pose[i - 1] + left_diff * left_scale
right_hand_pose[i] = right_hand_pose[i - 1] + right_diff * right_scale
return left_hand_pose, right_hand_pose