Spaces:
Runtime error
Runtime error
Commit ·
4a7fc58
1
Parent(s): cb90e6f
Update inference.py
Browse files- inference.py +8 -5
inference.py
CHANGED
|
@@ -155,11 +155,6 @@ def do_detection(dataloader, model, device, gp=None, batch_size=BATCH_SIZE, verb
|
|
| 155 |
inf_out, _ = model(img, augment=False)
|
| 156 |
|
| 157 |
|
| 158 |
-
32, 30240, 6
|
| 159 |
-
print(inf_out.shape)
|
| 160 |
-
print(inf_out[1, 1, :])
|
| 161 |
-
w = inf_out[:, :, 2] - inf_out[:, :, 0]
|
| 162 |
-
print(w.shape)
|
| 163 |
|
| 164 |
# Save shapes for resizing to original shape
|
| 165 |
batch_shape = []
|
|
@@ -399,9 +394,11 @@ def json_dump_round_float(some_object, out_path, num_digits=4):
|
|
| 399 |
|
| 400 |
def non_max_suppression(
|
| 401 |
prediction,
|
|
|
|
| 402 |
conf_thres=0.25,
|
| 403 |
iou_thres=0.45,
|
| 404 |
max_det=300,
|
|
|
|
| 405 |
):
|
| 406 |
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
|
| 407 |
|
|
@@ -423,6 +420,12 @@ def non_max_suppression(
|
|
| 423 |
prediction = prediction.cpu()
|
| 424 |
bs = prediction.shape[0] # batch size
|
| 425 |
xc = prediction[..., 4] > conf_thres # candidates
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
|
| 427 |
# Settings
|
| 428 |
# min_wh = 2 # (pixels) minimum box width and height
|
|
|
|
| 155 |
inf_out, _ = model(img, augment=False)
|
| 156 |
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
# Save shapes for resizing to original shape
|
| 160 |
batch_shape = []
|
|
|
|
| 394 |
|
| 395 |
def non_max_suppression(
|
| 396 |
prediction,
|
| 397 |
+
pix2w,
|
| 398 |
conf_thres=0.25,
|
| 399 |
iou_thres=0.45,
|
| 400 |
max_det=300,
|
| 401 |
+
max_length=1.5
|
| 402 |
):
|
| 403 |
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
|
| 404 |
|
|
|
|
| 420 |
prediction = prediction.cpu()
|
| 421 |
bs = prediction.shape[0] # batch size
|
| 422 |
xc = prediction[..., 4] > conf_thres # candidates
|
| 423 |
+
print(xc.shape)
|
| 424 |
+
wc = (prediction[..., 2] - prediction[..., 0])*pix2w < max_length
|
| 425 |
+
print(wc.shape)
|
| 426 |
+
xc = xc and wc
|
| 427 |
+
print(xc.shape)
|
| 428 |
+
|
| 429 |
|
| 430 |
# Settings
|
| 431 |
# min_wh = 2 # (pixels) minimum box width and height
|