| import pytest
|
| from fastapi.testclient import TestClient
|
| from unittest.mock import Mock, patch, AsyncMock
|
| from io import BytesIO
|
| from constants import (
|
| STATUS_CODE_BAD_REQUEST,
|
| STATUS_CODE_CONTENT_TOO_LARGE,
|
| STATUS_CODE_EXCEED_SIZE_LIMIT,
|
| STATUS_CODE_INTERNAL_SERVER_ERROR,
|
| STATUS_CODE_LENGTH_REQUIRED,
|
| STATUS_CODE_UNSUPPORTED_MEDIA_TYPE,
|
| )
|
| from exceptions import (
|
| FileExtractionError,
|
| FileExtractionException,
|
| FileValidationError,
|
| FileValidationException,
|
| )
|
| from main import app
|
|
|
| client = TestClient(app)
|
|
|
| FILE_VALIDATION_TEST_CASES = [
|
| (FileValidationError.MISSING_SIZE, STATUS_CODE_LENGTH_REQUIRED),
|
| (FileValidationError.FILE_TOO_LARGE, STATUS_CODE_CONTENT_TOO_LARGE),
|
| (FileValidationError.MISSING_FILE_NAME, STATUS_CODE_BAD_REQUEST),
|
| (FileValidationError.FILE_NAME_TOO_LARGE, STATUS_CODE_BAD_REQUEST),
|
| (FileValidationError.INVALID_FILE_NAME, STATUS_CODE_BAD_REQUEST),
|
| (FileValidationError.INVALID_MIME_TYPE, STATUS_CODE_UNSUPPORTED_MEDIA_TYPE),
|
| (FileValidationError.UNSUPPORTED_EXTENSION, STATUS_CODE_UNSUPPORTED_MEDIA_TYPE),
|
| (FileValidationError.EMPTY_FILE, STATUS_CODE_BAD_REQUEST),
|
| ]
|
|
|
| FILE_EXTRACTION_TEST_CASES = [
|
| (FileExtractionError.INVALID_MIME_TYPE, STATUS_CODE_UNSUPPORTED_MEDIA_TYPE),
|
| (FileExtractionError.NO_TEXT, STATUS_CODE_BAD_REQUEST),
|
| (FileExtractionError.TEXT_EXTRACTION_TIMEOUT, STATUS_CODE_INTERNAL_SERVER_ERROR),
|
| (FileExtractionError.UNSAFE_ZIP, STATUS_CODE_INTERNAL_SERVER_ERROR),
|
| (FileExtractionError.FILE_TOO_LARGE, STATUS_CODE_CONTENT_TOO_LARGE),
|
| (FileExtractionError.MALFORMED_FILE, STATUS_CODE_BAD_REQUEST),
|
| ]
|
|
|
|
|
| class TestUploadFileEndpoint:
|
| """Test the PUT /file endpoint"""
|
|
|
| @pytest.fixture
|
| def valid_session_id(self):
|
| return "test-session-123"
|
|
|
| @pytest.fixture
|
| def valid_file(self):
|
| """Create a valid file upload"""
|
| return ("test.txt", BytesIO(b"Test file content"), "text/plain")
|
|
|
| @pytest.fixture
|
| def mock_dependencies(self):
|
| """Mock all external dependencies"""
|
| with (
|
| patch("main.validate_file") as mock_validate,
|
| patch("main.extract_text_from_file") as mock_extract,
|
| patch("main.PIIFilter") as mock_pii,
|
| patch("main.session_document_store") as mock_store,
|
| patch("main.session_tracker") as mock_tracker,
|
| ):
|
|
|
| mock_validate.return_value = AsyncMock(
|
| content=b"content", filename="test.txt", mime_type="text/plain"
|
| )
|
| mock_extract.return_value = "extracted text content"
|
| mock_pii.return_value.sanitize.return_value = "sanitized text"
|
| mock_store.create_document.return_value = True
|
|
|
| yield {
|
| "validate": mock_validate,
|
| "extract": mock_extract,
|
| "pii": mock_pii,
|
| "store": mock_store,
|
| "tracker": mock_tracker,
|
| }
|
|
|
| def test_successful_file_upload(
|
| self, valid_session_id, valid_file, mock_dependencies
|
| ):
|
| """Test successful file upload with all valid inputs"""
|
| response = client.put(
|
| "/file", files={"file": valid_file}, data={"session_id": valid_session_id}
|
| )
|
|
|
| assert response.status_code == 200
|
|
|
|
|
| assert mock_dependencies["validate"].called
|
| assert mock_dependencies["extract"].called
|
| assert mock_dependencies["pii"].return_value.sanitize.called
|
| assert mock_dependencies["store"].create_document.called
|
| assert mock_dependencies["tracker"].update_session.called
|
|
|
| @pytest.mark.parametrize(
|
| "validation_error,expected_status", FILE_VALIDATION_TEST_CASES
|
| )
|
| def test_file_validation_failure(
|
| self,
|
| valid_session_id,
|
| valid_file,
|
| mock_dependencies,
|
| validation_error,
|
| expected_status,
|
| ):
|
| """Test that validation errors return correct status codes"""
|
|
|
|
|
| mock_dependencies["validate"].side_effect = FileValidationException(
|
| validation_error
|
| )
|
|
|
| response = client.put(
|
| "/file", files={"file": valid_file}, data={"session_id": valid_session_id}
|
| )
|
|
|
|
|
| assert response.status_code == expected_status
|
|
|
|
|
| assert not mock_dependencies["extract"].called
|
| assert not mock_dependencies["store"].create_document.called
|
|
|
| @pytest.mark.parametrize(
|
| "extraction_error,expected_status", FILE_EXTRACTION_TEST_CASES
|
| )
|
| def test_file_extraction_failure(
|
| self,
|
| valid_session_id,
|
| valid_file,
|
| mock_dependencies,
|
| extraction_error,
|
| expected_status,
|
| ):
|
| """Test that extraction errors return correct status codes"""
|
|
|
| mock_dependencies["extract"].side_effect = FileExtractionException(
|
| extraction_error
|
| )
|
|
|
| response = client.put(
|
| "/file", files={"file": valid_file}, data={"session_id": valid_session_id}
|
| )
|
|
|
| assert response.status_code == expected_status
|
| assert not mock_dependencies["store"].create_document.called
|
|
|
| def test_unexpected_extraction_exception(
|
| self, valid_session_id, valid_file, mock_dependencies
|
| ):
|
| """Test that unexpected exceptions return 500"""
|
| mock_dependencies["extract"].side_effect = Exception("Unexpected error")
|
|
|
| response = client.put(
|
| "/file",
|
| files={"file": valid_file},
|
| data={"session_id": valid_session_id},
|
| )
|
|
|
| assert response.status_code == STATUS_CODE_INTERNAL_SERVER_ERROR
|
|
|
| def test_exceed_size_limit(self, valid_session_id, valid_file, mock_dependencies):
|
| """Test that size limit exceeded returns correct status"""
|
| mock_dependencies["store"].create_document.return_value = False
|
|
|
| response = client.put(
|
| "/file", files={"file": valid_file}, data={"session_id": valid_session_id}
|
| )
|
|
|
|
|
| assert response.status_code == STATUS_CODE_EXCEED_SIZE_LIMIT
|
|
|
|
|
| assert not mock_dependencies["tracker"].update_session.called
|
|
|
| def test_invalid_session_id_format(self, valid_file):
|
| """Test that invalid session_id format is rejected"""
|
| invalid_session_ids = [
|
| "",
|
| "a" * 51,
|
| "invalid@chars!",
|
| "has spaces",
|
| ]
|
|
|
| for invalid_id in invalid_session_ids:
|
| response = client.put(
|
| "/file", files={"file": valid_file}, data={"session_id": invalid_id}
|
| )
|
| assert response.status_code == 422
|
|
|
| def test_empty(self):
|
| response = client.put("/file")
|
| assert response.status_code == 422
|
|
|
| def test_missing_file(self, valid_session_id):
|
| """Test that missing file returns validation error"""
|
| response = client.put("/file", data={"session_id": valid_session_id})
|
| assert response.status_code == 422
|
|
|
| def test_missing_session_id(self, valid_file):
|
| """Test that missing session_id returns validation error"""
|
| response = client.put("/file", files={"file": valid_file})
|
| assert response.status_code == 422
|
|
|
| @pytest.mark.enable_rate_limit
|
| def test_rate_limiting(self, valid_session_id, valid_file, mock_dependencies):
|
| """Test that rate limiting works (12 requests per minute)"""
|
|
|
| responses = []
|
| for i in range(13):
|
| response = client.put(
|
| "/file",
|
| files={"file": ("test.txt", BytesIO(b"content"), "text/plain")},
|
| data={"session_id": valid_session_id},
|
| )
|
| responses.append(response)
|
|
|
|
|
|
|
| assert responses[-1].status_code == 429
|
|
|
| def test_pii_sanitization_called(
|
| self, valid_session_id, valid_file, mock_dependencies
|
| ):
|
| """Test that PII filter is applied to extracted text"""
|
| extracted_text = "This is sensitive data"
|
| sanitized_text = "This is [REDACTED] data"
|
|
|
| mock_dependencies["extract"].return_value = extracted_text
|
| mock_dependencies["pii"].return_value.sanitize.return_value = sanitized_text
|
|
|
| client.put(
|
| "/file", files={"file": valid_file}, data={"session_id": valid_session_id}
|
| )
|
|
|
|
|
| mock_dependencies["pii"].return_value.sanitize.assert_called_once_with(
|
| extracted_text
|
| )
|
|
|
|
|
| mock_dependencies["store"].create_document.assert_called_once_with(
|
| valid_session_id, sanitized_text, "test.txt"
|
| )
|
|
|
| def test_different_file_types(self, valid_session_id, mock_dependencies):
|
| """Test endpoint with different MIME types"""
|
| file_types = [
|
| ("test.txt", b"text content", "text/plain"),
|
| ("test.pdf", b"%PDF-1.4 content", "application/pdf"),
|
| (
|
| "test.docx",
|
| b"docx content",
|
| "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
| ),
|
| ]
|
|
|
| for filename, content, mime_type in file_types:
|
| response = client.put(
|
| "/file",
|
| files={"file": (filename, BytesIO(content), mime_type)},
|
| data={"session_id": valid_session_id},
|
| )
|
|
|
|
|
| assert response.status_code == 200
|
|
|