File size: 3,201 Bytes
346f830
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import sys

import numpy as np
import pytest
from PIL import Image

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from model.drawing_classifier import DrawingClassifier  # pylint: disable=wrong-import-position


@pytest.fixture
def sample_image():
    return Image.new('L', (28, 28), 128)


@pytest.fixture
def sample_numpy_image():
    return np.random.randint(0, 255, (28, 28), dtype=np.uint8)


def test_drawing_classifier_init():
    try:
        classifier = DrawingClassifier()
        assert classifier.model is None
        assert isinstance(classifier.classes, list)
        assert len(classifier.classes) > 0
    except FileNotFoundError:
        pytest.skip("classes.json not found")


def test_preprocess_image_pil(sample_image):  # pylint: disable=redefined-outer-name
    try:
        classifier = DrawingClassifier()
        processed = classifier.preprocess_image(sample_image)
        assert processed.shape == (1, 28, 28, 1)
        assert processed.dtype == np.float32
        assert np.all(processed >= 0) and np.all(processed <= 1)
    except FileNotFoundError:
        pytest.skip("classes.json not found")


def test_preprocess_image_numpy(sample_numpy_image):  # pylint: disable=redefined-outer-name
    try:
        classifier = DrawingClassifier()
        processed = classifier.preprocess_image(sample_numpy_image)
        assert processed.shape == (1, 28, 28, 1)
        assert processed.dtype == np.float32
        assert np.all(processed >= 0) and np.all(processed <= 1)
    except FileNotFoundError:
        pytest.skip("classes.json not found")


def test_create_simple_model():
    try:
        classifier = DrawingClassifier()
        model = classifier.create_simple_model()
        assert model is not None
        assert len(model.layers) > 0
        assert model.input_shape == (None, 28, 28, 1)
        assert model.output_shape == (None, len(classifier.classes))
    except FileNotFoundError:
        pytest.skip("classes.json not found")


def test_predict_without_model(sample_image):  # pylint: disable=redefined-outer-name
    try:
        classifier = DrawingClassifier()
        classifier.model = None
        predictions = classifier.predict(sample_image)
        assert isinstance(predictions, list)
        assert len(predictions) > 0
        assert predictions[0]['class'] == 'error'
        assert predictions[0]['confidence'] == 0.0
    except FileNotFoundError:
        pytest.skip("classes.json not found")


def test_predict_with_simple_model(sample_image):  # pylint: disable=redefined-outer-name
    try:
        classifier = DrawingClassifier()
        classifier.model = classifier.create_simple_model()
        predictions = classifier.predict(sample_image)
        assert isinstance(predictions, list)
        assert len(predictions) <= 3
        for pred in predictions:
            assert 'class' in pred
            assert 'confidence' in pred
            assert isinstance(pred['confidence'], (int, float))
            assert 0 <= pred['confidence'] <= 100
    except FileNotFoundError:
        pytest.skip("classes.json not found")


if __name__ == '__main__':
    pytest.main([__file__])