|
|
from transformers import ( |
|
|
VitPoseForPoseEstimation, |
|
|
AutoProcessor, |
|
|
RTDetrForObjectDetection, |
|
|
) |
|
|
from PIL import Image |
|
|
import torch |
|
|
|
|
|
|
|
|
det_proc = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365") |
|
|
det_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365").eval() |
|
|
|
|
|
pose_proc = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple") |
|
|
pose_model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple").eval() |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
det_model.to(device) |
|
|
pose_model.to(device) |
|
|
|
|
|
|
|
|
def predict(inputs: dict) -> dict: |
|
|
""" |
|
|
inputs: {"image": PIL.Image} |
|
|
returns: {"poses": ...} |
|
|
""" |
|
|
image = inputs["image"] |
|
|
|
|
|
|
|
|
det_inputs = det_proc(images=image, return_tensors="pt").to(device) |
|
|
det_outputs = det_model(**det_inputs) |
|
|
results = det_proc.post_process_object_detection( |
|
|
det_outputs, |
|
|
threshold=0.5, |
|
|
target_sizes=[(image.height, image.width)] |
|
|
) |
|
|
|
|
|
person_boxes = results[0]["boxes"][results[0]["labels"] == 0] |
|
|
|
|
|
|
|
|
pose_inputs = pose_proc(image, boxes=[person_boxes], return_tensors="pt").to(device) |
|
|
with torch.no_grad(): |
|
|
pose_outputs = pose_model(**pose_inputs) |
|
|
poses = pose_proc.post_process_pose_estimation(pose_outputs, boxes=[person_boxes]) |
|
|
|
|
|
return {"poses": poses[0]} |
|
|
|