File size: 3,822 Bytes
73f6108
 
 
 
 
 
 
 
26ab37f
73f6108
 
 
 
 
 
 
26ab37f
73f6108
 
 
26ab37f
79b792e
26ab37f
 
73f6108
 
 
 
 
26ab37f
73f6108
 
 
26ab37f
73f6108
6fe328f
73f6108
 
6fe328f
 
 
 
73f6108
 
 
 
 
 
 
 
 
 
79b792e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fe328f
79b792e
 
 
 
 
 
 
 
6fe328f
79b792e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
from os.path import abspath, join

import utool as ut

from scoutbot import agg, loc, tile, wic


def test_agg_compute_phase1():
    img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))

    # Run tiling
    img_shape, tile_grids, tile_filepaths = tile.compute(img_filepath)
    assert len(tile_filepaths) == 1252

    # Run WIC
    wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths, config='phase1')))
    assert len(wic_outputs) == len(tile_filepaths)

    # Threshold for WIC
    flags = [
        wic_output.get('positive') >= wic.CONFIGS['phase1']['thresh']
        for wic_output in wic_outputs
    ]
    loc_tile_grids = ut.compress(tile_grids, flags)
    loc_tile_filepaths = ut.compress(tile_filepaths, flags)
    assert sum(flags) == 15

    # Run localizer
    loc_outputs = loc.post(loc.predict(loc.pre(loc_tile_filepaths, config='phase1')))
    assert len(loc_tile_grids) == len(loc_outputs)

    # Aggregate
    detects = agg.compute(img_shape, loc_tile_grids, loc_outputs, config='phase1')

    assert len(detects) in [3, 4]

    targets = [
        {'l': 'elephant', 'c': 0.9299, 'x': 4597, 'y': 2322, 'w': 72, 'h': 149},
        {'l': 'elephant', 'c': 0.8739, 'x': 4865, 'y': 2422, 'w': 97, 'h': 109},
        {'l': 'elephant', 'c': 0.7115, 'x': 4806, 'y': 2476, 'w': 66, 'h': 119},
        {'l': 'elephant', 'c': 0.5236, 'x': 3511, 'y': 1228, 'w': 47, 'h': 78},
    ]

    for output, target in zip(detects, targets):
        for key in target.keys():
            if key == 'l':
                assert output.get(key) == target.get(key)
            elif key == 'c':
                assert abs(output.get(key) - target.get(key)) < 1e-2
            else:
                assert abs(output.get(key) - target.get(key)) < 3


def test_agg_compute_mvp():
    img_filepath = abspath(join('examples', '1be4d40a-6fd0-42ce-da6c-294e45781f41.jpg'))

    # Run tiling
    img_shape, tile_grids, tile_filepaths = tile.compute(img_filepath)
    assert len(tile_filepaths) == 1252

    # Run WIC
    wic_outputs = wic.post(wic.predict(wic.pre(tile_filepaths, config='mvp')))
    assert len(wic_outputs) == len(tile_filepaths)

    # Threshold for WIC
    flags = [
        wic_output.get('positive') >= wic.CONFIGS['mvp']['thresh']
        for wic_output in wic_outputs
    ]
    loc_tile_grids = ut.compress(tile_grids, flags)
    loc_tile_filepaths = ut.compress(tile_filepaths, flags)
    assert sum(flags) in [123, 125]

    # Run localizer
    loc_outputs = loc.post(loc.predict(loc.pre(loc_tile_filepaths, config='mvp')))
    assert len(loc_tile_grids) == len(loc_outputs)

    # Aggregate
    detects = agg.compute(img_shape, loc_tile_grids, loc_outputs, config='mvp')

    assert len(detects) in [7, 8]

    # fmt: off
    targets = [
        {'l': 'elephant', 'c': 0.6795, 'x': 4593, 'y': 2300, 'w': 78, 'h': 201},
        {'l': 'elephant', 'c': 0.6126, 'x': 4813, 'y': 2452, 'w': 54, 'h': 87},
        {'l': 'kob',      'c': 0.6058, 'x': 3391, 'y': 1076, 'w': 33, 'h': 32},
        {'l': 'elephant', 'c': 0.5933, 'x': 4873, 'y': 2428, 'w': 80, 'h': 99},
        {'l': 'kob',      'c': 0.4767, 'x': 1601, 'y': 1729, 'w': 53, 'h': 55},
        {'l': 'warthog',  'c': 0.4571, 'x': 4199, 'y': 2109, 'w': 31, 'h': 45},
        {'l': 'kob',      'c': 0.4193, 'x': 1441, 'y': 3377, 'w': 30, 'h': 38},
        {'l': 'elephant', 'c': 0.4178, 'x': 3891, 'y': 3641, 'w': 60, 'h': 84},
    ]
    # fmt: on

    for output, target in zip(detects, targets):
        for key in target.keys():
            if key == 'l':
                assert output.get(key) == target.get(key)
            elif key == 'c':
                assert abs(output.get(key) - target.get(key)) < 1e-2
            else:
                assert abs(output.get(key) - target.get(key)) < 3