| | from unittest.mock import AsyncMock, MagicMock, patch |
| |
|
| | import pytest |
| | from app.auth.jwt import get_current_user |
| | from fastapi import UploadFile |
| | from httpx import AsyncClient |
| | from main import app |
| |
|
| | |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_predict(): |
| | mock_file = AsyncMock(spec=UploadFile) |
| | mock_file.filename = "test_image.png" |
| | mock_file.read = AsyncMock(return_value=b"fake-image-data") |
| |
|
| | mock_user = MagicMock() |
| | mock_user.id = 1 |
| |
|
| | mock_current_user = MagicMock() |
| | mock_current_user.return_value = "testtoken" |
| |
|
| | app.dependency_overrides[get_current_user] = lambda: mock_current_user |
| |
|
| | with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"): |
| | with patch( |
| | "app.model.router.model_predict", new_callable=AsyncMock |
| | ) as mock_model_predict: |
| | with patch("app.model.router.os.path.exists", return_value=False): |
| | mock_model_predict.return_value = ("cat", 0.95) |
| | with patch("builtins.open", new_callable=MagicMock): |
| | async with AsyncClient(app=app, base_url="http://test") as ac: |
| | response = await ac.post( |
| | "/model/predict", |
| | files={ |
| | "file": ( |
| | "test_image.png", |
| | mock_file.read.return_value, |
| | "image/png", |
| | ) |
| | }, |
| | headers={"Authorization": "Bearer testtoken"}, |
| | ) |
| |
|
| | assert response.status_code == 200 |
| |
|
| | response_data = response.json() |
| | assert response_data["success"] is True |
| | assert response_data["prediction"] == "cat" |
| | assert response_data["score"] == 0.95 |
| | assert response_data["image_file_name"] == "fakehash123" |
| |
|
| |
|
| | @pytest.mark.asyncio |
| | async def test_predict_fails_bad_extension(): |
| | mock_file = AsyncMock(spec=UploadFile) |
| | mock_file.filename = "test_image.png" |
| | mock_file.read = AsyncMock(return_value=b"fake-image-data") |
| |
|
| | mock_user = MagicMock() |
| | mock_user.id = 1 |
| |
|
| | mock_current_user = MagicMock() |
| | mock_current_user.return_value = "testtoken" |
| |
|
| | app.dependency_overrides[get_current_user] = lambda: mock_current_user |
| |
|
| | with patch("app.model.router.utils.get_file_hash", return_value="fakehash123"): |
| | with patch( |
| | "app.model.router.model_predict", new_callable=AsyncMock |
| | ) as mock_model_predict: |
| | with patch("app.model.router.os.path.exists", return_value=False): |
| | mock_model_predict.return_value = ("cat", 0.95) |
| | with patch("builtins.open", new_callable=MagicMock): |
| | async with AsyncClient(app=app, base_url="http://test") as ac: |
| | response = await ac.post( |
| | "/model/predict", |
| | files={ |
| | "file": ( |
| | "test_image.pdf", |
| | mock_file.read.return_value, |
| | "image/png", |
| | ) |
| | }, |
| | headers={"Authorization": "Bearer testtoken"}, |
| | ) |
| |
|
| | assert response.status_code == 400 |
| | assert response.json() == { |
| | "detail": "File type is not supported." |
| | } |
| |
|