Spaces:
Runtime error
Runtime error
Commit ·
cef04ce
1
Parent(s): a77f5fd
Update inference.py
Browse files- inference.py +6 -6
inference.py
CHANGED
|
@@ -52,7 +52,7 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
| 52 |
|
| 53 |
# Load hyperparameters
|
| 54 |
if 'model' not in hyperparams: hyperparams['model'] = WEIGHTS
|
| 55 |
-
if 'conf_thresh' not in hyperparams: hyperparams['
|
| 56 |
if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
|
| 57 |
if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
|
| 58 |
if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
|
|
@@ -87,13 +87,13 @@ def do_full_inference(dataloader, image_meter_width, image_meter_height, gp=None
|
|
| 87 |
return
|
| 88 |
|
| 89 |
|
| 90 |
-
outputs = do_suppression(inference, conf_thres=hyperparams['
|
| 91 |
|
| 92 |
if hyperparams['use_associative_tracking']:
|
| 93 |
-
|
| 94 |
do_confidence_boost(inference, outputs, gp=gp)
|
| 95 |
|
| 96 |
-
outputs = do_suppression(inference, conf_thres=hyperparams['
|
| 97 |
|
| 98 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
| 99 |
|
|
@@ -288,13 +288,13 @@ def do_confidence_boost(inference, safe_preds, gp=None, batch_size=BATCH_SIZE, v
|
|
| 288 |
pbar.update(1*batch_size)
|
| 289 |
|
| 290 |
|
| 291 |
-
def boost_frame(safe_frame, base_frame, dt):
|
| 292 |
safe_boxes = safe_frame[:, :4]
|
| 293 |
boxes = xywh2xyxy(base_frame[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
|
| 294 |
ious = box_iou(boxes, safe_boxes)
|
| 295 |
score = torch.matmul(ious, safe_frame[:, 4])
|
| 296 |
# score = iou(safe_box, base_box) * confidence(safe_box)
|
| 297 |
-
base_frame[:, 4] *= 1 + (score)*math.exp(-dt*dt)
|
| 298 |
return base_frame
|
| 299 |
|
| 300 |
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True):
|
|
|
|
| 52 |
|
| 53 |
# Load hyperparameters
|
| 54 |
if 'model' not in hyperparams: hyperparams['model'] = WEIGHTS
|
| 55 |
+
if 'conf_thresh' not in hyperparams: hyperparams['conf_thresh'] = CONF_THRES
|
| 56 |
if 'iou_thresh' not in hyperparams: hyperparams['iou_thresh'] = NMS_IOU
|
| 57 |
if 'min_hits' not in hyperparams: hyperparams['min_hits'] = MIN_HITS
|
| 58 |
if 'max_age' not in hyperparams: hyperparams['max_age'] = MAX_AGE
|
|
|
|
| 87 |
return
|
| 88 |
|
| 89 |
|
| 90 |
+
outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
|
| 91 |
|
| 92 |
if hyperparams['use_associative_tracking']:
|
| 93 |
+
|
| 94 |
do_confidence_boost(inference, outputs, gp=gp)
|
| 95 |
|
| 96 |
+
outputs = do_suppression(inference, conf_thres=hyperparams['conf_thresh'], iou_thres=hyperparams['iou_thresh'], gp=gp)
|
| 97 |
|
| 98 |
all_preds, real_width, real_height = format_predictions(image_shapes, outputs, width, height, gp=gp)
|
| 99 |
|
|
|
|
| 288 |
pbar.update(1*batch_size)
|
| 289 |
|
| 290 |
|
| 291 |
+
def boost_frame(safe_frame, base_frame, dt, decay=1):
|
| 292 |
safe_boxes = safe_frame[:, :4]
|
| 293 |
boxes = xywh2xyxy(base_frame[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
|
| 294 |
ious = box_iou(boxes, safe_boxes)
|
| 295 |
score = torch.matmul(ious, safe_frame[:, 4])
|
| 296 |
# score = iou(safe_box, base_box) * confidence(safe_box)
|
| 297 |
+
base_frame[:, 4] *= 1 + (score)*math.exp(-decay*dt*dt)
|
| 298 |
return base_frame
|
| 299 |
|
| 300 |
def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True):
|