import sys import os from pathlib import Path import numpy as np from tqdm import tqdm import torch from ultralytics import YOLO # Dynamically add HaWoR path for local imports current_file_dir = os.path.dirname(os.path.abspath(__file__)) hawor_path = os.path.abspath(os.path.join(current_file_dir, '..', '..', 'thirdparty', 'HaWoR')) if hawor_path not in sys.path: sys.path.insert(0, hawor_path) from thirdparty.HaWoR.lib.models.hawor import HAWOR from thirdparty.HaWoR.lib.pipeline.tools import parse_chunks from thirdparty.HaWoR.lib.eval_utils.custom_utils import interpolate_bboxes from thirdparty.HaWoR.hawor.utils.rotation import angle_axis_to_rotation_matrix, rotation_matrix_to_angle_axis from thirdparty.HaWoR.hawor.configs import get_config def load_hawor(checkpoint_path: str): """ Loads the HAWOR model and its configuration from a checkpoint. Args: checkpoint_path (str): Path to the model checkpoint file or HuggingFace repo ID (e.g., 'username/model-name'). Returns: tuple: (HAWOR model instance, model configuration object) """ from huggingface_hub import hf_hub_download # Check if checkpoint_path is a HuggingFace repo (no local path separators) if '/' in checkpoint_path and not os.path.exists(checkpoint_path): # Download from HuggingFace Hub print(f"Downloading model from HuggingFace: {checkpoint_path}") checkpoint_file = hf_hub_download(repo_id=checkpoint_path, filename="checkpoints/hawor.ckpt") config_file = hf_hub_download(repo_id=checkpoint_path, filename="config.yaml") print(f"Downloaded checkpoint to: {checkpoint_file}") print(f"Downloaded config to: {config_file}") print(f"Checkpoint exists: {os.path.exists(checkpoint_file)}") model_cfg_path = Path(config_file) else: # Local checkpoint path checkpoint_file = checkpoint_path model_cfg_path = Path(checkpoint_path).parent.parent / 'config.yaml' print(f"Using local checkpoint: {checkpoint_file}") print(f"Using local config: {model_cfg_path}") print(f"Loading config from: {model_cfg_path}") model_cfg = get_config(str(model_cfg_path), update_cachedir=True) # Override config for correct bbox cropping when using ViT backbone if (model_cfg.MODEL.BACKBONE.TYPE == 'vit') and ('BBOX_SHAPE' not in model_cfg.MODEL): model_cfg.defrost() assert model_cfg.MODEL.IMAGE_SIZE == 256, \ f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone" model_cfg.MODEL.BBOX_SHAPE = [192, 256] model_cfg.freeze() # Load model with Lightning - map_location will be set when moving to device print(f"Loading HAWOR model from checkpoint: {checkpoint_file}") model = HAWOR.load_from_checkpoint( checkpoint_file, strict=False, cfg=model_cfg, map_location='cpu' # Load to CPU first, then move to device later ) return model, model_cfg class HaworPipeline: """ Pipeline for hand detection, tracking, and HAWOR motion estimation. """ def __init__( self, model_path: str = '', detector_path: str = '', device: torch.device = torch.device("cuda") ): """ Initializes the HAWOR model and detector path. Args: model_path (str): Path to the HAWOR checkpoint. detector_path (str): Path to the hand detector (YOLO) weights. device (torch.device): Device to load models onto. """ self.device = device self.detector_path = detector_path self._checkpoint_path = model_path # Store for reloading self._original_device = device # Store for reloading model, model_cfg = load_hawor(model_path) model = model.to(device) model.eval() self.model = model self.model_cfg = model_cfg # Store config for buffer reinitialization def recon( self, images: list, img_focal: float, thresh: float = 0.2, single_image: bool = False ) -> dict: """ Performs hand detection, tracking, and HAWOR-based 3D reconstruction. Args: images (list): List of consecutive input image frames (cv2/numpy format). img_focal (float): Focal length of the camera in pixels. thresh (float): Confidence threshold for hand detection. single_image (bool): Flag for single-image processing mode. Returns: dict: Dictionary of reconstruction results for 'left' and 'right' hands. """ # Load detector and perform detection/tracking hand_det_model = YOLO(self.detector_path) _, tracks = detect_track(images, hand_det_model, thresh=thresh) # Perform HAWOR motion estimation recon_results = hawor_motion_estimation( images, tracks, self.model, img_focal, single_image=single_image ) # delete the YOLO detector to avoid accumulation of tracking history del hand_det_model return recon_results # Adapted from https://github.com/ThunderVVV/HaWoR/blob/main/scripts/scripts_test_video/detect_track_video.py def detect_track(imgfiles: list, hand_det_model: YOLO, thresh: float = 0.5) -> tuple: """ Detects and tracks hands across a sequence of images using YOLO. Args: imgfiles (list): List of image frames. hand_det_model (YOLO): The initialized YOLO hand detection model. thresh (float): Confidence threshold for detection. Returns: tuple: (list of boxes (unused in original logic), dict of tracks) """ boxes_ = [] tracks = {} for t, img_cv2 in enumerate(tqdm(imgfiles)): ### --- Detection --- with torch.no_grad(): with torch.amp.autocast('cuda'): results = hand_det_model.track(img_cv2, conf=thresh, persist=True, verbose=False) boxes = results[0].boxes.xyxy.cpu().numpy() confs = results[0].boxes.conf.cpu().numpy() handedness = results[0].boxes.cls.cpu().numpy() if not results[0].boxes.id is None: track_id = results[0].boxes.id.cpu().numpy() else: track_id = [-1] * len(boxes) boxes = np.hstack([boxes, confs[:, None]]) find_right = False find_left = False for idx, box in enumerate(boxes): if track_id[idx] == -1: if handedness[[idx]] > 0: id = int(10000) else: id = int(5000) else: id = track_id[idx] subj = dict() subj['frame'] = t subj['det'] = True subj['det_box'] = boxes[[idx]] subj['det_handedness'] = handedness[[idx]] if (not find_right and handedness[[idx]] > 0) or (not find_left and handedness[[idx]]==0): if id in tracks: tracks[id].append(subj) else: tracks[id] = [subj] if handedness[[idx]] > 0: find_right = True elif handedness[[idx]] == 0: find_left = True return boxes_, tracks # Adapted from https://github.com/ThunderVVV/HaWoR/blob/main/scripts/scripts_test_video/hawor_video.py def hawor_motion_estimation( imgfiles: list, tracks: dict, model: HAWOR, img_focal: float, single_image: bool = False ) -> dict: """ Performs HAWOR 3D hand reconstruction on detected and tracked hand regions. Args: imgfiles (list): List of image frames. tracks (dict): Dictionary mapping track ID to a list of detection objects. model (HAWOR): The initialized HAWOR model. img_focal (float): Camera focal length. single_image (bool): Flag for single-image processing mode. Returns: dict: Reconstructed parameters ('left' and 'right' hand results). """ left_results = {} right_results = {} tid = np.array([tr for tr in tracks]) left_trk = [] right_trk = [] for k, idx in enumerate(tid): trk = tracks[idx] valid = np.array([t['det'] for t in trk]) is_right = np.concatenate([t['det_handedness'] for t in trk])[valid] if is_right.sum() / len(is_right) < 0.5: left_trk.extend(trk) else: right_trk.extend(trk) left_trk = sorted(left_trk, key=lambda x: x['frame']) right_trk = sorted(right_trk, key=lambda x: x['frame']) final_tracks = { 0: left_trk, 1: right_trk } tid = [0, 1] img = imgfiles[0] img_center = [img.shape[1] / 2, img.shape[0] / 2]# w/2, h/2 H, W = img.shape[:2] for idx in tid: print(f"tracklet {idx}:") trk = final_tracks[idx] # interp bboxes valid = np.array([t['det'] for t in trk]) if not single_image: if valid.sum() < 2: continue else: if valid.sum() < 1: continue boxes = np.concatenate([t['det_box'] for t in trk]) non_zero_indices = np.where(np.any(boxes != 0, axis=1))[0] first_non_zero = non_zero_indices[0] last_non_zero = non_zero_indices[-1] boxes[first_non_zero:last_non_zero+1] = interpolate_bboxes(boxes[first_non_zero:last_non_zero+1]) valid[first_non_zero:last_non_zero+1] = True boxes = boxes[first_non_zero:last_non_zero+1] is_right = np.concatenate([t['det_handedness'] for t in trk])[valid] frame = np.array([t['frame'] for t in trk])[valid] if is_right.sum() / len(is_right) < 0.5: is_right = np.zeros((len(boxes), 1)) else: is_right = np.ones((len(boxes), 1)) frame_chunks, boxes_chunks = parse_chunks(frame, boxes, min_len=1) if len(frame_chunks) == 0: continue for frame_ck, boxes_ck in zip(frame_chunks, boxes_chunks): print(f"inference from frame {frame_ck[0]} to {frame_ck[-1]}") img_ck = [imgfiles[i] for i in frame_ck] if is_right[0] > 0: do_flip = False else: do_flip = True results = model.inference(img_ck, boxes_ck, img_focal=img_focal, img_center=img_center, do_flip=do_flip) data_out = { "init_root_orient": results["pred_rotmat"][None, :, 0], # (B, T, 3, 3) "init_hand_pose": results["pred_rotmat"][None, :, 1:], # (B, T, 15, 3, 3) "init_trans": results["pred_trans"][None, :, 0], # (B, T, 3) "init_betas": results["pred_shape"][None, :] # (B, T, 10) } # flip left hand init_root = rotation_matrix_to_angle_axis(data_out["init_root_orient"]) init_hand_pose = rotation_matrix_to_angle_axis(data_out["init_hand_pose"]) if do_flip: init_root[..., 1] *= -1 init_root[..., 2] *= -1 data_out["init_root_orient"] = angle_axis_to_rotation_matrix(init_root) data_out["init_hand_pose"] = angle_axis_to_rotation_matrix(init_hand_pose) s_frame = frame_ck[0] e_frame = frame_ck[-1] for frame_id in range(s_frame, e_frame+1): result = {} result['beta'] = data_out['init_betas'][0, frame_id-s_frame].cpu().numpy() result['hand_pose'] = data_out['init_hand_pose'][0, frame_id-s_frame].cpu().numpy() result['global_orient'] = data_out['init_root_orient'][0, frame_id-s_frame].cpu().numpy() result['transl'] = data_out['init_trans'][0, frame_id-s_frame].cpu().numpy() if idx == 0: left_results[frame_id] = result else: right_results[frame_id] = result reformat_results = {'left': left_results, 'right': right_results} return reformat_results