Commit
·
c21af4b
1
Parent(s):
746cfcf
yolo fix
Browse files- backend/gradio_labanmovementanalysis/pose_estimation.py +86 -55
- examples/mediapipe.json +0 -0
- examples/movenet.json +0 -0
- examples/yolov11.json +0 -0
- examples/yolov8.json +0 -0
backend/gradio_labanmovementanalysis/pose_estimation.py
CHANGED
|
@@ -146,10 +146,10 @@ class MoveNetPoseEstimator(PoseEstimator):
|
|
| 146 |
confidence=float(score),
|
| 147 |
name=self.KEYPOINT_NAMES[i]
|
| 148 |
))
|
| 149 |
-
if keypoints:
|
| 150 |
return [PoseResult(keypoints=keypoints, frame_index=0)]
|
| 151 |
else:
|
| 152 |
-
return []
|
| 153 |
|
| 154 |
def get_keypoint_names(self) -> List[str]:
|
| 155 |
return self.KEYPOINT_NAMES.copy()
|
|
@@ -190,10 +190,10 @@ class MediaPipePoseEstimator(PoseEstimator):
|
|
| 190 |
import mediapipe as mp
|
| 191 |
self.mp_pose = mp.solutions.pose
|
| 192 |
self.pose = self.mp_pose.Pose(
|
| 193 |
-
static_image_mode=False,
|
| 194 |
model_complexity=self.model_complexity,
|
| 195 |
min_detection_confidence=self.min_detection_confidence,
|
| 196 |
-
min_tracking_confidence=0.5
|
| 197 |
)
|
| 198 |
except ImportError:
|
| 199 |
raise ImportError("MediaPipe required. Install with: pip install mediapipe")
|
|
@@ -219,10 +219,11 @@ class MediaPipePoseEstimator(PoseEstimator):
|
|
| 219 |
keypoints.append(Keypoint(
|
| 220 |
x=landmark.x,
|
| 221 |
y=landmark.y,
|
| 222 |
-
confidence=landmark.visibility if hasattr(landmark, 'visibility') else 1.0,
|
| 223 |
name=self.LANDMARK_NAMES[i] if i < len(self.LANDMARK_NAMES) else f"landmark_{i}"
|
| 224 |
))
|
| 225 |
|
|
|
|
| 226 |
return [PoseResult(keypoints=keypoints, frame_index=0)]
|
| 227 |
|
| 228 |
def get_keypoint_names(self) -> List[str]:
|
|
@@ -250,20 +251,22 @@ class YOLOPoseEstimator(PoseEstimator):
|
|
| 250 |
Initialize YOLO pose model.
|
| 251 |
|
| 252 |
Args:
|
| 253 |
-
model_version: "v8" or "v11"
|
| 254 |
model_size: Model size - "n" (nano), "s" (small), "m" (medium), "l" (large), "x" (xlarge)
|
| 255 |
-
confidence_threshold: Minimum confidence for detections
|
| 256 |
"""
|
| 257 |
-
self.model_version = model_version
|
| 258 |
-
self.model_size = model_size
|
| 259 |
-
self.confidence_threshold = confidence_threshold
|
| 260 |
self.model = None
|
| 261 |
|
| 262 |
# Determine model path
|
| 263 |
-
if model_version == "v8":
|
| 264 |
-
self.model_path = f"yolov8{model_size}-pose.pt"
|
| 265 |
-
|
| 266 |
-
self.model_path = f"
|
|
|
|
|
|
|
| 267 |
|
| 268 |
self._load_model()
|
| 269 |
|
|
@@ -282,34 +285,53 @@ class YOLOPoseEstimator(PoseEstimator):
|
|
| 282 |
self._load_model()
|
| 283 |
|
| 284 |
# Run inference
|
| 285 |
-
|
|
|
|
| 286 |
|
| 287 |
pose_results = []
|
|
|
|
| 288 |
|
| 289 |
-
# Process each detection
|
| 290 |
-
for r in results:
|
| 291 |
-
if r.keypoints is not None:
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
| 294 |
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
for i, (x, y, conf) in enumerate(keypoints_data):
|
| 298 |
# Sanitize NaN values
|
| 299 |
-
if any(map(math.isnan, [
|
| 300 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
keypoints.append(Keypoint(
|
| 302 |
-
x=float(
|
| 303 |
-
y=float(
|
| 304 |
-
confidence=
|
| 305 |
name=self.KEYPOINT_NAMES[i] if i < len(self.KEYPOINT_NAMES) else f"joint_{i}"
|
| 306 |
))
|
| 307 |
|
| 308 |
if keypoints:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
pose_results.append(PoseResult(
|
| 310 |
keypoints=keypoints,
|
| 311 |
-
frame_index=0,
|
| 312 |
-
person_id=
|
| 313 |
))
|
| 314 |
|
| 315 |
return pose_results
|
|
@@ -376,18 +398,42 @@ def get_pose_estimator(model_spec: str) -> PoseEstimator:
|
|
| 376 |
# MoveNet variants
|
| 377 |
elif model_spec.startswith("movenet"):
|
| 378 |
variant = "lightning" if "lightning" in model_spec else "thunder"
|
|
|
|
|
|
|
| 379 |
return create_pose_estimator("movenet", model_variant=variant)
|
| 380 |
|
| 381 |
# YOLO variants
|
| 382 |
elif model_spec.startswith("yolo"):
|
| 383 |
parts = model_spec.split("-")
|
| 384 |
-
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
return create_pose_estimator("yolo", model_version=version, model_size=size)
|
| 387 |
|
| 388 |
-
# Legacy format support
|
| 389 |
else:
|
| 390 |
-
|
|
|
|
|
|
|
|
|
|
| 391 |
|
| 392 |
|
| 393 |
def _safe_pose_from_dets(dets: List[PoseResult], frame_idx: int) -> List[PoseResult]:
|
|
@@ -396,29 +442,14 @@ def _safe_pose_from_dets(dets: List[PoseResult], frame_idx: int) -> List[PoseRes
|
|
| 396 |
After the loop, interpolate missing poses in pose_seq before running metrics.
|
| 397 |
Add debug prints when a pose is missing and when interpolation is performed.
|
| 398 |
"""
|
|
|
|
|
|
|
|
|
|
| 399 |
safe_poses = []
|
| 400 |
missing_mask = []
|
| 401 |
prev_pose = None
|
| 402 |
|
| 403 |
-
for
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
safe_poses.append(PoseResult(keypoints=[], frame_index=frame_idx))
|
| 408 |
-
missing_mask.append(True)
|
| 409 |
-
else:
|
| 410 |
-
safe_poses.append(PoseResult(keypoints=prev_pose.keypoints, frame_index=frame_idx))
|
| 411 |
-
missing_mask.append(False)
|
| 412 |
-
prev_pose = det
|
| 413 |
-
elif det.frame_index > frame_idx:
|
| 414 |
-
break
|
| 415 |
-
|
| 416 |
-
if prev_pose is None:
|
| 417 |
-
print(f"Warning: No poses found for frame {frame_idx}")
|
| 418 |
-
safe_poses.append(PoseResult(keypoints=[], frame_index=frame_idx))
|
| 419 |
-
missing_mask.append(True)
|
| 420 |
-
else:
|
| 421 |
-
safe_poses.append(PoseResult(keypoints=prev_pose.keypoints, frame_index=frame_idx))
|
| 422 |
-
missing_mask.append(False)
|
| 423 |
-
|
| 424 |
-
return safe_poses, missing_mask
|
|
|
|
| 146 |
confidence=float(score),
|
| 147 |
name=self.KEYPOINT_NAMES[i]
|
| 148 |
))
|
| 149 |
+
if keypoints: # MoveNet is single-pose, so only one result if any
|
| 150 |
return [PoseResult(keypoints=keypoints, frame_index=0)]
|
| 151 |
else:
|
| 152 |
+
return [] # No pose detected or all keypoints were NaN
|
| 153 |
|
| 154 |
def get_keypoint_names(self) -> List[str]:
|
| 155 |
return self.KEYPOINT_NAMES.copy()
|
|
|
|
| 190 |
import mediapipe as mp
|
| 191 |
self.mp_pose = mp.solutions.pose
|
| 192 |
self.pose = self.mp_pose.Pose(
|
| 193 |
+
static_image_mode=False, # Process video frames
|
| 194 |
model_complexity=self.model_complexity,
|
| 195 |
min_detection_confidence=self.min_detection_confidence,
|
| 196 |
+
min_tracking_confidence=0.5 # Default from MediaPipe
|
| 197 |
)
|
| 198 |
except ImportError:
|
| 199 |
raise ImportError("MediaPipe required. Install with: pip install mediapipe")
|
|
|
|
| 219 |
keypoints.append(Keypoint(
|
| 220 |
x=landmark.x,
|
| 221 |
y=landmark.y,
|
| 222 |
+
confidence=landmark.visibility if hasattr(landmark, 'visibility') else 1.0, # Use visibility as confidence
|
| 223 |
name=self.LANDMARK_NAMES[i] if i < len(self.LANDMARK_NAMES) else f"landmark_{i}"
|
| 224 |
))
|
| 225 |
|
| 226 |
+
# MediaPipe Pose API typically returns one pose per image in this configuration
|
| 227 |
return [PoseResult(keypoints=keypoints, frame_index=0)]
|
| 228 |
|
| 229 |
def get_keypoint_names(self) -> List[str]:
|
|
|
|
| 251 |
Initialize YOLO pose model.
|
| 252 |
|
| 253 |
Args:
|
| 254 |
+
model_version: "v8" or "v11" (Note: v11 is hypothetical here as Ultralytics primarily focuses on v8, v9, etc.)
|
| 255 |
model_size: Model size - "n" (nano), "s" (small), "m" (medium), "l" (large), "x" (xlarge)
|
| 256 |
+
confidence_threshold: Minimum confidence for person detections (not individual keypoints)
|
| 257 |
"""
|
| 258 |
+
self.model_version = model_version.lower()
|
| 259 |
+
self.model_size = model_size.lower()
|
| 260 |
+
self.confidence_threshold = confidence_threshold # This is for the main object detection
|
| 261 |
self.model = None
|
| 262 |
|
| 263 |
# Determine model path
|
| 264 |
+
if self.model_version == "v8":
|
| 265 |
+
self.model_path = f"yolov8{self.model_size}-pose.pt"
|
| 266 |
+
elif self.model_version == "v11": # Assuming v11 follows a similar naming, adjust if official names differ
|
| 267 |
+
self.model_path = f"yolov11{self.model_size}-pose.pt" # This might be a placeholder if v11 isn't standard Ultralytics
|
| 268 |
+
else:
|
| 269 |
+
raise ValueError(f"Unsupported YOLO version: {model_version}")
|
| 270 |
|
| 271 |
self._load_model()
|
| 272 |
|
|
|
|
| 285 |
self._load_model()
|
| 286 |
|
| 287 |
# Run inference
|
| 288 |
+
# conf is for person detection; keypoint confidences are separate
|
| 289 |
+
results = self.model(frame, conf=self.confidence_threshold, iou=0.7)
|
| 290 |
|
| 291 |
pose_results = []
|
| 292 |
+
height, width = frame.shape[:2]
|
| 293 |
|
| 294 |
+
# Process each detection result (Ultralytics Results object)
|
| 295 |
+
for r_idx, r in enumerate(results):
|
| 296 |
+
if r.keypoints is not None and hasattr(r.keypoints, 'data'):
|
| 297 |
+
# r.keypoints.data is a tensor of shape (num_persons, num_keypoints, 3)
|
| 298 |
+
# The last dimension is [x_pixel, y_pixel, confidence_keypoint]
|
| 299 |
+
for person_idx, keypoints_data_tensor in enumerate(r.keypoints.data):
|
| 300 |
+
keypoints_list_for_person = keypoints_data_tensor.cpu().tolist() # Convert tensor to list
|
| 301 |
|
| 302 |
+
keypoints = []
|
| 303 |
+
for i, (x_pixel, y_pixel, kp_conf) in enumerate(keypoints_list_for_person):
|
|
|
|
| 304 |
# Sanitize NaN values
|
| 305 |
+
if any(map(math.isnan, [x_pixel, y_pixel, kp_conf])):
|
| 306 |
continue
|
| 307 |
+
|
| 308 |
+
current_confidence = float(kp_conf)
|
| 309 |
+
|
| 310 |
+
# According to Ultralytics/COCO, missing keypoints are often (0,0) with conf 0.
|
| 311 |
+
# If (0,0) pixel coords are returned with non-zero confidence by the model,
|
| 312 |
+
# it might be an artifact or a misinterpretation.
|
| 313 |
+
# We will reduce confidence for (0,0) pixel points if their original confidence isn't extremely high,
|
| 314 |
+
# to help filter them in downstream tasks (visualization, analysis).
|
| 315 |
+
if float(x_pixel) == 0.0 and float(y_pixel) == 0.0 and current_confidence < 0.9:
|
| 316 |
+
# Threshold 0.9 is arbitrary, means "only trust (0,0) if model is super sure"
|
| 317 |
+
current_confidence = 0.0
|
| 318 |
+
|
| 319 |
keypoints.append(Keypoint(
|
| 320 |
+
x=float(x_pixel) / width if width > 0 else 0.0, # Normalize
|
| 321 |
+
y=float(y_pixel) / height if height > 0 else 0.0, # Normalize
|
| 322 |
+
confidence=current_confidence,
|
| 323 |
name=self.KEYPOINT_NAMES[i] if i < len(self.KEYPOINT_NAMES) else f"joint_{i}"
|
| 324 |
))
|
| 325 |
|
| 326 |
if keypoints:
|
| 327 |
+
# Create a unique person ID if not available from tracker (e.g. r.boxes.id)
|
| 328 |
+
# For simplicity, using r_idx (result index) and person_idx (index within this result)
|
| 329 |
+
# This might not be persistent across frames without a tracker.
|
| 330 |
+
unique_person_id = person_idx # Or a more robust ID if tracking is used
|
| 331 |
pose_results.append(PoseResult(
|
| 332 |
keypoints=keypoints,
|
| 333 |
+
frame_index=0, # Will be updated by detect_batch
|
| 334 |
+
person_id=unique_person_id
|
| 335 |
))
|
| 336 |
|
| 337 |
return pose_results
|
|
|
|
| 398 |
# MoveNet variants
|
| 399 |
elif model_spec.startswith("movenet"):
|
| 400 |
variant = "lightning" if "lightning" in model_spec else "thunder"
|
| 401 |
+
if "lightning" not in model_spec and "thunder" not in model_spec: # e.g. "movenet"
|
| 402 |
+
variant = "lightning" # Default MoveNet to lightning
|
| 403 |
return create_pose_estimator("movenet", model_variant=variant)
|
| 404 |
|
| 405 |
# YOLO variants
|
| 406 |
elif model_spec.startswith("yolo"):
|
| 407 |
parts = model_spec.split("-")
|
| 408 |
+
# yolo-v8-n -> parts = ["yolo", "v8", "n"]
|
| 409 |
+
# yolo -> parts = ["yolo"] -> default to v8-n
|
| 410 |
+
version = "v8" # default version
|
| 411 |
+
size = "n" # default size
|
| 412 |
+
|
| 413 |
+
if len(parts) > 1: # "yolo-v8" or "yolo-v11"
|
| 414 |
+
if parts[1] in ["v8", "v11"]: # Add other versions as needed
|
| 415 |
+
version = parts[1]
|
| 416 |
+
# If parts[1] is a size (e.g. "yolo-n"), then version remains default "v8" and size is parts[1]
|
| 417 |
+
elif parts[1] in ["n", "s", "m", "l", "x"]:
|
| 418 |
+
size = parts[1]
|
| 419 |
+
|
| 420 |
+
if len(parts) > 2: # "yolo-v8-n"
|
| 421 |
+
if parts[2] in ["n", "s", "m", "l", "x"]:
|
| 422 |
+
size = parts[2]
|
| 423 |
+
|
| 424 |
+
# Handle case like "yolo-s" where version is implied as v8
|
| 425 |
+
if len(parts) == 2 and parts[1] in ["n","s","m","l","x"]:
|
| 426 |
+
version = "v8" # Default to v8 if only size is specified after "yolo-"
|
| 427 |
+
size = parts[1]
|
| 428 |
+
|
| 429 |
return create_pose_estimator("yolo", model_version=version, model_size=size)
|
| 430 |
|
| 431 |
+
# Legacy format support or direct name
|
| 432 |
else:
|
| 433 |
+
try:
|
| 434 |
+
return create_pose_estimator(model_spec)
|
| 435 |
+
except ValueError: # If model_spec isn't a direct key like "mediapipe", "movenet", "yolo"
|
| 436 |
+
raise ValueError(f"Invalid or unsupported model specification: {model_spec}")
|
| 437 |
|
| 438 |
|
| 439 |
def _safe_pose_from_dets(dets: List[PoseResult], frame_idx: int) -> List[PoseResult]:
|
|
|
|
| 442 |
After the loop, interpolate missing poses in pose_seq before running metrics.
|
| 443 |
Add debug prints when a pose is missing and when interpolation is performed.
|
| 444 |
"""
|
| 445 |
+
# This function is currently not used in the provided codebase.
|
| 446 |
+
# If it were to be used, it would need proper integration.
|
| 447 |
+
print(f"[DEBUG] _safe_pose_from_dets called for frame {frame_idx}, but is not currently integrated.")
|
| 448 |
safe_poses = []
|
| 449 |
missing_mask = []
|
| 450 |
prev_pose = None
|
| 451 |
|
| 452 |
+
# This logic seems flawed for its intended purpose without further context or modification.
|
| 453 |
+
# For now, returning empty or passed 'dets' might be safer if it's not fully implemented.
|
| 454 |
+
# Returning dets as is, since the function is not used.
|
| 455 |
+
return dets, []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/mediapipe.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/movenet.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/yolov11.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/yolov8.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|