poolay2's picture
Upload folder using huggingface_hub
bbc0514 verified
from __future__ import annotations
import torch
import numpy as np
import supervision as sv
from pycocotools import mask as mask_utils
import cv2
import ffmpeg
from PIL import Image
import numpy as np
from typing import List, Iterable
from matplotlib import pyplot as plt
class SAM2Tracker:
def __init__(self, predictor) -> None:
self.predictor = predictor
self._prompted = False
def prompt_first_frame(self, frame: np.ndarray, detections: sv.Detections) -> None:
if len(detections) == 0:
raise ValueError("detections must contain at least one box")
if detections.tracker_id is None:
detections.tracker_id = list(range(1, len(detections) + 1))
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
self.predictor.load_first_frame(frame)
for xyxy, obj_id in zip(detections.xyxy, detections.tracker_id):
bbox = np.asarray([xyxy], dtype=np.float32)
self.predictor.add_new_prompt(
frame_idx=0,
obj_id=int(obj_id),
bbox=bbox,
)
self._prompted = True
def propagate(self, frame: np.ndarray) -> sv.Detections:
if not self._prompted:
raise RuntimeError("Call prompt_first_frame before propagate")
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
tracker_ids, mask_logits = self.predictor.track(frame)
tracker_ids = np.asarray(tracker_ids, dtype=np.int32)
masks = (mask_logits > 0.0).cpu().numpy()
masks = np.squeeze(masks).astype(bool)
if masks.ndim == 2:
masks = masks[None, ...]
masks = np.array([
sv.filter_segments_by_distance(mask, relative_distance=0.03, mode="edge")
for mask in masks
])
xyxy = sv.mask_to_xyxy(masks=masks)
detections = sv.Detections(xyxy=xyxy, mask=masks, tracker_id=tracker_ids)
return detections
def reset(self) -> None:
self._prompted = False
def get_crops_from_masks(frame: np.ndarray, masks: np.ndarray) -> list[np.ndarray]:
"""
Args:mask_index
frame: (H, W, 3) image
masks: (N, H, W) binary masks
Returns:
List of cropped images, one per mask. Each crop is a rectangular
bounding box around the mask, with black pixels outside the mask.
"""
crops = []
for mask in masks:
# Find bounding box of the mask
ys, xs = np.where(mask)
if len(xs) == 0 or len(ys) == 0:
# Empty mask → skip or return empty crop
crops.append(np.zeros((0, 0, 3), dtype=frame.dtype))
continue
y_min, y_max = ys.min(), ys.max() + 1
x_min, x_max = xs.min(), xs.max() + 1
# Crop the frame and mask
frame_crop = frame[y_min:y_max, x_min:x_max]
mask_crop = mask[y_min:y_max, x_min:x_max]
# Apply mask: keep pixels where mask is True, else black
crop = np.zeros_like(frame_crop)
crop[mask_crop] = frame_crop[mask_crop]
crops.append(crop)
return crops
def f(detections: sv.Detections, track_history: dict, frame_index):
for i in range(len(detections)):
mask = detections.mask[i]
rle = mask_utils.encode(np.asfortranarray(mask))
track_history[int(detections.tracker_id[i])].append((frame_index, rle['counts']))
def toRGB(img: np.ndarray):
return cv2.cvtColor(img, code=cv2.COLOR_BGR2RGB)
def read_frame_from_video(in_filename, frame_num):
raw_bytes, err = (
ffmpeg
.input(in_filename)
.filter('select', 'gte(n,{})'.format(frame_num))
.output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
.global_args('-loglevel', 'error')
.run(capture_stdout=True)
)
assert len(raw_bytes) == 1080 * 1920 * 3
return np.frombuffer(raw_bytes, np.uint8).reshape(1, 1080, 1920, 3).copy()
def read_consecutive_frames_from_video(in_filename, start_frame, num_frames) -> np.ndarray:
out, err = ffmpeg.input(in_filename)\
.output(
'pipe:1',
vf=f'select=between(n\\,{start_frame}\\,{start_frame + num_frames - 1})',
vsync=0,
vframes=num_frames,
format='rawvideo',
pix_fmt='rgb24'
).global_args('-loglevel', 'error')\
.run(capture_stdout=True, capture_stderr=True)
W, H = 1920, 1080
frame_size = W * H * 3
frames = np.frombuffer(out, np.uint8)
if frames.size != num_frames * frame_size:
raise RuntimeError(
f'Expected {num_frames * frame_size} bytes, got {frames.size}\n'
f'ffmpeg stderr:\n{err.decode()}'
)
# frames.setflags(write=True)
return frames.reshape(num_frames, H, W, 3).copy()
def xywhn_to_xywh(xywhn:list, height:int, width:int):
x,y,w,h = xywhn
return [int(x * width), int(y * height), int(w * width), int(h * height)]
def crop_frame_at_mask_from_bbox(frame: np.ndarray, mask: np.ndarray, bbox: list) -> np.array:
x,y,w,h = bbox
crop = frame[y: y+h, x: x+w]
cropped_mask = mask[y: y+h, x: x+w]
# from code import interact; interact(local=locals())
crop[~cropped_mask] = np.array([0,0,0], dtype=np.uint8)
return crop
def find_consecutive_streaks(nums: list|Iterable):
if isinstance(nums, Iterable): nums = list(nums)
if not nums:
return []
streaks = []
start = nums[0]
for i in range(1, len(nums)):
if nums[i] != nums[i-1] + 1:
stop = nums[i-1]
streaks.append(range(start, stop + 1))
start = nums[i]
streaks.append(range(start, nums[-1] + 1))
return streaks
def save_loss_history(fpath, loss:float):
with open(fpath, "a+") as f:
f.write(f"{loss:.6f}\n")
def save_loss_history_plot(loss_history: list[float], fpath):
plt.plot(loss_history)
plt.savefig(fpath)
def save_checkpoint(
path,
model,
optimizer,
epoch,
step,
):
ckpt = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
"step": step,
}
torch.save(ckpt, path)
def load_checkpoint(
path,
model,
optimizer,
device="cuda"
):
ckpt = torch.load(path, map_location=device)
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
epoch = ckpt.get("epoch", 0)
step = ckpt.get("step", 0)
return epoch, step
def mask_iou_pair(m1, m2):
inter = np.logical_and(m1, m2).sum()
if inter == 0:
return 0.0
union = m1.sum() + m2.sum() - inter
return inter / (union + 1e-6)
def mask_nms(masks, scores, iou_thresh=0.6):
order = np.argsort(-scores)
keep = []
suppressed = np.zeros(len(masks), dtype=bool)
for i in order:
if suppressed[i]:
continue
keep.append(i)
for j in order:
if j <= i or suppressed[j]:
continue
iou = mask_iou_pair(masks[i], masks[j])
if iou > iou_thresh:
suppressed[j] = True
return keep
def mask_iou(masks_t: np.ndarray, masks_t1):
# Flatten
N, H, W = masks_t.shape
M = masks_t1.shape[0]
masks_t = masks_t.reshape(N, -1).astype(float) # (N, HW)
masks_t1 = masks_t1.reshape(M, -1).astype(float) # (M, HW)
# Intersection: (N, M)
intersection = masks_t @ masks_t1.T
# Areas
area_t = masks_t.sum(1, keepdims=True) # (N, 1)
area_t1 = masks_t1.sum(1, keepdims=True) # (M, 1)
# Union
union = area_t + area_t1.T - intersection
iou = intersection / (union + 1e-6)
return iou # (N, M)
COURT_KEYPOINT_COORDINATES = np.array([
(0.0, 0.0),
(0.0, 2.99),
(0.0, 17.0),
(0.0, 33.01),
(0.0, 47.02),
(0.0, 50.0),
(5.25, 25.0),
(13.92, 2.99),
(13.92, 47.02),
(19.0, 17.0),
(19.0, 25.0),
(19.0, 33.01),
(27.4, 0.0),
(29.01, 25.0),
(27.4, 50.0),
(46.99, 0.0),
(46.99, 25.0),
(46.99, 50.0),
(66.61, 0.0),
(65.0, 25.0),
(66.61, 50.0),
(75.0, 17.0),
(75.0, 25.0),
(75.0, 33.01),
(80.09, 2.99),
(80.09, 47.02),
(88.75, 25.0),
(94.0, 0.0),
(94.0, 2.99),
(94.0, 17.0),
(94.0, 33.01),
(94.0, 47.02),
(94.0, 50.0)
])
def get_distance_cost_matrix(arr1:np.ndarray, arr2:np.ndarray, ord=1) :
cost_matrix = np.empty(shape=(len(arr1), len(arr2)), dtype=np.float64)
for i in range(len(arr1)):
cost_matrix[i] = np.linalg.norm(arr1[i] - arr2, ord=ord, axis=-1)
return torch.tensor(cost_matrix)
def matcher_probs_custom_argmax(probs:np.ndarray, confidence_threshold=0.7):
probs = probs.squeeze(0)
pred = probs.argmax()
# if matcher predicts the null prediction, but it is not confident
if pred == len(probs) - 1 and probs[pred] < confidence_threshold:
# predict the second most confident prediction if it has high weight
second_best = probs[:-1].argmax()
if probs[second_best] > 1.0 - confidence_threshold - 0.05:
pred = second_best
return pred
def show_annotations(frame_, detections_):
annotated_frame = frame_.copy()
annotated_frame = sv.MaskAnnotator(color_lookup=sv.ColorLookup.TRACK).annotate(annotated_frame, detections_)
annotated_frame = sv.LabelAnnotator(smart_position=True).annotate(annotated_frame, detections_, labels=list(str(i) for i in detections_.tracker_id))
return Image.fromarray(annotated_frame)
def annotate_frame(frame_, detections_):
annotated_frame = frame_.copy()
annotated_frame = sv.MaskAnnotator(color_lookup=sv.ColorLookup.TRACK).annotate(annotated_frame, detections_)
annotated_frame = sv.LabelAnnotator(smart_position=True).annotate(annotated_frame, detections_, labels=list(str(i) for i in detections_.tracker_id))
return annotated_frame
if __name__ == "__main__":
from code import interact
frames = read_consecutive_frames_from_video("nba_sample_videos/batch2/SAC_LAL_1.mp4", 199, 1)
# crop_frame_at_mask_from_bbox(np.zeros((1080, 1920, 3)), )
interact(local=locals())