usiddiquee
hi
e1832f4
# Mikel Broström 🔥 RFDETR Tracking 🧾 AGPL-3.0 license
import numpy as np
import torch
import cv2
from PIL import Image
from rfdetr import RFDETRBase
from rfdetr.util.coco_classes import COCO_CLASSES
from ultralytics.engine.results import Results
from ultralytics.utils import ops
from ultralytics.models.yolo.detect import DetectionPredictor
from boxmot.utils import logger as LOGGER
from tracking.detectors.yolo_interface import YoloInterface
class RFDETRStrategy(YoloInterface):
pt = False
stride = 32
fp16 = False
triton = False
names = COCO_CLASSES
def __init__(self, model, device, args):
self.args = args
LOGGER.info("Loading RFDETR model")
self.model = RFDETRBase(device='cpu')
@torch.no_grad()
def __call__(self, im, augment, visualize, embed):
# Convert frame to PIL Image format for RFDETR
frame_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
image = Image.fromarray(frame_rgb)
with torch.no_grad():
detections = self.model.predict(im, threshold=self.args.conf)
preds = np.column_stack(
[
detections.xyxy,
detections.confidence[:, np.newaxis],
detections.class_id[:, np.newaxis]
]
)
preds = torch.from_numpy(preds).unsqueeze(0)
return preds
def warmup(self, imgsz):
pass
def update_im_paths(self, predictor: DetectionPredictor):
"""
This function saves image paths for the current batch,
being passed as callback on_predict_batch_start
"""
assert (isinstance(predictor, DetectionPredictor),
"Only ultralytics predictors are supported")
self.im_paths = predictor.batch[0]
def preprocess(self, im) -> torch.Tensor:
assert isinstance(im, list)
return im[0]
def postprocess(self, preds, im, im0s):
results = []
for i, pred in enumerate(preds):
im_path = self.im_paths[i] if len(self.im_paths) else ""
if pred is None or len(pred) == 0:
pred = torch.empty((0, 6))
else:
if self.args.classes:
pred = pred[torch.isin(pred[:, 5].cpu(), torch.as_tensor(self.args.classes))]
r = Results(
path=im_path,
boxes=pred,
orig_img=im0s[i],
names=COCO_CLASSES
)
results.append(r)
return results