import pytest from fastapi.testclient import TestClient from unittest.mock import Mock, patch, AsyncMock from io import BytesIO from constants import ( STATUS_CODE_BAD_REQUEST, STATUS_CODE_CONTENT_TOO_LARGE, STATUS_CODE_EXCEED_SIZE_LIMIT, STATUS_CODE_INTERNAL_SERVER_ERROR, STATUS_CODE_LENGTH_REQUIRED, STATUS_CODE_UNSUPPORTED_MEDIA_TYPE, ) from exceptions import ( FileExtractionError, FileExtractionException, FileValidationError, FileValidationException, ) from main import app client = TestClient(app) FILE_VALIDATION_TEST_CASES = [ (FileValidationError.MISSING_SIZE, STATUS_CODE_LENGTH_REQUIRED), (FileValidationError.FILE_TOO_LARGE, STATUS_CODE_CONTENT_TOO_LARGE), (FileValidationError.MISSING_FILE_NAME, STATUS_CODE_BAD_REQUEST), (FileValidationError.FILE_NAME_TOO_LARGE, STATUS_CODE_BAD_REQUEST), (FileValidationError.INVALID_FILE_NAME, STATUS_CODE_BAD_REQUEST), (FileValidationError.INVALID_MIME_TYPE, STATUS_CODE_UNSUPPORTED_MEDIA_TYPE), (FileValidationError.UNSUPPORTED_EXTENSION, STATUS_CODE_UNSUPPORTED_MEDIA_TYPE), (FileValidationError.EMPTY_FILE, STATUS_CODE_BAD_REQUEST), ] FILE_EXTRACTION_TEST_CASES = [ (FileExtractionError.INVALID_MIME_TYPE, STATUS_CODE_UNSUPPORTED_MEDIA_TYPE), (FileExtractionError.NO_TEXT, STATUS_CODE_BAD_REQUEST), (FileExtractionError.TEXT_EXTRACTION_TIMEOUT, STATUS_CODE_INTERNAL_SERVER_ERROR), (FileExtractionError.UNSAFE_ZIP, STATUS_CODE_INTERNAL_SERVER_ERROR), (FileExtractionError.FILE_TOO_LARGE, STATUS_CODE_CONTENT_TOO_LARGE), (FileExtractionError.MALFORMED_FILE, STATUS_CODE_BAD_REQUEST), ] class TestUploadFileEndpoint: """Test the PUT /file endpoint""" @pytest.fixture def valid_session_id(self): return "test-session-123" @pytest.fixture def valid_file(self): """Create a valid file upload""" return ("test.txt", BytesIO(b"Test file content"), "text/plain") @pytest.fixture def mock_dependencies(self): """Mock all external dependencies""" with ( patch("main.validate_file") as mock_validate, patch("main.extract_text_from_file") as mock_extract, patch("main.PIIFilter") as mock_pii, patch("main.session_document_store") as mock_store, patch("main.session_tracker") as mock_tracker, ): # Setup default successful behavior mock_validate.return_value = AsyncMock( content=b"content", filename="test.txt", mime_type="text/plain" ) mock_extract.return_value = "extracted text content" mock_pii.return_value.sanitize.return_value = "sanitized text" mock_store.create_document.return_value = True yield { "validate": mock_validate, "extract": mock_extract, "pii": mock_pii, "store": mock_store, "tracker": mock_tracker, } def test_successful_file_upload( self, valid_session_id, valid_file, mock_dependencies ): """Test successful file upload with all valid inputs""" response = client.put( "/file", files={"file": valid_file}, data={"session_id": valid_session_id} ) assert response.status_code == 200 # Verify the workflow was executed assert mock_dependencies["validate"].called assert mock_dependencies["extract"].called assert mock_dependencies["pii"].return_value.sanitize.called assert mock_dependencies["store"].create_document.called assert mock_dependencies["tracker"].update_session.called @pytest.mark.parametrize( "validation_error,expected_status", FILE_VALIDATION_TEST_CASES ) def test_file_validation_failure( self, valid_session_id, valid_file, mock_dependencies, validation_error, expected_status, ): """Test that validation errors return correct status codes""" # Mock validation to raise an exception mock_dependencies["validate"].side_effect = FileValidationException( validation_error ) response = client.put( "/file", files={"file": valid_file}, data={"session_id": valid_session_id} ) # Should return the appropriate error status code assert response.status_code == expected_status # Verify later steps weren't called assert not mock_dependencies["extract"].called assert not mock_dependencies["store"].create_document.called @pytest.mark.parametrize( "extraction_error,expected_status", FILE_EXTRACTION_TEST_CASES ) def test_file_extraction_failure( self, valid_session_id, valid_file, mock_dependencies, extraction_error, expected_status, ): """Test that extraction errors return correct status codes""" mock_dependencies["extract"].side_effect = FileExtractionException( extraction_error ) response = client.put( "/file", files={"file": valid_file}, data={"session_id": valid_session_id} ) assert response.status_code == expected_status assert not mock_dependencies["store"].create_document.called def test_unexpected_extraction_exception( self, valid_session_id, valid_file, mock_dependencies ): """Test that unexpected exceptions return 500""" mock_dependencies["extract"].side_effect = Exception("Unexpected error") response = client.put( "/file", files={"file": valid_file}, data={"session_id": valid_session_id}, ) assert response.status_code == STATUS_CODE_INTERNAL_SERVER_ERROR def test_exceed_size_limit(self, valid_session_id, valid_file, mock_dependencies): """Test that size limit exceeded returns correct status""" mock_dependencies["store"].create_document.return_value = False response = client.put( "/file", files={"file": valid_file}, data={"session_id": valid_session_id} ) # Should return size limit status code assert response.status_code == STATUS_CODE_EXCEED_SIZE_LIMIT # Session tracker should NOT be called when storage fails assert not mock_dependencies["tracker"].update_session.called def test_invalid_session_id_format(self, valid_file): """Test that invalid session_id format is rejected""" invalid_session_ids = [ "", # Too short "a" * 51, # Too long "invalid@chars!", # Invalid characters "has spaces", # Spaces not allowed ] for invalid_id in invalid_session_ids: response = client.put( "/file", files={"file": valid_file}, data={"session_id": invalid_id} ) assert response.status_code == 422 # Validation error def test_empty(self): response = client.put("/file") assert response.status_code == 422 def test_missing_file(self, valid_session_id): """Test that missing file returns validation error""" response = client.put("/file", data={"session_id": valid_session_id}) assert response.status_code == 422 def test_missing_session_id(self, valid_file): """Test that missing session_id returns validation error""" response = client.put("/file", files={"file": valid_file}) assert response.status_code == 422 @pytest.mark.enable_rate_limit def test_rate_limiting(self, valid_session_id, valid_file, mock_dependencies): """Test that rate limiting works (12 requests per minute)""" # Make 13 rapid requests responses = [] for i in range(13): response = client.put( "/file", files={"file": ("test.txt", BytesIO(b"content"), "text/plain")}, data={"session_id": valid_session_id}, ) responses.append(response) # First 12 should succeed (or have normal errors) # 13th should be rate limited assert responses[-1].status_code == 429 def test_pii_sanitization_called( self, valid_session_id, valid_file, mock_dependencies ): """Test that PII filter is applied to extracted text""" extracted_text = "This is sensitive data" sanitized_text = "This is [REDACTED] data" mock_dependencies["extract"].return_value = extracted_text mock_dependencies["pii"].return_value.sanitize.return_value = sanitized_text client.put( "/file", files={"file": valid_file}, data={"session_id": valid_session_id} ) # Verify sanitize was called with extracted text mock_dependencies["pii"].return_value.sanitize.assert_called_once_with( extracted_text ) # Verify store was called with sanitized text mock_dependencies["store"].create_document.assert_called_once_with( valid_session_id, sanitized_text, "test.txt" ) def test_different_file_types(self, valid_session_id, mock_dependencies): """Test endpoint with different MIME types""" file_types = [ ("test.txt", b"text content", "text/plain"), ("test.pdf", b"%PDF-1.4 content", "application/pdf"), ( "test.docx", b"docx content", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", ), ] for filename, content, mime_type in file_types: response = client.put( "/file", files={"file": (filename, BytesIO(content), mime_type)}, data={"session_id": valid_session_id}, ) # Should process all valid types assert response.status_code == 200