Spaces:
Runtime error
Runtime error
Commit ·
889bde6
1
Parent(s): bb2dfaa
Update inference.py
Browse files- inference.py +24 -8
inference.py
CHANGED
|
@@ -18,6 +18,7 @@ import torch
|
|
| 18 |
import torchvision
|
| 19 |
|
| 20 |
from lib.fish_eye.tracker import Tracker
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
### Configuration options
|
|
@@ -329,11 +330,24 @@ def do_tracking(all_preds, image_meter_width, image_meter_height, gp=None, max_a
|
|
| 329 |
|
| 330 |
return json_data
|
| 331 |
|
| 332 |
-
def do_associative_tracking(raw_detections, image_meter_width, image_meter_height, gp=None, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True):
|
| 333 |
|
| 334 |
if (gp): gp(0, "Tracking...")
|
| 335 |
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
|
| 338 |
# Initialize tracker
|
| 339 |
clip_info = {
|
|
@@ -342,14 +356,16 @@ def do_associative_tracking(raw_detections, image_meter_width, image_meter_heigh
|
|
| 342 |
'image_meter_width': image_meter_width,
|
| 343 |
'image_meter_height': image_meter_height
|
| 344 |
}
|
| 345 |
-
tracker = Tracker(clip_info, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
|
| 346 |
|
| 347 |
# Run tracking
|
| 348 |
-
with tqdm(total=len(
|
| 349 |
-
for i, key in enumerate(sorted(
|
| 350 |
-
if gp: gp(i / len(
|
| 351 |
-
|
| 352 |
-
|
|
|
|
|
|
|
| 353 |
tracker.update(boxes)
|
| 354 |
else:
|
| 355 |
tracker.update()
|
|
|
|
| 18 |
import torchvision
|
| 19 |
|
| 20 |
from lib.fish_eye.tracker import Tracker
|
| 21 |
+
from lib.fish_eye.associative import Associate
|
| 22 |
|
| 23 |
|
| 24 |
### Configuration options
|
|
|
|
| 330 |
|
| 331 |
return json_data
|
| 332 |
|
| 333 |
+
def do_associative_tracking(raw_detections, image_meter_width, image_meter_height, gp=None, conf_thresh=0.2, max_age=MAX_AGE, iou_thres=IOU_THRES, min_hits=MIN_HITS, min_length=MIN_LENGTH, verbose=True):
|
| 334 |
|
| 335 |
if (gp): gp(0, "Tracking...")
|
| 336 |
|
| 337 |
+
low_dets = []
|
| 338 |
+
high_dets = []
|
| 339 |
+
for batch in raw_detections:
|
| 340 |
+
for frame in batch:
|
| 341 |
+
low_frame = []
|
| 342 |
+
high_frame = []
|
| 343 |
+
for bbox in frame:
|
| 344 |
+
if bbox[4] > conf_thresh:
|
| 345 |
+
high_frame.append(bbox)
|
| 346 |
+
else:
|
| 347 |
+
low_frame.append(bbox)
|
| 348 |
+
low_dets.append(low_frame)
|
| 349 |
+
high_dets.append(high_frame)
|
| 350 |
+
|
| 351 |
|
| 352 |
# Initialize tracker
|
| 353 |
clip_info = {
|
|
|
|
| 356 |
'image_meter_width': image_meter_width,
|
| 357 |
'image_meter_height': image_meter_height
|
| 358 |
}
|
| 359 |
+
tracker = Tracker(clip_info, algorithm=Associate, args={ 'max_age': max_age, 'min_hits': 0, 'iou_threshold': iou_thres}, min_hits=min_hits)
|
| 360 |
|
| 361 |
# Run tracking
|
| 362 |
+
with tqdm(total=len(low_dets), desc="Running tracking", ncols=0, disable=not verbose) as pbar:
|
| 363 |
+
for i, key in enumerate(sorted(low_dets.keys())):
|
| 364 |
+
if gp: gp(i / len(low_dets), pbar.__str__())
|
| 365 |
+
low_boxes = low_dets[key]
|
| 366 |
+
high_boxes = high_dets[key]
|
| 367 |
+
boxes = (low_boxes, high_boxes)
|
| 368 |
+
if len(low_boxes) + len(high_boxes) > 0:
|
| 369 |
tracker.update(boxes)
|
| 370 |
else:
|
| 371 |
tracker.update()
|