File size: 3,208 Bytes
799c02f
3072e1b
 
799c02f
 
 
26ab37f
68e54f6
799c02f
26ab37f
68e54f6
 
799c02f
 
 
 
 
 
 
26ab37f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
799c02f
 
 
 
 
 
 
26ab37f
799c02f
26ab37f
d00fd73
26ab37f
799c02f
26ab37f
799c02f
 
26ab37f
d00fd73
 
 
 
26ab37f
799c02f
26ab37f
d00fd73
799c02f
 
 
 
79b792e
26ab37f
799c02f
3072e1b
 
d00fd73
 
acbe316
26ab37f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b792e
26ab37f
 
 
 
 
 
 
 
 
 
 
acbe316
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# -*- coding: utf-8 -*-
from os.path import abspath, exists, join

import onnx


def test_wic_onnx_load_phase1():
    from scoutbot.wic import fetch

    onnx_model = fetch(config='phase1')
    model = onnx.load(onnx_model)
    assert exists(onnx_model)

    onnx.checker.check_model(model)

    graph = onnx.helper.printable_graph(model.graph)
    assert graph.count('\n') == 1334


def test_wic_onnx_load_mvp():
    from scoutbot.wic import fetch

    onnx_model = fetch(config='mvp')
    model = onnx.load(onnx_model)
    assert exists(onnx_model)

    onnx.checker.check_model(model)

    graph = onnx.helper.printable_graph(model.graph)
    assert graph.count('\n') == 237


def test_wic_onnx_pipeline_phase1():
    from scoutbot.wic import CONFIGS, INPUT_SIZE, post, pre, predict

    inputs = [
        abspath(join('examples', '1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg')),
    ]

    assert exists(inputs[0])

    data = pre(inputs, config='phase1')

    temp, config = next(data)
    assert temp.shape == (1, 3, INPUT_SIZE, INPUT_SIZE)
    assert config == 'phase1'

    data = pre(inputs, config='phase1')
    preds = predict(data)

    temp, config = next(preds)
    assert temp.shape == (1, 2)
    assert temp[0][1] > temp[0][0]
    assert abs(temp[0][0] - 0.00001503) < 1e-4
    assert abs(temp[0][1] - 0.99998497) < 1e-4
    assert config == 'phase1'

    data = pre(inputs, config='phase1')
    preds = predict(data)
    outputs = post(preds)

    assert len(outputs) == 1
    output = outputs[0]
    classes = CONFIGS['phase1']['classes']
    assert output.keys() == set(classes)
    assert output['positive'] > output['negative']
    assert abs(output['negative'] - 0.00001503) < 1e-4
    assert abs(output['positive'] - 0.99998497) < 1e-4
    assert isinstance(output['negative'], float)
    assert isinstance(output['positive'], float)


def test_wic_onnx_pipeline_mvp():
    from scoutbot.wic import CONFIGS, INPUT_SIZE, post, pre, predict

    inputs = [
        abspath(join('examples', '1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg')),
    ]

    assert exists(inputs[0])

    data = pre(inputs, config='mvp')

    temp, config = next(data)
    assert temp.shape == (1, 3, INPUT_SIZE, INPUT_SIZE)
    assert config == 'mvp'

    data = pre(inputs, config='mvp')
    preds = predict(data)

    temp, config = next(preds)
    assert temp.shape == (1, 2)
    assert temp[0][1] > temp[0][0]
    assert abs(temp[0][0] - 0.00000000) < 1e-4
    assert abs(temp[0][1] - 1.00000000) < 1e-4
    assert config == 'mvp'

    data = pre(inputs, config='mvp')
    preds = predict(data)
    outputs = post(preds)

    assert len(outputs) == 1
    output = outputs[0]
    classes = CONFIGS['mvp']['classes']
    assert output.keys() == set(classes)
    assert output['positive'] > output['negative']
    assert abs(output['negative'] - 0.00000000) < 1e-4
    assert abs(output['positive'] - 1.00000000) < 1e-4
    assert isinstance(output['negative'], float)
    assert isinstance(output['positive'], float)


def test_wic_onnx_pipeline_empty():
    from scoutbot.wic import post, pre, predict

    data = pre([])
    preds = predict(data)
    outputs = post(preds)
    assert len(outputs) == 0