pose-estimation / src /model.py
yuto090612's picture
fix: remove monkey-patch, scipy now properly installed
e1b437d verified
"""ViTPose model loading utilities."""
import logging
import torch
from transformers import AutoProcessor, VitPoseForPoseEstimation
from transformers.models.vitpose.image_processing_vitpose import VitPoseImageProcessor
logger = logging.getLogger(__name__)
MODEL_ID = "usyd-community/vitpose-base-simple"
# COCO 17 keypoints color palette
PALETTE = [
[255, 128, 0],
[255, 153, 51],
[255, 178, 102],
[230, 230, 0],
[255, 153, 255],
[153, 204, 255],
[255, 102, 255],
[255, 51, 255],
[102, 178, 255],
[51, 153, 255],
[255, 153, 153],
[255, 102, 102],
[255, 51, 51],
[153, 255, 153],
[102, 255, 102],
[51, 255, 51],
[0, 255, 0],
[0, 0, 255],
[255, 0, 0],
[255, 255, 255],
]
# Color indices for limb connections (maps to PALETTE)
LINK_COLOR_INDICES = [0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]
# Color indices for keypoints (maps to PALETTE)
KEYPOINT_COLOR_INDICES = [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]
KEYPOINT_SCORE_THRESHOLD = 0.3
KEYPOINT_RADIUS = 4
LIMB_THICKNESS = 2
def load_model(
device: str = "cuda",
) -> tuple[VitPoseForPoseEstimation, VitPoseImageProcessor]:
"""Load ViTPose model and processor from HuggingFace.
Args:
device: Device to load model on ("cuda" or "cpu").
Returns:
Tuple of (model, processor).
"""
resolved_device = device if torch.cuda.is_available() else "cpu"
if resolved_device != device:
logger.warning(
"CUDA not available, falling back to CPU (requested: %s)", device
)
logger.info("Loading ViTPose model from %s on %s", MODEL_ID, resolved_device)
processor: VitPoseImageProcessor = AutoProcessor.from_pretrained(MODEL_ID)
model: VitPoseForPoseEstimation = VitPoseForPoseEstimation.from_pretrained(MODEL_ID)
model = model.to(resolved_device)
model.eval()
logger.info("Model loaded successfully")
return model, processor