EiMon724 commited on
Commit
1365e57
·
verified ·
1 Parent(s): 45f7cd2

scorevision: push artifact

Browse files
Files changed (1) hide show
  1. miner.py +175 -0
miner.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ # NOTE:
4
+ # - This is copied from `example_miner/miner.py` as a starting point.
5
+ # - This version shows how to use a SAM-style segmentation model as your detector.
6
+ # - SAM gives masks (segmentation). This subnet expects boxes, so we convert masks -> boxes.
7
+ # - SAM does NOT give 32 pitch keypoints; you likely need a separate keypoint model.
8
+
9
+ import os
10
+ from typing import Any
11
+
12
+ import cv2
13
+ import numpy as np
14
+ from numpy import ndarray
15
+ from pydantic import BaseModel
16
+
17
+
18
+ class BoundingBox(BaseModel):
19
+ x1: int
20
+ y1: int
21
+ x2: int
22
+ y2: int
23
+ cls_id: int
24
+ conf: float
25
+
26
+
27
+ class TVFrameResult(BaseModel):
28
+ frame_id: int
29
+ boxes: list[BoundingBox]
30
+ keypoints: list[tuple[int, int]]
31
+
32
+
33
+ class Miner:
34
+ """
35
+ Your miner engine.
36
+
37
+ Requirements (must keep):
38
+ - file name: `miner.py` (repo root)
39
+ - class name: `Miner`
40
+ - method: `predict_batch(batch_images, offset, n_keypoints)`
41
+ """
42
+
43
+ def __init__(self, path_hf_repo: Path) -> None:
44
+ """
45
+ Load your models from the HuggingFace repo snapshot directory.
46
+
47
+ For SAM-based detection:
48
+ - Put your SAM checkpoint file in this repo folder (same folder as miner.py)
49
+ - Set SAM_CHECKPOINT env var (optional) to choose the filename.
50
+ """
51
+ self.path_hf_repo = path_hf_repo
52
+
53
+ # ---------------- SAM setup ----------------
54
+ # IMPORTANT: "SAM 3" can mean different things. This skeleton uses the common
55
+ # Segment Anything API shape (sam_model_registry + SamAutomaticMaskGenerator).
56
+ # If your SAM3 is different, keep the structure and replace the loading/inference.
57
+ ckpt_name = os.getenv("SAM_CHECKPOINT", "sam_vit_h_4b8939.pth")
58
+ ckpt_path = (path_hf_repo / ckpt_name).resolve()
59
+ if not ckpt_path.is_file():
60
+ raise FileNotFoundError(
61
+ f"SAM checkpoint not found: {ckpt_path}. Put the checkpoint in your HF repo "
62
+ f"and/or set SAM_CHECKPOINT to the correct filename."
63
+ )
64
+
65
+ model_type = os.getenv("SAM_MODEL_TYPE", "vit_h") # vit_h / vit_l / vit_b (depends on checkpoint)
66
+ try:
67
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
68
+ except Exception as e:
69
+ raise ImportError(
70
+ "segment-anything is not installed in the Chutes image. "
71
+ "Add it to chute_config.yml (pip install segment-anything)."
72
+ ) from e
73
+
74
+ device = "cuda" if os.getenv("CUDA_VISIBLE_DEVICES") else "cpu"
75
+ self.sam = sam_model_registry[model_type](checkpoint=str(ckpt_path))
76
+ self.sam.to(device=device)
77
+
78
+ # Tunables: lower points_per_side => faster, fewer masks.
79
+ self.mask_generator = SamAutomaticMaskGenerator(
80
+ self.sam,
81
+ points_per_side=int(os.getenv("SAM_POINTS_PER_SIDE", "16")),
82
+ pred_iou_thresh=float(os.getenv("SAM_PRED_IOU_THRESH", "0.88")),
83
+ stability_score_thresh=float(os.getenv("SAM_STABILITY_THRESH", "0.90")),
84
+ min_mask_region_area=int(os.getenv("SAM_MIN_REGION_AREA", "200")),
85
+ )
86
+
87
+ # ---------------- Keypoints ----------------
88
+ # Placeholder: output all zeros unless you add a keypoint detector.
89
+ self.enable_keypoints = os.getenv("ENABLE_KEYPOINTS", "0").lower() in ("1", "true", "yes")
90
+ self._kp_model: Any | None = None
91
+ # If you have a keypoint model, load it here from path_hf_repo.
92
+
93
+ def __repr__(self) -> str:
94
+ return (
95
+ f"SAM: {type(self.sam).__name__}\n"
96
+ f"Keypoints enabled: {self.enable_keypoints}"
97
+ )
98
+
99
+ def predict_batch(
100
+ self,
101
+ batch_images: list[ndarray],
102
+ offset: int,
103
+ n_keypoints: int,
104
+ ) -> list[TVFrameResult]:
105
+ # ------------------ Boxes (SAM masks -> boxes) ------------------
106
+ # SAM returns masks for "things" but does not label them (player/ref/ball).
107
+ # For a first working miner, we mark everything as "player" (cls_id=2).
108
+ # To score well, you will later need classification (ball/ref/goalkeeper/team).
109
+ bboxes: dict[int, list[BoundingBox]] = {}
110
+
111
+ for i, img in enumerate(batch_images):
112
+ frame_id = offset + i
113
+
114
+ # Convert BGR(OpenCV) -> RGB(SAM)
115
+ if img is None:
116
+ bboxes[frame_id] = []
117
+ continue
118
+ rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
119
+
120
+ masks = self.mask_generator.generate(rgb) # list[dict]
121
+
122
+ # Filter out giant masks (often the grass/background) and tiny noise.
123
+ H, W = rgb.shape[:2]
124
+ area_frame = float(H * W)
125
+ out_boxes: list[BoundingBox] = []
126
+ for m in masks:
127
+ # segment-anything returns bbox as [x, y, w, h]
128
+ x, y, w, h = m.get("bbox") or (0, 0, 0, 0)
129
+ x1, y1 = int(x), int(y)
130
+ x2, y2 = int(x + w), int(y + h)
131
+ if x2 <= x1 or y2 <= y1:
132
+ continue
133
+
134
+ box_area = float((x2 - x1) * (y2 - y1))
135
+ if box_area < float(os.getenv("MIN_BOX_AREA", "250")):
136
+ continue
137
+ if box_area / area_frame > float(os.getenv("MAX_BOX_AREA_FRAC", "0.25")):
138
+ continue
139
+
140
+ conf = float(m.get("predicted_iou") or 0.5)
141
+ out_boxes.append(
142
+ BoundingBox(
143
+ x1=x1,
144
+ y1=y1,
145
+ x2=x2,
146
+ y2=y2,
147
+ cls_id=2, # default: player
148
+ conf=conf,
149
+ )
150
+ )
151
+
152
+ bboxes[frame_id] = out_boxes
153
+
154
+ # ---------------- Keypoints (length = n_keypoints) ----------------
155
+ keypoints: dict[int, list[tuple[int, int]]] = {}
156
+ # Placeholder (zeros). Replace with your own keypoint detector when ready.
157
+ for i in range(len(batch_images)):
158
+ frame_id = offset + i
159
+ keypoints[frame_id] = [(0, 0) for _ in range(n_keypoints)]
160
+
161
+ # ---------------- Combine ------------------
162
+ results: list[TVFrameResult] = []
163
+ for frame_number in range(offset, offset + len(batch_images)):
164
+ results.append(
165
+ TVFrameResult(
166
+ frame_id=frame_number,
167
+ boxes=bboxes.get(frame_number, []),
168
+ keypoints=keypoints.get(
169
+ frame_number, [(0, 0) for _ in range(n_keypoints)]
170
+ ),
171
+ )
172
+ )
173
+ return results
174
+
175
+