BladeSzaSza commited on
Commit
c21af4b
·
1 Parent(s): 746cfcf
backend/gradio_labanmovementanalysis/pose_estimation.py CHANGED
@@ -146,10 +146,10 @@ class MoveNetPoseEstimator(PoseEstimator):
146
  confidence=float(score),
147
  name=self.KEYPOINT_NAMES[i]
148
  ))
149
- if keypoints:
150
  return [PoseResult(keypoints=keypoints, frame_index=0)]
151
  else:
152
- return []
153
 
154
  def get_keypoint_names(self) -> List[str]:
155
  return self.KEYPOINT_NAMES.copy()
@@ -190,10 +190,10 @@ class MediaPipePoseEstimator(PoseEstimator):
190
  import mediapipe as mp
191
  self.mp_pose = mp.solutions.pose
192
  self.pose = self.mp_pose.Pose(
193
- static_image_mode=False,
194
  model_complexity=self.model_complexity,
195
  min_detection_confidence=self.min_detection_confidence,
196
- min_tracking_confidence=0.5
197
  )
198
  except ImportError:
199
  raise ImportError("MediaPipe required. Install with: pip install mediapipe")
@@ -219,10 +219,11 @@ class MediaPipePoseEstimator(PoseEstimator):
219
  keypoints.append(Keypoint(
220
  x=landmark.x,
221
  y=landmark.y,
222
- confidence=landmark.visibility if hasattr(landmark, 'visibility') else 1.0,
223
  name=self.LANDMARK_NAMES[i] if i < len(self.LANDMARK_NAMES) else f"landmark_{i}"
224
  ))
225
 
 
226
  return [PoseResult(keypoints=keypoints, frame_index=0)]
227
 
228
  def get_keypoint_names(self) -> List[str]:
@@ -250,20 +251,22 @@ class YOLOPoseEstimator(PoseEstimator):
250
  Initialize YOLO pose model.
251
 
252
  Args:
253
- model_version: "v8" or "v11"
254
  model_size: Model size - "n" (nano), "s" (small), "m" (medium), "l" (large), "x" (xlarge)
255
- confidence_threshold: Minimum confidence for detections
256
  """
257
- self.model_version = model_version
258
- self.model_size = model_size
259
- self.confidence_threshold = confidence_threshold
260
  self.model = None
261
 
262
  # Determine model path
263
- if model_version == "v8":
264
- self.model_path = f"yolov8{model_size}-pose.pt"
265
- else: # v11
266
- self.model_path = f"yolo11{model_size}-pose.pt"
 
 
267
 
268
  self._load_model()
269
 
@@ -282,34 +285,53 @@ class YOLOPoseEstimator(PoseEstimator):
282
  self._load_model()
283
 
284
  # Run inference
285
- results = self.model(frame, conf=self.confidence_threshold, iou=0.7)
 
286
 
287
  pose_results = []
 
288
 
289
- # Process each detection
290
- for r in results:
291
- if r.keypoints is not None:
292
- for person_idx, keypoints_data in enumerate(r.keypoints.data):
293
- keypoints = []
 
 
294
 
295
- # YOLO returns keypoints as [x, y, conf]
296
- height, width = frame.shape[:2]
297
- for i, (x, y, conf) in enumerate(keypoints_data):
298
  # Sanitize NaN values
299
- if any(map(math.isnan, [x, y, conf])):
300
  continue
 
 
 
 
 
 
 
 
 
 
 
 
301
  keypoints.append(Keypoint(
302
- x=float(x) / width, # Normalize to 0-1
303
- y=float(y) / height, # Normalize to 0-1
304
- confidence=float(conf),
305
  name=self.KEYPOINT_NAMES[i] if i < len(self.KEYPOINT_NAMES) else f"joint_{i}"
306
  ))
307
 
308
  if keypoints:
 
 
 
 
309
  pose_results.append(PoseResult(
310
  keypoints=keypoints,
311
- frame_index=0,
312
- person_id=person_idx
313
  ))
314
 
315
  return pose_results
@@ -376,18 +398,42 @@ def get_pose_estimator(model_spec: str) -> PoseEstimator:
376
  # MoveNet variants
377
  elif model_spec.startswith("movenet"):
378
  variant = "lightning" if "lightning" in model_spec else "thunder"
 
 
379
  return create_pose_estimator("movenet", model_variant=variant)
380
 
381
  # YOLO variants
382
  elif model_spec.startswith("yolo"):
383
  parts = model_spec.split("-")
384
- version = "v8" if "v8" in model_spec else "v11"
385
- size = parts[-1] if len(parts) > 2 else "n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  return create_pose_estimator("yolo", model_version=version, model_size=size)
387
 
388
- # Legacy format support
389
  else:
390
- return create_pose_estimator(model_spec)
 
 
 
391
 
392
 
393
  def _safe_pose_from_dets(dets: List[PoseResult], frame_idx: int) -> List[PoseResult]:
@@ -396,29 +442,14 @@ def _safe_pose_from_dets(dets: List[PoseResult], frame_idx: int) -> List[PoseRes
396
  After the loop, interpolate missing poses in pose_seq before running metrics.
397
  Add debug prints when a pose is missing and when interpolation is performed.
398
  """
 
 
 
