doodleai / tests /test_app.py
alanoee's picture
Upload folder using huggingface_hub
346f830 verified
# os.environ must be set before importing app to override AI_API_KEY at module load time
# pylint: disable=wrong-import-position,wrong-import-order
import os
os.environ['AI_API_KEY'] = 'test-key'
import base64
import json
import sys
from io import BytesIO
import pytest
from PIL import Image
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app import app
@pytest.fixture
def client():
app.config['TESTING'] = True
app.config['RATELIMIT_ENABLED'] = False
with app.test_client() as client:
client.environ_base['HTTP_X_API_KEY'] = 'test-key'
yield client
@pytest.fixture
def sample_image_base64():
img = Image.new('L', (28, 28), 0)
for x in range(28):
for y in range(28):
if ((x - 14) ** 2 + (y - 14) ** 2) < 64:
img.putpixel((x, y), 255)
buffer = BytesIO()
img.save(buffer, format='PNG')
img_data = buffer.getvalue()
img_base64 = base64.b64encode(img_data).decode()
return f"data:image/png;base64,{img_base64}"
def test_index_endpoint(client):
response = client.get('/')
assert response.status_code == 200
data = json.loads(response.data)
assert 'name' in data
assert data['name'] == 'AI Drawing Classifier API'
def test_health_endpoint(client):
response = client.get('/health')
assert response.status_code == 200
data = json.loads(response.data)
assert 'status' in data
assert data['status'] == 'healthy'
def test_classes_endpoint(client):
response = client.get('/classes')
assert response.status_code in [200, 503]
data = json.loads(response.data)
if response.status_code == 200:
assert 'classes' in data
assert 'total_classes' in data
assert isinstance(data['classes'], list)
else:
assert 'error' in data
def test_get_random_word_endpoint(client):
response = client.get('/get_random_word')
assert response.status_code in [200, 503]
data = json.loads(response.data)
if response.status_code == 200:
assert 'word' in data
assert isinstance(data['word'], str)
else:
assert 'error' in data
def test_predict_missing_data(client):
response = client.post(
'/predict',
data=json.dumps({}),
content_type='application/json'
)
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
assert 'Missing image data' in data['error']
def test_predict_invalid_image_format(client):
response = client.post(
'/predict',
data=json.dumps({'image': 'invalid_format'}),
content_type='application/json'
)
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
assert 'Invalid image format' in data['error']
def test_predict_valid_image(client, sample_image_base64):
response = client.post(
'/predict',
data=json.dumps({'image': sample_image_base64}),
content_type='application/json'
)
assert response.status_code in [200, 503]
data = json.loads(response.data)
if response.status_code == 200:
assert 'predictions' in data
assert 'success' in data
assert data['success'] is True
assert isinstance(data['predictions'], list)
else:
assert 'error' in data
assert 'Model not loaded' in data['error']
def test_rate_limiting_predict(client, sample_image_base64):
for _ in range(3):
response = client.post(
'/predict',
data=json.dumps({'image': sample_image_base64}),
content_type='application/json'
)
assert response.status_code in [200, 429, 503]
if __name__ == '__main__':
pytest.main([__file__])