Spaces:
Runtime error
Runtime error
Commit ·
93b9874
1
Parent(s): 6bfcb22
Update inference.py
Browse files- inference.py +15 -9
inference.py
CHANGED
|
@@ -64,11 +64,11 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
| 64 |
if config.associative_tracker == TrackerType.BYTETRACK:
|
| 65 |
|
| 66 |
# Find low confidence detections
|
| 67 |
-
low_outputs = do_suppression(inference, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, gp=gp)
|
| 68 |
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
|
| 69 |
|
| 70 |
# Find high confidence detections
|
| 71 |
-
high_outputs = do_suppression(inference, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, gp=gp)
|
| 72 |
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
|
| 73 |
|
| 74 |
# Perform associative tracking (ByteTrack)
|
|
@@ -80,7 +80,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
| 80 |
else:
|
| 81 |
|
| 82 |
# Find confident detections
|
| 83 |
-
outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
|
| 84 |
|
| 85 |
if config.associative_tracker == TrackerType.CONF_BOOST:
|
| 86 |
|
|
@@ -88,7 +88,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
| 88 |
do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp)
|
| 89 |
|
| 90 |
# Find confident detections from boosted list
|
| 91 |
-
outputs = do_suppression(inference, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
|
| 92 |
|
| 93 |
# Format confident detections
|
| 94 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
|
@@ -169,7 +169,7 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE, verb
|
|
| 169 |
|
| 170 |
return inference, image_shapes, width, height
|
| 171 |
|
| 172 |
-
def do_suppression(inference, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU, verbose=True):
|
| 173 |
"""
|
| 174 |
Args:
|
| 175 |
frames_dir: a directory containing frames to be evaluated
|
|
@@ -188,7 +188,7 @@ def do_suppression(inference, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_TH
|
|
| 188 |
if gp: gp(batch_i / len(inference), pbar.__str__())
|
| 189 |
|
| 190 |
with torch.no_grad():
|
| 191 |
-
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)
|
| 192 |
|
| 193 |
|
| 194 |
outputs.append(output)
|
|
@@ -422,13 +422,13 @@ def filter_detection_size(inference, image_meter_width, width, max_length):
|
|
| 422 |
print(wc.shape)
|
| 423 |
bs = batch.shape[0] # batches
|
| 424 |
|
| 425 |
-
output =
|
| 426 |
print("wc")
|
| 427 |
print(batch.shape)
|
| 428 |
for xi, x in enumerate(batch):
|
| 429 |
x = x[wc[xi]] # confidence
|
| 430 |
print(x.shape)
|
| 431 |
-
output[xi] = x
|
| 432 |
|
| 433 |
output = torch.tensor(output)
|
| 434 |
print("output len", output.shape)
|
|
@@ -439,6 +439,9 @@ def filter_detection_size(inference, image_meter_width, width, max_length):
|
|
| 439 |
|
| 440 |
def non_max_suppression(
|
| 441 |
prediction,
|
|
|
|
|
|
|
|
|
|
| 442 |
conf_thres=0.25,
|
| 443 |
iou_thres=0.45,
|
| 444 |
max_det=300
|
|
@@ -463,6 +466,9 @@ def non_max_suppression(
|
|
| 463 |
prediction = prediction.cpu()
|
| 464 |
bs = prediction.shape[0] # batch size
|
| 465 |
xc = prediction[..., 4] > conf_thres # candidates
|
|
|
|
|
|
|
|
|
|
| 466 |
|
| 467 |
|
| 468 |
# Settings
|
|
@@ -476,7 +482,7 @@ def non_max_suppression(
|
|
| 476 |
|
| 477 |
|
| 478 |
# Keep boxes that pass confidence threshold
|
| 479 |
-
x = x[xc[xi]] # confidence
|
| 480 |
|
| 481 |
# If none remain process next image
|
| 482 |
if not x.shape[0]:
|
|
|
|
| 64 |
if config.associative_tracker == TrackerType.BYTETRACK:
|
| 65 |
|
| 66 |
# Find low confidence detections
|
| 67 |
+
low_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_low_conf, iou_thres=config.nms_iou, gp=gp)
|
| 68 |
low_preds, real_width, real_height = format_predictions(image_shapes, low_outputs, width, height, gp=gp)
|
| 69 |
|
| 70 |
# Find high confidence detections
|
| 71 |
+
high_outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.byte_high_conf, iou_thres=config.nms_iou, gp=gp)
|
| 72 |
high_preds, real_width, real_height = format_predictions(image_shapes, high_outputs, width, height, gp=gp)
|
| 73 |
|
| 74 |
# Perform associative tracking (ByteTrack)
|
|
|
|
| 80 |
else:
|
| 81 |
|
| 82 |
# Find confident detections
|
| 83 |
+
outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
|
| 84 |
|
| 85 |
if config.associative_tracker == TrackerType.CONF_BOOST:
|
| 86 |
|
|
|
|
| 88 |
do_confidence_boost(inference, outputs, boost_power=config.boost_power, boost_decay=config.boost_decay, gp=gp)
|
| 89 |
|
| 90 |
# Find confident detections from boosted list
|
| 91 |
+
outputs = do_suppression(inference, image_meter_width, width, conf_thres=config.conf_thresh, iou_thres=config.nms_iou, gp=gp)
|
| 92 |
|
| 93 |
# Format confident detections
|
| 94 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
|
|
|
| 169 |
|
| 170 |
return inference, image_shapes, width, height
|
| 171 |
|
| 172 |
+
def do_suppression(inference, image_meter_width, image_pixel_width, gp=None, batch_size=BATCH_SIZE, conf_thres=CONF_THRES, iou_thres=NMS_IOU, max_length=1.5, verbose=True):
|
| 173 |
"""
|
| 174 |
Args:
|
| 175 |
frames_dir: a directory containing frames to be evaluated
|
|
|
|
| 188 |
if gp: gp(batch_i / len(inference), pbar.__str__())
|
| 189 |
|
| 190 |
with torch.no_grad():
|
| 191 |
+
output = non_max_suppression(inf_out, image_meter_width, image_pixel_width, conf_thres=conf_thres, iou_thres=iou_thres, max_length=max_length)
|
| 192 |
|
| 193 |
|
| 194 |
outputs.append(output)
|
|
|
|
| 422 |
print(wc.shape)
|
| 423 |
bs = batch.shape[0] # batches
|
| 424 |
|
| 425 |
+
output = torch.zeros((bs, 0, 6), device=batch.device)
|
| 426 |
print("wc")
|
| 427 |
print(batch.shape)
|
| 428 |
for xi, x in enumerate(batch):
|
| 429 |
x = x[wc[xi]] # confidence
|
| 430 |
print(x.shape)
|
| 431 |
+
output[xi, :, :] = x
|
| 432 |
|
| 433 |
output = torch.tensor(output)
|
| 434 |
print("output len", output.shape)
|
|
|
|
| 439 |
|
| 440 |
def non_max_suppression(
|
| 441 |
prediction,
|
| 442 |
+
image_meter_width,
|
| 443 |
+
image_pixel_width,
|
| 444 |
+
max_length=1.5,
|
| 445 |
conf_thres=0.25,
|
| 446 |
iou_thres=0.45,
|
| 447 |
max_det=300
|
|
|
|
| 466 |
prediction = prediction.cpu()
|
| 467 |
bs = prediction.shape[0] # batch size
|
| 468 |
xc = prediction[..., 4] > conf_thres # candidates
|
| 469 |
+
pix2width = image_meter_width/width
|
| 470 |
+
width = prediction[..., 2]*pix2width
|
| 471 |
+
wc = width < max_length
|
| 472 |
|
| 473 |
|
| 474 |
# Settings
|
|
|
|
| 482 |
|
| 483 |
|
| 484 |
# Keep boxes that pass confidence threshold
|
| 485 |
+
x = x[xc[xi] * wc[xi]] # confidence
|
| 486 |
|
| 487 |
# If none remain process next image
|
| 488 |
if not x.shape[0]:
|