File size: 1,986 Bytes
7e1ca89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""ViTPose model loading utilities."""

import logging

import torch
from transformers import AutoProcessor, VitPoseForPoseEstimation
from transformers.models.vitpose.image_processing_vitpose import VitPoseImageProcessor

logger = logging.getLogger(__name__)

MODEL_ID = "usyd-community/vitpose-base-simple"

# COCO 17 keypoints color palette
PALETTE = [
    [255, 128, 0],
    [255, 153, 51],
    [255, 178, 102],
    [230, 230, 0],
    [255, 153, 255],
    [153, 204, 255],
    [255, 102, 255],
    [255, 51, 255],
    [102, 178, 255],
    [51, 153, 255],
    [255, 153, 153],
    [255, 102, 102],
    [255, 51, 51],
    [153, 255, 153],
    [102, 255, 102],
    [51, 255, 51],
    [0, 255, 0],
    [0, 0, 255],
    [255, 0, 0],
    [255, 255, 255],
]

# Color indices for limb connections (maps to PALETTE)
LINK_COLOR_INDICES = [0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]

# Color indices for keypoints (maps to PALETTE)
KEYPOINT_COLOR_INDICES = [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]

KEYPOINT_SCORE_THRESHOLD = 0.3
KEYPOINT_RADIUS = 4
LIMB_THICKNESS = 2


def load_model(
    device: str = "cuda",
) -> tuple[VitPoseForPoseEstimation, VitPoseImageProcessor]:
    """Load ViTPose model and processor from HuggingFace.

    Args:
        device: Device to load model on ("cuda" or "cpu").

    Returns:
        Tuple of (model, processor).
    """
    resolved_device = device if torch.cuda.is_available() else "cpu"
    if resolved_device != device:
        logger.warning(
            "CUDA not available, falling back to CPU (requested: %s)", device
        )

    logger.info("Loading ViTPose model from %s on %s", MODEL_ID, resolved_device)

    processor: VitPoseImageProcessor = AutoProcessor.from_pretrained(MODEL_ID)
    model: VitPoseForPoseEstimation = VitPoseForPoseEstimation.from_pretrained(MODEL_ID)
    model = model.to(resolved_device)
    model.eval()

    logger.info("Model loaded successfully")
    return model, processor