Gege24 commited on
Commit
c46342c
·
verified ·
1 Parent(s): fc9d334

scorevision: push artifact

Browse files
Files changed (1) hide show
  1. miner.py +126 -0
miner.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from ultralytics import YOLO
4
+ from numpy import ndarray
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class BoundingBox(BaseModel):
9
+ x1: int
10
+ y1: int
11
+ x2: int
12
+ y2: int
13
+ cls_id: int
14
+ conf: float
15
+
16
+
17
+ class TVFrameResult(BaseModel):
18
+ frame_id: int
19
+ boxes: list[BoundingBox]
20
+ keypoints: list[tuple[int, int]]
21
+
22
+
23
+ class Miner:
24
+ """
25
+ This class is responsible for:
26
+ - Loading ML models.
27
+ - Running batched predictions on images.
28
+ - Parsing ML model outputs into structured results (TVFrameResult).
29
+
30
+ MODIFIED FOR TESTING: Uses standard yolov8n.pt and yolov8n-pose.pt
31
+ """
32
+
33
+ def __init__(self, path_hf_repo: Path) -> None:
34
+ """
35
+ Loads all ML models.
36
+ """
37
+ # Using standard YOLOv8 nano models that will be automatically downloaded
38
+ # if not present. This avoids the need for custom .pt files for testing.
39
+ self.bbox_model = YOLO("yolov8n.pt")
40
+ print(f"✅ BBox Model Loaded (yolov8n)")
41
+
42
+ self.keypoints_model = YOLO("yolov8n-pose.pt")
43
+ print(f"✅ Keypoints Model Loaded (yolov8n-pose)")
44
+
45
+ def __repr__(self) -> str:
46
+ return f"BBox Model: {type(self.bbox_model).__name__}\nKeypoints Model: {type(self.keypoints_model).__name__}"
47
+
48
+ def predict_batch(
49
+ self,
50
+ batch_images: list[ndarray],
51
+ offset: int,
52
+ n_keypoints: int,
53
+ ) -> list[TVFrameResult]:
54
+ """
55
+ Miner prediction for a batch of images.
56
+ """
57
+
58
+ bboxes: dict[int, list[BoundingBox]] = {}
59
+ # Run BBox prediction
60
+ bbox_model_results = self.bbox_model.predict(batch_images, verbose=False)
61
+
62
+ if bbox_model_results is not None:
63
+ for frame_number_in_batch, detection in enumerate(bbox_model_results):
64
+ if not hasattr(detection, "boxes") or detection.boxes is None:
65
+ continue
66
+ boxes = []
67
+ for box in detection.boxes.data:
68
+ # YOLOv8 standard output: x1, y1, x2, y2, conf, cls
69
+ x1, y1, x2, y2, conf, cls_id = box.tolist()
70
+ boxes.append(
71
+ BoundingBox(
72
+ x1=int(x1),
73
+ y1=int(y1),
74
+ x2=int(x2),
75
+ y2=int(y2),
76
+ cls_id=int(cls_id),
77
+ conf=float(conf),
78
+ )
79
+ )
80
+ bboxes[offset + frame_number_in_batch] = boxes
81
+ print("✅ BBoxes predicted")
82
+
83
+ keypoints: dict[int, tuple[int, int]] = {}
84
+ # Run Pose/Keypoints prediction
85
+ keypoints_model_results = self.keypoints_model.predict(batch_images, verbose=False)
86
+
87
+ if keypoints_model_results is not None:
88
+ for frame_number_in_batch, detection in enumerate(keypoints_model_results):
89
+ if not hasattr(detection, "keypoints") or detection.keypoints is None:
90
+ continue
91
+
92
+ frame_keypoints: list[tuple[int, int]] = []
93
+
94
+ # Check if keypoints data exists and has the expected shape/content
95
+ if detection.keypoints.data is not None and len(detection.keypoints.data) > 0:
96
+ # Taking the first person detected for keypoints (simplification for testing)
97
+ # YOLO pose output is typically [num_people, num_kpts, 3] (x, y, conf)
98
+ first_person_kpts = detection.keypoints.data[0]
99
+ for kpt in first_person_kpts:
100
+ x, y = kpt[0], kpt[1] # extracting x, y
101
+ frame_keypoints.append((int(x), int(y)))
102
+
103
+ # Padding or truncating to match expected n_keypoints
104
+ if len(frame_keypoints) < n_keypoints:
105
+ frame_keypoints.extend(
106
+ [(0, 0)] * (n_keypoints - len(frame_keypoints))
107
+ )
108
+ else:
109
+ frame_keypoints = frame_keypoints[:n_keypoints]
110
+
111
+ keypoints[offset + frame_number_in_batch] = frame_keypoints
112
+ print("✅ Keypoints predicted")
113
+
114
+ results: list[TVFrameResult] = []
115
+ for frame_number in range(offset, offset + len(batch_images)):
116
+ results.append(
117
+ TVFrameResult(
118
+ frame_id=frame_number,
119
+ boxes=bboxes.get(frame_number, []),
120
+ keypoints=keypoints.get(
121
+ frame_number, [(0, 0) for _ in range(n_keypoints)]
122
+ ),
123
+ )
124
+ )
125
+ print("✅ Combined results as TVFrameResult")
126
+ return results