doodleai / tests /test_model.py
alanoee's picture
Upload folder using huggingface_hub
346f830 verified
import os
import sys
import numpy as np
import pytest
from PIL import Image
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model.drawing_classifier import DrawingClassifier # pylint: disable=wrong-import-position
@pytest.fixture
def sample_image():
return Image.new('L', (28, 28), 128)
@pytest.fixture
def sample_numpy_image():
return np.random.randint(0, 255, (28, 28), dtype=np.uint8)
def test_drawing_classifier_init():
try:
classifier = DrawingClassifier()
assert classifier.model is None
assert isinstance(classifier.classes, list)
assert len(classifier.classes) > 0
except FileNotFoundError:
pytest.skip("classes.json not found")
def test_preprocess_image_pil(sample_image): # pylint: disable=redefined-outer-name
try:
classifier = DrawingClassifier()
processed = classifier.preprocess_image(sample_image)
assert processed.shape == (1, 28, 28, 1)
assert processed.dtype == np.float32
assert np.all(processed >= 0) and np.all(processed <= 1)
except FileNotFoundError:
pytest.skip("classes.json not found")
def test_preprocess_image_numpy(sample_numpy_image): # pylint: disable=redefined-outer-name
try:
classifier = DrawingClassifier()
processed = classifier.preprocess_image(sample_numpy_image)
assert processed.shape == (1, 28, 28, 1)
assert processed.dtype == np.float32
assert np.all(processed >= 0) and np.all(processed <= 1)
except FileNotFoundError:
pytest.skip("classes.json not found")
def test_create_simple_model():
try:
classifier = DrawingClassifier()
model = classifier.create_simple_model()
assert model is not None
assert len(model.layers) > 0
assert model.input_shape == (None, 28, 28, 1)
assert model.output_shape == (None, len(classifier.classes))
except FileNotFoundError:
pytest.skip("classes.json not found")
def test_predict_without_model(sample_image): # pylint: disable=redefined-outer-name
try:
classifier = DrawingClassifier()
classifier.model = None
predictions = classifier.predict(sample_image)
assert isinstance(predictions, list)
assert len(predictions) > 0
assert predictions[0]['class'] == 'error'
assert predictions[0]['confidence'] == 0.0
except FileNotFoundError:
pytest.skip("classes.json not found")
def test_predict_with_simple_model(sample_image): # pylint: disable=redefined-outer-name
try:
classifier = DrawingClassifier()
classifier.model = classifier.create_simple_model()
predictions = classifier.predict(sample_image)
assert isinstance(predictions, list)
assert len(predictions) <= 3
for pred in predictions:
assert 'class' in pred
assert 'confidence' in pred
assert isinstance(pred['confidence'], (int, float))
assert 0 <= pred['confidence'] <= 100
except FileNotFoundError:
pytest.skip("classes.json not found")
if __name__ == '__main__':
pytest.main([__file__])