poolay2 commited on
Commit
68a34d4
·
verified ·
1 Parent(s): 6eb0dec

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -347
utils.py DELETED
@@ -1,347 +0,0 @@
1
- from __future__ import annotations
2
- import torch
3
- import numpy as np
4
- import supervision as sv
5
- from pycocotools import mask as mask_utils
6
- import cv2
7
- import ffmpeg
8
- from PIL import Image
9
- import numpy as np
10
- from typing import List, Iterable
11
- from matplotlib import pyplot as plt
12
-
13
- class SAM2Tracker:
14
- def __init__(self, predictor) -> None:
15
- self.predictor = predictor
16
- self._prompted = False
17
-
18
- def prompt_first_frame(self, frame: np.ndarray, detections: sv.Detections) -> None:
19
- if len(detections) == 0:
20
- raise ValueError("detections must contain at least one box")
21
-
22
- if detections.tracker_id is None:
23
- detections.tracker_id = list(range(1, len(detections) + 1))
24
-
25
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
26
- self.predictor.load_first_frame(frame)
27
- for xyxy, obj_id in zip(detections.xyxy, detections.tracker_id):
28
- bbox = np.asarray([xyxy], dtype=np.float32)
29
- self.predictor.add_new_prompt(
30
- frame_idx=0,
31
- obj_id=int(obj_id),
32
- bbox=bbox,
33
- )
34
-
35
- self._prompted = True
36
-
37
- def propagate(self, frame: np.ndarray) -> sv.Detections:
38
- if not self._prompted:
39
- raise RuntimeError("Call prompt_first_frame before propagate")
40
-
41
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
42
- tracker_ids, mask_logits = self.predictor.track(frame)
43
-
44
- tracker_ids = np.asarray(tracker_ids, dtype=np.int32)
45
- masks = (mask_logits > 0.0).cpu().numpy()
46
- masks = np.squeeze(masks).astype(bool)
47
-
48
- if masks.ndim == 2:
49
- masks = masks[None, ...]
50
-
51
- masks = np.array([
52
- sv.filter_segments_by_distance(mask, relative_distance=0.03, mode="edge")
53
- for mask in masks
54
- ])
55
-
56
- xyxy = sv.mask_to_xyxy(masks=masks)
57
- detections = sv.Detections(xyxy=xyxy, mask=masks, tracker_id=tracker_ids)
58
- return detections
59
-
60
- def reset(self) -> None:
61
- self._prompted = False
62
-
63
- def get_crops_from_masks(frame: np.ndarray, masks: np.ndarray) -> list[np.ndarray]:
64
- """
65
- Args:mask_index
66
- frame: (H, W, 3) image
67
- masks: (N, H, W) binary masks
68
-
69
- Returns:
70
- List of cropped images, one per mask. Each crop is a rectangular
71
- bounding box around the mask, with black pixels outside the mask.
72
- """
73
- crops = []
74
-
75
- for mask in masks:
76
-
77
- # Find bounding box of the mask
78
- ys, xs = np.where(mask)
79
- if len(xs) == 0 or len(ys) == 0:
80
- # Empty mask → skip or return empty crop
81
- crops.append(np.zeros((0, 0, 3), dtype=frame.dtype))
82
- continue
83
-
84
- y_min, y_max = ys.min(), ys.max() + 1
85
- x_min, x_max = xs.min(), xs.max() + 1
86
-
87
- # Crop the frame and mask
88
- frame_crop = frame[y_min:y_max, x_min:x_max]
89
- mask_crop = mask[y_min:y_max, x_min:x_max]
90
-
91
- # Apply mask: keep pixels where mask is True, else black
92
- crop = np.zeros_like(frame_crop)
93
- crop[mask_crop] = frame_crop[mask_crop]
94
-
95
- crops.append(crop)
96
-
97
- return crops
98
-
99
- def f(detections: sv.Detections, track_history: dict, frame_index):
100
-
101
- for i in range(len(detections)):
102
-
103
- mask = detections.mask[i]
104
- rle = mask_utils.encode(np.asfortranarray(mask))
105
- track_history[int(detections.tracker_id[i])].append((frame_index, rle['counts']))
106
-
107
-
108
- def toRGB(img: np.ndarray):
109
- return cv2.cvtColor(img, code=cv2.COLOR_BGR2RGB)
110
-
111
- def read_frame_from_video(in_filename, frame_num):
112
- raw_bytes, err = (
113
- ffmpeg
114
- .input(in_filename)
115
- .filter('select', 'gte(n,{})'.format(frame_num))
116
- .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
117
- .global_args('-loglevel', 'error')
118
- .run(capture_stdout=True)
119
- )
120
- assert len(raw_bytes) == 1080 * 1920 * 3
121
- return np.frombuffer(raw_bytes, np.uint8).reshape(1, 1080, 1920, 3).copy()
122
-
123
- def read_consecutive_frames_from_video(in_filename, start_frame, num_frames) -> np.ndarray:
124
-
125
- out, err = ffmpeg.input(in_filename)\
126
- .output(
127
- 'pipe:1',
128
- vf=f'select=between(n\\,{start_frame}\\,{start_frame + num_frames - 1})',
129
- vsync=0,
130
- vframes=num_frames,
131
- format='rawvideo',
132
- pix_fmt='rgb24'
133
- ).global_args('-loglevel', 'error')\
134
- .run(capture_stdout=True, capture_stderr=True)
135
-
136
- W, H = 1920, 1080
137
- frame_size = W * H * 3
138
- frames = np.frombuffer(out, np.uint8)
139
-
140
- if frames.size != num_frames * frame_size:
141
- raise RuntimeError(
142
- f'Expected {num_frames * frame_size} bytes, got {frames.size}\n'
143
- f'ffmpeg stderr:\n{err.decode()}'
144
- )
145
-
146
- # frames.setflags(write=True)
147
- return frames.reshape(num_frames, H, W, 3).copy()
148
-
149
- def xywhn_to_xywh(xywhn:list, height:int, width:int):
150
-
151
- x,y,w,h = xywhn
152
-
153
- return [int(x * width), int(y * height), int(w * width), int(h * height)]
154
-
155
- def crop_frame_at_mask_from_bbox(frame: np.ndarray, mask: np.ndarray, bbox: list) -> np.array:
156
-
157
- x,y,w,h = bbox
158
- crop = frame[y: y+h, x: x+w]
159
- cropped_mask = mask[y: y+h, x: x+w]
160
- # from code import interact; interact(local=locals())
161
- crop[~cropped_mask] = np.array([0,0,0], dtype=np.uint8)
162
-
163
- return crop
164
-
165
- def find_consecutive_streaks(nums: list|Iterable):
166
-
167
- if isinstance(nums, Iterable): nums = list(nums)
168
- if not nums:
169
- return []
170
-
171
- streaks = []
172
- start = nums[0]
173
- for i in range(1, len(nums)):
174
- if nums[i] != nums[i-1] + 1:
175
- stop = nums[i-1]
176
- streaks.append(range(start, stop + 1))
177
- start = nums[i]
178
-
179
- streaks.append(range(start, nums[-1] + 1))
180
- return streaks
181
-
182
- def save_loss_history(fpath, loss:float):
183
-
184
- with open(fpath, "a+") as f:
185
- f.write(f"{loss:.6f}\n")
186
-
187
- def save_loss_history_plot(loss_history: list[float], fpath):
188
-
189
- plt.plot(loss_history)
190
- plt.savefig(fpath)
191
-
192
- def save_checkpoint(
193
- path,
194
- model,
195
- optimizer,
196
- epoch,
197
- step,
198
- ):
199
-
200
- ckpt = {
201
- "model": model.state_dict(),
202
- "optimizer": optimizer.state_dict(),
203
- "epoch": epoch,
204
- "step": step,
205
- }
206
- torch.save(ckpt, path)
207
-
208
- def load_checkpoint(
209
- path,
210
- model,
211
- optimizer,
212
- device="cuda"
213
- ):
214
- ckpt = torch.load(path, map_location=device)
215
-
216
- model.load_state_dict(ckpt["model"])
217
- optimizer.load_state_dict(ckpt["optimizer"])
218
-
219
- epoch = ckpt.get("epoch", 0)
220
- step = ckpt.get("step", 0)
221
-
222
- return epoch, step
223
-
224
- def mask_iou_pair(m1, m2):
225
- inter = np.logical_and(m1, m2).sum()
226
- if inter == 0:
227
- return 0.0
228
- union = m1.sum() + m2.sum() - inter
229
- return inter / (union + 1e-6)
230
-
231
-
232
- def mask_nms(masks, scores, iou_thresh=0.6):
233
- order = np.argsort(-scores)
234
- keep = []
235
- suppressed = np.zeros(len(masks), dtype=bool)
236
-
237
- for i in order:
238
- if suppressed[i]:
239
- continue
240
-
241
- keep.append(i)
242
-
243
- for j in order:
244
- if j <= i or suppressed[j]:
245
- continue
246
-
247
- iou = mask_iou_pair(masks[i], masks[j])
248
- if iou > iou_thresh:
249
- suppressed[j] = True
250
-
251
- return keep
252
-
253
- def mask_iou(masks_t: np.ndarray, masks_t1):
254
- # Flatten
255
- N, H, W = masks_t.shape
256
- M = masks_t1.shape[0]
257
-
258
- masks_t = masks_t.reshape(N, -1).astype(float) # (N, HW)
259
- masks_t1 = masks_t1.reshape(M, -1).astype(float) # (M, HW)
260
-
261
- # Intersection: (N, M)
262
- intersection = masks_t @ masks_t1.T
263
-
264
- # Areas
265
- area_t = masks_t.sum(1, keepdims=True) # (N, 1)
266
- area_t1 = masks_t1.sum(1, keepdims=True) # (M, 1)
267
-
268
- # Union
269
- union = area_t + area_t1.T - intersection
270
-
271
- iou = intersection / (union + 1e-6)
272
- return iou # (N, M)
273
-
274
- COURT_KEYPOINT_COORDINATES = np.array([
275
- (0.0, 0.0),
276
- (0.0, 2.99),
277
- (0.0, 17.0),
278
- (0.0, 33.01),
279
- (0.0, 47.02),
280
- (0.0, 50.0),
281
- (5.25, 25.0),
282
- (13.92, 2.99),
283
- (13.92, 47.02),
284
- (19.0, 17.0),
285
- (19.0, 25.0),
286
- (19.0, 33.01),
287
- (27.4, 0.0),
288
- (29.01, 25.0),
289
- (27.4, 50.0),
290
- (46.99, 0.0),
291
- (46.99, 25.0),
292
- (46.99, 50.0),
293
- (66.61, 0.0),
294
- (65.0, 25.0),
295
- (66.61, 50.0),
296
- (75.0, 17.0),
297
- (75.0, 25.0),
298
- (75.0, 33.01),
299
- (80.09, 2.99),
300
- (80.09, 47.02),
301
- (88.75, 25.0),
302
- (94.0, 0.0),
303
- (94.0, 2.99),
304
- (94.0, 17.0),
305
- (94.0, 33.01),
306
- (94.0, 47.02),
307
- (94.0, 50.0)
308
- ])
309
-
310
- def get_distance_cost_matrix(arr1:np.ndarray, arr2:np.ndarray, ord=1) :
311
-
312
- cost_matrix = np.empty(shape=(len(arr1), len(arr2)), dtype=np.float64)
313
-
314
- for i in range(len(arr1)):
315
- cost_matrix[i] = np.linalg.norm(arr1[i] - arr2, ord=ord, axis=-1)
316
-
317
- return torch.tensor(cost_matrix)
318
-
319
- def matcher_probs_custom_argmax(probs:np.ndarray, confidence_threshold=0.7):
320
- probs = probs.squeeze(0)
321
- pred = probs.argmax()
322
- # if matcher predicts the null prediction, but it is not confident
323
- if pred == len(probs) - 1 and probs[pred] < confidence_threshold:
324
- # predict the second most confident prediction if it has high weight
325
- second_best = probs[:-1].argmax()
326
- if probs[second_best] > 1.0 - confidence_threshold - 0.05:
327
- pred = second_best
328
-
329
- return pred
330
-
331
- def show_annotations(frame_, detections_):
332
- annotated_frame = frame_.copy()
333
- annotated_frame = sv.MaskAnnotator(color_lookup=sv.ColorLookup.TRACK).annotate(annotated_frame, detections_)
334
- annotated_frame = sv.LabelAnnotator(smart_position=True).annotate(annotated_frame, detections_, labels=list(str(i) for i in detections_.tracker_id))
335
- return Image.fromarray(annotated_frame)
336
-
337
- def annotate_frame(frame_, detections_):
338
- annotated_frame = frame_.copy()
339
- annotated_frame = sv.MaskAnnotator(color_lookup=sv.ColorLookup.TRACK).annotate(annotated_frame, detections_)
340
- annotated_frame = sv.LabelAnnotator(smart_position=True).annotate(annotated_frame, detections_, labels=list(str(i) for i in detections_.tracker_id))
341
- return annotated_frame
342
-
343
- if __name__ == "__main__":
344
- from code import interact
345
- frames = read_consecutive_frames_from_video("nba_sample_videos/batch2/SAC_LAL_1.mp4", 199, 1)
346
- # crop_frame_at_mask_from_bbox(np.zeros((1080, 1920, 3)), )
347
- interact(local=locals())