Hussein El-Hadidy
fix path
29f3df9
from ultralytics import YOLO
from CPR_Module.Common.logging_config import cpr_logger
class PoseEstimator:
"""Human pose estimation using YOLO"""
def __init__(self, min_confidence, model_path="CPR_Module/Common/yolo11n-pose.pt"):
self.model = YOLO(model_path).to("cuda")
if next(self.model.model.parameters()).is_cuda:
cpr_logger.info("YOLO model loaded on CUDA (GPU).")
else:
cpr_logger.warning("YOLO model is not on CUDA. Check your setup.")
self.min_confidence = min_confidence
def detect_poses(self, frame):
"""Detect human poses in a frame"""
try:
results = self.model(frame, verbose=False, conf=self.min_confidence, show=False, iou=0.2)
if not results or len(results[0].keypoints.xy) == 0:
return None
return results[0]
except Exception as e:
cpr_logger.error(f"Pose detection error: {e}")
return None
def get_keypoints(self, results, person_idx=0):
"""Extract keypoints for a detected person"""
try:
if not results or len(results.keypoints.xy) <= person_idx:
return None
return results.keypoints.xy[person_idx].cpu().numpy()
except Exception as e:
cpr_logger.error(f"Keypoint extraction error: {e}")
return None
def draw_keypoints(self, frame, results):
"""Draw detected keypoints on frame"""
try:
return results.plot()
except Exception as e:
cpr_logger.error(f"Keypoint drawing error: {e}")
return frame