MTerryJack commited on
Commit
89bbcc2
·
verified ·
1 Parent(s): 4d99c28

Upload miner.py

Browse files
Files changed (1) hide show
  1. miner.py +139 -0
miner.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from ultralytics import YOLO
4
+ from numpy import ndarray
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class BoundingBox(BaseModel):
9
+ x1: int
10
+ y1: int
11
+ x2: int
12
+ y2: int
13
+ cls_id: int
14
+ conf: float
15
+
16
+ class TVFrameResult(BaseModel):
17
+ frame_id: int
18
+ boxes: list[BoundingBox]
19
+ keypoints: list[tuple[int, int]]
20
+
21
+
22
+ class Miner:
23
+ """
24
+ This class is responsible for:
25
+ - Loading ML models.
26
+ - Running batched predictions on images.
27
+ - Parsing ML model outputs into structured results (TVFrameResult).
28
+
29
+ This class can be modified, but it must have the following to be compatible with the chute:
30
+ - be named `Miner`
31
+ - have a `predict_batch` function with the inputs and outputs specified
32
+ - be stored in a file called `miner.py` which lives in the root of the HFHub repo
33
+ """
34
+
35
+ def __init__(self, path_hf_repo:Path) -> None:
36
+ """
37
+ Loads all ML models from the repository.
38
+ -----(Adjust as needed)----
39
+
40
+ Args:
41
+ path_hf_repo (Path):
42
+ Path to the downloaded HuggingFace Hub repository
43
+
44
+ Returns:
45
+ None
46
+ """
47
+ self.bbox_model = YOLO(path_hf_repo/"football-player-detection.pt")
48
+ print(f"✅ BBox Model Loaded")
49
+ self.keypoints_model = YOLO(path_hf_repo/"football-pitch-detection.pt")
50
+ print(f"✅ Keypoints Model Loaded")
51
+
52
+
53
+ def __repr__(self) -> str:
54
+ """
55
+ Information about miner returned in the health endpoint
56
+ to inspect the loaded ML models (and their types)
57
+ -----(Adjust as needed)----
58
+ """
59
+ return f"BBox Model: {type(self.bbox_model).__name__}\nKeypoints Model: {type(self.keypoints_model).__name__}"
60
+
61
+ def predict_batch(
62
+ self,
63
+ batch_images: list[ndarray],
64
+ offset: int,
65
+ n_keypoints: int,
66
+ ) -> list[TVFrameResult]:
67
+ """
68
+ Miner prediction for a batch of images.
69
+ Handles the orchestration of ML models and any preprocessing and postprocessing
70
+ -----(Adjust as needed)----
71
+
72
+ Args:
73
+ batch_images (list[np.ndarray]):
74
+ A list of images (as NumPy arrays) to process in this batch.
75
+ offset (int):
76
+ The frame number corresponding to the first image in the batch.
77
+ Used to correctly index frames in the output results.
78
+ n_keypoints (int):
79
+ The number of keypoints expected for each frame in this challenge type.
80
+
81
+ Returns:
82
+ list[TVFrameResult]:
83
+ A list of predictions for each image in the batch
84
+ """
85
+
86
+ bboxes: dict[int, list[BoundingBox]] = {}
87
+ bbox_model_results = self.bbox_model.predict(batch_images)
88
+ if bbox_model_results is not None:
89
+ for frame_number_in_batch, detection in enumerate(bbox_model_results):
90
+ if not hasattr(detection, "boxes") or detection.boxes is None:
91
+ continue
92
+ boxes = []
93
+ for box in detection.boxes.data:
94
+ x1, y1, x2, y2, conf, cls_id = box.tolist()
95
+ boxes.append(
96
+ BoundingBox(
97
+ x1=int(x1),
98
+ y1=int(y1),
99
+ x2=int(x2),
100
+ y2=int(y2),
101
+ cls_id=int(cls_id),
102
+ conf=float(conf),
103
+ )
104
+ )
105
+ bboxes[offset + frame_number_in_batch] = boxes
106
+ print("✅ BBoxes predicted")
107
+
108
+ keypoints: dict[int, tuple[int, int]] = {}
109
+ keypoints_model_results = self.keypoints_model.predict(batch_images)
110
+ if keypoints_model_results is not None:
111
+ for frame_number_in_batch, detection in enumerate(keypoints_model_results):
112
+ if not hasattr(detection, "keypoints") or detection.keypoints is None:
113
+ continue
114
+ frame_keypoints: list[tuple[int, int]] = []
115
+ for person_points in detection.keypoints.data:
116
+ for x, y in person_points:
117
+ frame_keypoints.append((int(x), int(y)))
118
+ if len(frame_keypoints) < n_keypoints:
119
+ frame_keypoints.extend(
120
+ [(0, 0)] * (n_keypoints - len(frame_keypoints))
121
+ )
122
+ else:
123
+ frame_keypoints = frame_keypoints[:n_keypoints]
124
+ keypoints[offset + frame_number_in_batch] = frame_keypoints
125
+ print("✅ Keypoints predicted")
126
+
127
+ results: list[TVFrameResult] = []
128
+ for frame_number in range(offset, offset + len(batch_images)):
129
+ results.append(
130
+ TVFrameResult(
131
+ frame_id=frame_number,
132
+ boxes=bboxes.get(frame_number, []),
133
+ keypoints=keypoints.get(
134
+ frame_number, [(0, 0) for _ in range(n_keypoints)]
135
+ ),
136
+ )
137
+ )
138
+ print("✅ Combined results as TVFrameResult")
139
+ return results