Simon9 commited on
Commit
421b656
·
verified ·
1 Parent(s): d31a3f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -50,6 +50,15 @@ CLIENT = InferenceHTTPClient(
50
  PLAYER_DETECTION_MODEL_ID = "football-players-detection-3zvbc/11"
51
  FIELD_DETECTION_MODEL_ID = "football-field-detection-f07vi/14"
52
 
 
 
 
 
 
 
 
 
 
53
  # ==============================================
54
  # SIGLIP MODEL (Embeddings)
55
  # ==============================================
@@ -535,8 +544,7 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
535
  break
536
 
537
  if frame_count % STRIDE == 0:
538
- result = CLIENT.infer(frame, model_id=PLAYER_DETECTION_MODEL_ID)
539
- detections = sv.Detections.from_inference(result)
540
  detections = detections.with_nms(threshold=0.5, class_agnostic=True)
541
  players_detections = detections[detections.class_id == PLAYER_ID]
542
 
@@ -587,8 +595,7 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
587
  desc=f"🎬 Processing frame {frame_count}/{total_frames}")
588
 
589
  # Player and ball detection
590
- result = CLIENT.infer(frame, model_id=PLAYER_DETECTION_MODEL_ID, confidence=0.3)
591
- detections = sv.Detections.from_inference(result)
592
 
593
  if len(detections.xyxy) == 0:
594
  out.write(frame)
@@ -653,7 +660,7 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
653
  # STEP 4: Field Detection & Transformation
654
  # ========================================
655
  try:
656
- result_field = CLIENT.infer(frame, model_id=FIELD_DETECTION_MODEL_ID, confidence=0.3)
657
  key_points = sv.KeyPoints.from_inference(result_field)
658
 
659
  # Filter confident keypoints
 
50
  PLAYER_DETECTION_MODEL_ID = "football-players-detection-3zvbc/11"
51
  FIELD_DETECTION_MODEL_ID = "football-field-detection-f07vi/14"
52
 
53
+ def infer_with_confidence(model_id: str, frame: np.ndarray, confidence_threshold: float = 0.3):
54
+ """Run inference and filter by confidence threshold"""
55
+ result = CLIENT.infer(frame, model_id=model_id)
56
+ detections = sv.Detections.from_inference(result)
57
+ # Filter by confidence
58
+ if len(detections) > 0:
59
+ detections = detections[detections.confidence > confidence_threshold]
60
+ return result, detections
61
+
62
  # ==============================================
63
  # SIGLIP MODEL (Embeddings)
64
  # ==============================================
 
544
  break
545
 
546
  if frame_count % STRIDE == 0:
547
+ _, detections = infer_with_confidence(PLAYER_DETECTION_MODEL_ID, frame, 0.3)
 
548
  detections = detections.with_nms(threshold=0.5, class_agnostic=True)
549
  players_detections = detections[detections.class_id == PLAYER_ID]
550
 
 
595
  desc=f"🎬 Processing frame {frame_count}/{total_frames}")
596
 
597
  # Player and ball detection
598
+ _, detections = infer_with_confidence(PLAYER_DETECTION_MODEL_ID, frame, 0.3)
 
599
 
600
  if len(detections.xyxy) == 0:
601
  out.write(frame)
 
660
  # STEP 4: Field Detection & Transformation
661
  # ========================================
662
  try:
663
+ result_field, _ = infer_with_confidence(FIELD_DETECTION_MODEL_ID, frame, 0.3)
664
  key_points = sv.KeyPoints.from_inference(result_field)
665
 
666
  # Filter confident keypoints