399
  safe_poses = []
400
  missing_mask = []
401
  prev_pose = None
402
 
403
- for det in dets:
404
- if det.frame_index == frame_idx:
405
- if prev_pose is None:
406
- print(f"Warning: No previous pose found for frame {frame_idx}")
407
- safe_poses.append(PoseResult(keypoints=[], frame_index=frame_idx))
408
- missing_mask.append(True)
409
- else:
410
- safe_poses.append(PoseResult(keypoints=prev_pose.keypoints, frame_index=frame_idx))
411
- missing_mask.append(False)
412
- prev_pose = det
413
- elif det.frame_index > frame_idx:
414
- break
415
-
416
- if prev_pose is None:
417
- print(f"Warning: No poses found for frame {frame_idx}")
418
- safe_poses.append(PoseResult(keypoints=[], frame_index=frame_idx))
419
- missing_mask.append(True)
420
- else:
421
- safe_poses.append(PoseResult(keypoints=prev_pose.keypoints, frame_index=frame_idx))
422
- missing_mask.append(False)
423
-
424
- return safe_poses, missing_mask
 
146
  confidence=float(score),
147
  name=self.KEYPOINT_NAMES[i]
148
  ))
149
+ if keypoints: # MoveNet is single-pose, so only one result if any
150
  return [PoseResult(keypoints=keypoints, frame_index=0)]
151
  else:
152
+ return [] # No pose detected or all keypoints were NaN
153
 
154
  def get_keypoint_names(self) -> List[str]:
155
  return self.KEYPOINT_NAMES.copy()
 
190
  import mediapipe as mp
191
  self.mp_pose = mp.solutions.pose
192
  self.pose = self.mp_pose.Pose(
193
+ static_image_mode=False, # Process video frames
194
  model_complexity=self.model_complexity,
195
  min_detection_confidence=self.min_detection_confidence,
196
+ min_tracking_confidence=0.5 # Default from MediaPipe
197
  )
198
  except ImportError:
199
  raise ImportError("MediaPipe required. Install with: pip install mediapipe")
 
219
  keypoints.append(Keypoint(
220
  x=landmark.x,
221
  y=landmark.y,
222
+ confidence=landmark.visibility if hasattr(landmark, 'visibility') else 1.0, # Use visibility as confidence
223
  name=self.LANDMARK_NAMES[i] if i < len(self.LANDMARK_NAMES) else f"landmark_{i}"
224
  ))
225
 
226
+ # MediaPipe Pose API typically returns one pose per image in this configuration
227
  return [PoseResult(keypoints=keypoints, frame_index=0)]
228
 
229
  def get_keypoint_names(self) -> List[str]:
 
251
  Initialize YOLO pose model.
252
 
253
  Args:
254
+ model_version: "v8" or "v11" (Note: v11 is hypothetical here as Ultralytics primarily focuses on v8, v9, etc.)
255
  model_size: Model size - "n" (nano), "s" (small), "m" (medium), "l" (large), "x" (xlarge)
