"""ViTPose model loading utilities.""" import logging import torch from transformers import AutoProcessor, VitPoseForPoseEstimation from transformers.models.vitpose.image_processing_vitpose import VitPoseImageProcessor logger = logging.getLogger(__name__) MODEL_ID = "usyd-community/vitpose-base-simple" # COCO 17 keypoints color palette PALETTE = [ [255, 128, 0], [255, 153, 51], [255, 178, 102], [230, 230, 0], [255, 153, 255], [153, 204, 255], [255, 102, 255], [255, 51, 255], [102, 178, 255], [51, 153, 255], [255, 153, 153], [255, 102, 102], [255, 51, 51], [153, 255, 153], [102, 255, 102], [51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], [255, 255, 255], ] # Color indices for limb connections (maps to PALETTE) LINK_COLOR_INDICES = [0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16] # Color indices for keypoints (maps to PALETTE) KEYPOINT_COLOR_INDICES = [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0] KEYPOINT_SCORE_THRESHOLD = 0.3 KEYPOINT_RADIUS = 4 LIMB_THICKNESS = 2 def load_model( device: str = "cuda", ) -> tuple[VitPoseForPoseEstimation, VitPoseImageProcessor]: """Load ViTPose model and processor from HuggingFace. Args: device: Device to load model on ("cuda" or "cpu"). Returns: Tuple of (model, processor). """ resolved_device = device if torch.cuda.is_available() else "cpu" if resolved_device != device: logger.warning( "CUDA not available, falling back to CPU (requested: %s)", device ) logger.info("Loading ViTPose model from %s on %s", MODEL_ID, resolved_device) processor: VitPoseImageProcessor = AutoProcessor.from_pretrained(MODEL_ID) model: VitPoseForPoseEstimation = VitPoseForPoseEstimation.from_pretrained(MODEL_ID) model = model.to(resolved_device) model.eval() logger.info("Model loaded successfully") return model, processor