Spaces:
Runtime error
Runtime error
Commit ·
1ae5c71
1
Parent(s): 9ab5dcd
Update inference.py
Browse files- inference.py +97 -22
inference.py
CHANGED
|
@@ -330,45 +330,36 @@ def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_a
|
|
| 330 |
|
| 331 |
return json_data
|
| 332 |
|
| 333 |
-
def do_associative_tracking(
|
| 334 |
|
| 335 |
if (gp): gp(0, "Tracking...")
|
| 336 |
|
| 337 |
print("Preprocessing")
|
| 338 |
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
high_frame = []
|
| 346 |
-
for bbox in frame:
|
| 347 |
-
if bbox[4] > conf_thresh:
|
| 348 |
-
high_frame.append(bbox)
|
| 349 |
-
else:
|
| 350 |
-
low_frame.append(bbox)
|
| 351 |
-
low_dets.append(low_frame)
|
| 352 |
-
high_dets.append(high_frame)
|
| 353 |
-
pbar.update(1)
|
| 354 |
|
| 355 |
print("Preprocess done")
|
| 356 |
|
| 357 |
# Initialize tracker
|
| 358 |
clip_info = {
|
| 359 |
'start_frame': 0,
|
| 360 |
-
'end_frame': len(
|
| 361 |
'image_meter_width': image_meter_width,
|
| 362 |
'image_meter_height': image_meter_height
|
| 363 |
}
|
| 364 |
tracker = Tracker(clip_info, algorithm=Associate, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
|
| 365 |
|
| 366 |
# Run tracking
|
| 367 |
-
with tqdm(total=len(
|
| 368 |
-
for i in range(len(
|
| 369 |
-
if gp: gp(i / len(
|
| 370 |
-
low_boxes =
|
| 371 |
-
high_boxes =
|
| 372 |
boxes = (low_boxes, high_boxes)
|
| 373 |
if len(low_boxes) + len(high_boxes) > 0:
|
| 374 |
tracker.update(boxes)
|
|
@@ -455,6 +446,90 @@ def non_max_suppression(
|
|
| 455 |
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
|
| 456 |
|
| 457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
# Check shape
|
| 459 |
n = x.shape[0] # number of boxes
|
| 460 |
if not n: # no boxes
|
|
|
|
| 330 |
|
| 331 |
return json_data
|
| 332 |
|
| 333 |
+
def do_associative_tracking(inference, image_shapes, width, height, image_meter_width, image_meter_height, gp=None, low_thresh=0.001, high_threshold=0.2, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, batch_size=BATCH_SIZE, verbose=True):
|
| 334 |
|
| 335 |
if (gp): gp(0, "Tracking...")
|
| 336 |
|
| 337 |
print("Preprocessing")
|
| 338 |
|
| 339 |
+
|
| 340 |
+
low_outputs = do_suppression(inference, conf_thres=low_thresh, iou_thres=iou_thres, gp=gp)
|
| 341 |
+
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, verbose=verbose)
|
| 342 |
+
|
| 343 |
+
high_outputs = do_suppression(inference, conf_thres=high_threshold, iou_thres=iou_thres, gp=gp)
|
| 344 |
+
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, verbose=verbose)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
print("Preprocess done")
|
| 347 |
|
| 348 |
# Initialize tracker
|
| 349 |
clip_info = {
|
| 350 |
'start_frame': 0,
|
| 351 |
+
'end_frame': len(low_preds),
|
| 352 |
'image_meter_width': image_meter_width,
|
| 353 |
'image_meter_height': image_meter_height
|
| 354 |
}
|
| 355 |
tracker = Tracker(clip_info, algorithm=Associate, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
|
| 356 |
|
| 357 |
# Run tracking
|
| 358 |
+
with tqdm(total=len(low_preds), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
|
| 359 |
+
for i in range(len(low_preds)):
|
| 360 |
+
if gp: gp(i / len(low_preds), pbar.__str__())
|
| 361 |
+
low_boxes = low_preds[i]
|
| 362 |
+
high_boxes = high_preds[i]
|
| 363 |
boxes = (low_boxes, high_boxes)
|
| 364 |
if len(low_boxes) + len(high_boxes) > 0:
|
| 365 |
tracker.update(boxes)
|
|
|
|
| 446 |
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
|
| 447 |
|
| 448 |
|
| 449 |
+
# Check shape
|
| 450 |
+
n = x.shape[0] # number of boxes
|
| 451 |
+
if not n: # no boxes
|
| 452 |
+
continue
|
| 453 |
+
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
|
| 454 |
+
|
| 455 |
+
# Batched NMS
|
| 456 |
+
boxes = x[:, :4] # boxes (offset by class), scores
|
| 457 |
+
scores = x[:, 4]
|
| 458 |
+
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
| 459 |
+
|
| 460 |
+
i = i[:max_det] # limit detections
|
| 461 |
+
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
| 462 |
+
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
| 463 |
+
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
| 464 |
+
weights = iou * scores[None] # box weights
|
| 465 |
+
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
| 466 |
+
if redundant:
|
| 467 |
+
i = i[iou.sum(1) > 1] # require redundancy
|
| 468 |
+
|
| 469 |
+
output[xi] = x[i]
|
| 470 |
+
if mps:
|
| 471 |
+
output[xi] = output[xi].to(device)
|
| 472 |
+
|
| 473 |
+
logging = False
|
| 474 |
+
|
| 475 |
+
return output
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def no_suppression(
|
| 479 |
+
prediction,
|
| 480 |
+
conf_thres=0.25,
|
| 481 |
+
iou_thres=0.45,
|
| 482 |
+
max_det=300,
|
| 483 |
+
):
|
| 484 |
+
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
# Checks
|
| 491 |
+
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
|
| 492 |
+
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
|
| 493 |
+
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
|
| 494 |
+
prediction = prediction[0] # select only inference output
|
| 495 |
+
|
| 496 |
+
device = prediction.device
|
| 497 |
+
mps = 'mps' in device.type # Apple MPS
|
| 498 |
+
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
|
| 499 |
+
prediction = prediction.cpu()
|
| 500 |
+
bs = prediction.shape[0] # batch size
|
| 501 |
+
xc = prediction[..., 4] > conf_thres # candidates
|
| 502 |
+
|
| 503 |
+
# Settings
|
| 504 |
+
# min_wh = 2 # (pixels) minimum box width and height
|
| 505 |
+
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
| 506 |
+
redundant = True # require redundant detections
|
| 507 |
+
merge = False # use merge-NMS
|
| 508 |
+
|
| 509 |
+
output = [torch.zeros((0, 6), device=prediction.device)] * bs
|
| 510 |
+
for xi, x in enumerate(prediction): # image index, image inference
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
# Keep boxes that pass confidence threshold
|
| 514 |
+
x = x[xc[xi]] # confidence
|
| 515 |
+
|
| 516 |
+
# If none remain process next image
|
| 517 |
+
if not x.shape[0]:
|
| 518 |
+
continue
|
| 519 |
+
|
| 520 |
+
# Compute conf
|
| 521 |
+
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
# Box/Mask
|
| 525 |
+
box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
|
| 526 |
+
mask = x[:, 6:] # zero columns if no masks
|
| 527 |
+
|
| 528 |
+
# Detections matrix nx6 (xyxy, conf, cls)
|
| 529 |
+
conf, j = x[:, 5:6].max(1, keepdim=True)
|
| 530 |
+
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
|
| 531 |
+
|
| 532 |
+
|
| 533 |
# Check shape
|
| 534 |
n = x.shape[0] # number of boxes
|
| 535 |
if not n: # no boxes
|