256
+ confidence_threshold: Minimum confidence for person detections (not individual keypoints)
257
  """
258
+ self.model_version = model_version.lower()
259
+ self.model_size = model_size.lower()
260
+ self.confidence_threshold = confidence_threshold # This is for the main object detection
261
  self.model = None
262
 
263
  # Determine model path
264
+ if self.model_version == "v8":
265
+ self.model_path = f"yolov8{self.model_size}-pose.pt"
266
+ elif self.model_version == "v11": # Assuming v11 follows a similar naming, adjust if official names differ
267
+ self.model_path = f"yolov11{self.model_size}-pose.pt" # This might be a placeholder if v11 isn't standard Ultralytics
268
+ else:
269
+ raise ValueError(f"Unsupported YOLO version: {model_version}")
270
 
271
  self._load_model()
272
 
 
285
  self._load_model()
286
 
287
  # Run inference
288
+ # conf is for person detection; keypoint confidences are separate
289
+ results = self.model(frame, conf=self.confidence_threshold, iou=0.7)
290
 
291
  pose_results = []
292
+ height, width = frame.shape[:2]
293
 
294
+ # Process each detection result (Ultralytics Results object)
295
+ for r_idx, r in enumerate(results):
296
+ if r.keypoints is not None and hasattr(r.keypoints, 'data'):
297
+ # r.keypoints.data is a tensor of shape (num_persons, num_keypoints, 3)
298
+ # The last dimension is [x_pixel, y_pixel, confidence_keypoint]
299
+ for person_idx, keypoints_data_tensor in enumerate(r.keypoints.data):
300
+ keypoints_list_for_person = keypoints_data_tensor.cpu().tolist() # Convert tensor to list
301
 
302
+ keypoints = []
303
+ for i, (x_pixel, y_pixel, kp_conf) in enumerate(keypoints_list_for_person):
 
304
  # Sanitize NaN values
305
+ if any(map(math.isnan, [x_pixel, y_pixel, kp_conf])):
306
  continue
307
+
308
+ current_confidence = float(kp_conf)
309
+
310
+ # According to Ultralytics/COCO, missing keypoints are often (0,0) with conf 0.
311
+ # If (0,0) pixel coords are returned with non-zero confidence by the model,
312
+ # it might be an artifact or a misinterpretation.
313
+ # We will reduce confidence for (0,0) pixel points if their original confidence isn't extremely high,
314
+ # to help filter them in downstream tasks (visualization, analysis).
315
+ if float(x_pixel) == 0.0 and float(y_pixel) == 0.0 and current_confidence < 0.9:
316
+ # Threshold 0.9 is arbitrary, means "only trust (0,0) if model is super sure"
317
+ current_confidence = 0.0
318
+
319
  keypoints.append(Keypoint(
320
+ x=float(x_pixel) / width if width > 0 else 0.0, # Normalize
321
+ y=float(y_pixel) / height if height > 0 else 0.0, # Normalize
322
+ confidence=current_confidence,
323
  name=self.KEYPOINT_NAMES[i] if i < len(self.KEYPOINT_NAMES) else f"joint_{i}"
324
  ))
325
 
326
  if keypoints:
327
+ # Create a unique person ID if not available from tracker (e.g. r.boxes.id)
328
+ # For simplicity, using r_idx (result index) and person_idx (index within this result)
329
+ # This might not be persistent across frames without a tracker.
330
+ unique_person_id = person_idx # Or a more robust ID if tracking is used
331
  pose_results.append(PoseResult(
332
  keypoints=keypoints,
333
+ frame_index=0, # Will be updated by detect_batch
334
+ person_id=unique_person_id
335
  ))
336
 
337
  return pose_results
 
398
  # MoveNet variants
399
  elif model_spec.startswith("movenet"):
400
  variant = "lightning" if "lightning" in model_spec else "thunder"
401
+ if "lightning" not in model_spec and "thunder" not in model_spec: # e.g. "movenet"
402
+ variant = "lightning" # Default MoveNet to lightning
403
  return create_pose_estimator("movenet", model_variant=variant)
404
 
405
  # YOLO variants
406
  elif model_spec.startswith("yolo"):
407
  parts = model_spec.split("-")
408
+ # yolo-v8-n -> parts = ["yolo", "v8", "n"]
409
+ # yolo -> parts = ["yolo"] -> default to v8-n
410
+ version = "v8" # default version
411
+ size = "n" # default size
412
+
413
+ if len(parts) > 1: # "yolo-v8" or "yolo-v11"
414
+ if parts[1] in ["v8", "v11"]: # Add other versions as needed
415
+ version = parts[1]
416
+ # If parts[1] is a size (e.g. "yolo-n"), then version remains default "v8" and size is parts[1]
417
+ elif parts[1] in ["n", "s", "m", "l", "x"]:
418
+ size = parts[1]
419
+
420
+ if len(parts) > 2: # "yolo-v8-n"
421
+ if parts[2] in ["n", "s", "m", "l", "x"]:
422
+ size = parts[2]
423
+
424
+ # Handle case like "yolo-s" where version is implied as v8
425
+ if len(parts) == 2 and parts[1] in ["n","s","m","l","x"]:
426
+ version = "v8" # Default to v8 if only size is specified after "yolo-"
427
+ size = parts[1]
428
+
429
  return create_pose_estimator("yolo", model_version=version, model_size=size)
430
 
431
+ # Legacy format support or direct name
432
  else:
433
+ try:
434
+ return create_pose_estimator(model_spec)
435
+ except ValueError: # If model_spec isn't a direct key like "mediapipe", "movenet", "yolo"
436
+ raise ValueError(f"Invalid or unsupported model specification: {model_spec}")
437
 
438
 
439
  def _safe_pose_from_dets(dets: List[PoseResult], frame_idx: int) -> List[PoseResult]:
 
442
  After the loop, interpolate missing poses in pose_seq before running metrics.
443
  Add debug prints when a pose is missing and when interpolation is performed.
444
  """
445
+ # This function is currently not used in the provided codebase.
446
+ # If it were to be used, it would need proper integration.
447
+ print(f"[DEBUG] _safe_pose_from_dets called for frame {frame_idx}, but is not currently integrated.")
448
  safe_poses = []
449
  missing_mask = []
450
  prev_pose = None
451
 
452
+ # This logic seems flawed for its intended purpose without further context or modification.
453
+ # For now, returning empty or passed 'dets' might be safer if it's not fully implemented.
454
+ # Returning dets as is, since the function is not used.
455
+ return dets, []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/mediapipe.json ADDED
The diff for this file is too large to render. See raw diff
 
examples/movenet.json ADDED
The diff for this file is too large to render. See raw diff
 
examples/yolov11.json ADDED
The diff for this file is too large to render. See raw diff
 
examples/yolov8.json ADDED
The diff for this file is too large to render. See raw diff