File size: 3,208 Bytes
799c02f 3072e1b 799c02f 26ab37f 68e54f6 799c02f 26ab37f 68e54f6 799c02f 26ab37f 799c02f 26ab37f 799c02f 26ab37f d00fd73 26ab37f 799c02f 26ab37f 799c02f 26ab37f d00fd73 26ab37f 799c02f 26ab37f d00fd73 799c02f 79b792e 26ab37f 799c02f 3072e1b d00fd73 acbe316 26ab37f 79b792e 26ab37f acbe316 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
# -*- coding: utf-8 -*-
from os.path import abspath, exists, join
import onnx
def test_wic_onnx_load_phase1():
from scoutbot.wic import fetch
onnx_model = fetch(config='phase1')
model = onnx.load(onnx_model)
assert exists(onnx_model)
onnx.checker.check_model(model)
graph = onnx.helper.printable_graph(model.graph)
assert graph.count('\n') == 1334
def test_wic_onnx_load_mvp():
from scoutbot.wic import fetch
onnx_model = fetch(config='mvp')
model = onnx.load(onnx_model)
assert exists(onnx_model)
onnx.checker.check_model(model)
graph = onnx.helper.printable_graph(model.graph)
assert graph.count('\n') == 237
def test_wic_onnx_pipeline_phase1():
from scoutbot.wic import CONFIGS, INPUT_SIZE, post, pre, predict
inputs = [
abspath(join('examples', '1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg')),
]
assert exists(inputs[0])
data = pre(inputs, config='phase1')
temp, config = next(data)
assert temp.shape == (1, 3, INPUT_SIZE, INPUT_SIZE)
assert config == 'phase1'
data = pre(inputs, config='phase1')
preds = predict(data)
temp, config = next(preds)
assert temp.shape == (1, 2)
assert temp[0][1] > temp[0][0]
assert abs(temp[0][0] - 0.00001503) < 1e-4
assert abs(temp[0][1] - 0.99998497) < 1e-4
assert config == 'phase1'
data = pre(inputs, config='phase1')
preds = predict(data)
outputs = post(preds)
assert len(outputs) == 1
output = outputs[0]
classes = CONFIGS['phase1']['classes']
assert output.keys() == set(classes)
assert output['positive'] > output['negative']
assert abs(output['negative'] - 0.00001503) < 1e-4
assert abs(output['positive'] - 0.99998497) < 1e-4
assert isinstance(output['negative'], float)
assert isinstance(output['positive'], float)
def test_wic_onnx_pipeline_mvp():
from scoutbot.wic import CONFIGS, INPUT_SIZE, post, pre, predict
inputs = [
abspath(join('examples', '1e8372e4-357d-26e6-d7fd-0e0ae402463a.true.jpg')),
]
assert exists(inputs[0])
data = pre(inputs, config='mvp')
temp, config = next(data)
assert temp.shape == (1, 3, INPUT_SIZE, INPUT_SIZE)
assert config == 'mvp'
data = pre(inputs, config='mvp')
preds = predict(data)
temp, config = next(preds)
assert temp.shape == (1, 2)
assert temp[0][1] > temp[0][0]
assert abs(temp[0][0] - 0.00000000) < 1e-4
assert abs(temp[0][1] - 1.00000000) < 1e-4
assert config == 'mvp'
data = pre(inputs, config='mvp')
preds = predict(data)
outputs = post(preds)
assert len(outputs) == 1
output = outputs[0]
classes = CONFIGS['mvp']['classes']
assert output.keys() == set(classes)
assert output['positive'] > output['negative']
assert abs(output['negative'] - 0.00000000) < 1e-4
assert abs(output['positive'] - 1.00000000) < 1e-4
assert isinstance(output['negative'], float)
assert isinstance(output['positive'], float)
def test_wic_onnx_pipeline_empty():
from scoutbot.wic import post, pre, predict
data = pre([])
preds = predict(data)
outputs = post(preds)
assert len(outputs) == 0
|