Spaces:
Paused
Paused
| import pytest | |
| from fastapi.testclient import TestClient | |
| from unittest.mock import Mock, patch, AsyncMock | |
| from io import BytesIO | |
| from constants import ( | |
| STATUS_CODE_BAD_REQUEST, | |
| STATUS_CODE_CONTENT_TOO_LARGE, | |
| STATUS_CODE_EXCEED_SIZE_LIMIT, | |
| STATUS_CODE_INTERNAL_SERVER_ERROR, | |
| STATUS_CODE_LENGTH_REQUIRED, | |
| STATUS_CODE_UNSUPPORTED_MEDIA_TYPE, | |
| ) | |
| from exceptions import ( | |
| FileExtractionError, | |
| FileExtractionException, | |
| FileValidationError, | |
| FileValidationException, | |
| ) | |
| from main import app | |
| client = TestClient(app) | |
| FILE_VALIDATION_TEST_CASES = [ | |
| (FileValidationError.MISSING_SIZE, STATUS_CODE_LENGTH_REQUIRED), | |
| (FileValidationError.FILE_TOO_LARGE, STATUS_CODE_CONTENT_TOO_LARGE), | |
| (FileValidationError.MISSING_FILE_NAME, STATUS_CODE_BAD_REQUEST), | |
| (FileValidationError.FILE_NAME_TOO_LARGE, STATUS_CODE_BAD_REQUEST), | |
| (FileValidationError.INVALID_FILE_NAME, STATUS_CODE_BAD_REQUEST), | |
| (FileValidationError.INVALID_MIME_TYPE, STATUS_CODE_UNSUPPORTED_MEDIA_TYPE), | |
| (FileValidationError.UNSUPPORTED_EXTENSION, STATUS_CODE_UNSUPPORTED_MEDIA_TYPE), | |
| (FileValidationError.EMPTY_FILE, STATUS_CODE_BAD_REQUEST), | |
| ] | |
| FILE_EXTRACTION_TEST_CASES = [ | |
| (FileExtractionError.INVALID_MIME_TYPE, STATUS_CODE_UNSUPPORTED_MEDIA_TYPE), | |
| (FileExtractionError.NO_TEXT, STATUS_CODE_BAD_REQUEST), | |
| (FileExtractionError.TEXT_EXTRACTION_TIMEOUT, STATUS_CODE_INTERNAL_SERVER_ERROR), | |
| (FileExtractionError.UNSAFE_ZIP, STATUS_CODE_INTERNAL_SERVER_ERROR), | |
| (FileExtractionError.FILE_TOO_LARGE, STATUS_CODE_CONTENT_TOO_LARGE), | |
| (FileExtractionError.MALFORMED_FILE, STATUS_CODE_BAD_REQUEST), | |
| ] | |
| class TestUploadFileEndpoint: | |
| """Test the PUT /file endpoint""" | |
| def valid_session_id(self): | |
| return "test-session-123" | |
| def valid_file(self): | |
| """Create a valid file upload""" | |
| return ("test.txt", BytesIO(b"Test file content"), "text/plain") | |
| def mock_dependencies(self): | |
| """Mock all external dependencies""" | |
| with ( | |
| patch("main.validate_file") as mock_validate, | |
| patch("main.extract_text_from_file") as mock_extract, | |
| patch("main.PIIFilter") as mock_pii, | |
| patch("main.session_document_store") as mock_store, | |
| patch("main.session_tracker") as mock_tracker, | |
| ): | |
| # Setup default successful behavior | |
| mock_validate.return_value = AsyncMock( | |
| content=b"content", filename="test.txt", mime_type="text/plain" | |
| ) | |
| mock_extract.return_value = "extracted text content" | |
| mock_pii.return_value.sanitize.return_value = "sanitized text" | |
| mock_store.create_document.return_value = True | |
| yield { | |
| "validate": mock_validate, | |
| "extract": mock_extract, | |
| "pii": mock_pii, | |
| "store": mock_store, | |
| "tracker": mock_tracker, | |
| } | |
| def test_successful_file_upload( | |
| self, valid_session_id, valid_file, mock_dependencies | |
| ): | |
| """Test successful file upload with all valid inputs""" | |
| response = client.put( | |
| "/file", files={"file": valid_file}, data={"session_id": valid_session_id} | |
| ) | |
| assert response.status_code == 200 | |
| # Verify the workflow was executed | |
| assert mock_dependencies["validate"].called | |
| assert mock_dependencies["extract"].called | |
| assert mock_dependencies["pii"].return_value.sanitize.called | |
| assert mock_dependencies["store"].create_document.called | |
| assert mock_dependencies["tracker"].update_session.called | |
| def test_file_validation_failure( | |
| self, | |
| valid_session_id, | |
| valid_file, | |
| mock_dependencies, | |
| validation_error, | |
| expected_status, | |
| ): | |
| """Test that validation errors return correct status codes""" | |
| # Mock validation to raise an exception | |
| mock_dependencies["validate"].side_effect = FileValidationException( | |
| validation_error | |
| ) | |
| response = client.put( | |
| "/file", files={"file": valid_file}, data={"session_id": valid_session_id} | |
| ) | |
| # Should return the appropriate error status code | |
| assert response.status_code == expected_status | |
| # Verify later steps weren't called | |
| assert not mock_dependencies["extract"].called | |
| assert not mock_dependencies["store"].create_document.called | |
| def test_file_extraction_failure( | |
| self, | |
| valid_session_id, | |
| valid_file, | |
| mock_dependencies, | |
| extraction_error, | |
| expected_status, | |
| ): | |
| """Test that extraction errors return correct status codes""" | |
| mock_dependencies["extract"].side_effect = FileExtractionException( | |
| extraction_error | |
| ) | |
| response = client.put( | |
| "/file", files={"file": valid_file}, data={"session_id": valid_session_id} | |
| ) | |
| assert response.status_code == expected_status | |
| assert not mock_dependencies["store"].create_document.called | |
| def test_unexpected_extraction_exception( | |
| self, valid_session_id, valid_file, mock_dependencies | |
| ): | |
| """Test that unexpected exceptions return 500""" | |
| mock_dependencies["extract"].side_effect = Exception("Unexpected error") | |
| response = client.put( | |
| "/file", | |
| files={"file": valid_file}, | |
| data={"session_id": valid_session_id}, | |
| ) | |
| assert response.status_code == STATUS_CODE_INTERNAL_SERVER_ERROR | |
| def test_exceed_size_limit(self, valid_session_id, valid_file, mock_dependencies): | |
| """Test that size limit exceeded returns correct status""" | |
| mock_dependencies["store"].create_document.return_value = False | |
| response = client.put( | |
| "/file", files={"file": valid_file}, data={"session_id": valid_session_id} | |
| ) | |
| # Should return size limit status code | |
| assert response.status_code == STATUS_CODE_EXCEED_SIZE_LIMIT | |
| # Session tracker should NOT be called when storage fails | |
| assert not mock_dependencies["tracker"].update_session.called | |
| def test_invalid_session_id_format(self, valid_file): | |
| """Test that invalid session_id format is rejected""" | |
| invalid_session_ids = [ | |
| "", # Too short | |
| "a" * 51, # Too long | |
| "invalid@chars!", # Invalid characters | |
| "has spaces", # Spaces not allowed | |
| ] | |
| for invalid_id in invalid_session_ids: | |
| response = client.put( | |
| "/file", files={"file": valid_file}, data={"session_id": invalid_id} | |
| ) | |
| assert response.status_code == 422 # Validation error | |
| def test_empty(self): | |
| response = client.put("/file") | |
| assert response.status_code == 422 | |
| def test_missing_file(self, valid_session_id): | |
| """Test that missing file returns validation error""" | |
| response = client.put("/file", data={"session_id": valid_session_id}) | |
| assert response.status_code == 422 | |
| def test_missing_session_id(self, valid_file): | |
| """Test that missing session_id returns validation error""" | |
| response = client.put("/file", files={"file": valid_file}) | |
| assert response.status_code == 422 | |
| def test_rate_limiting(self, valid_session_id, valid_file, mock_dependencies): | |
| """Test that rate limiting works (12 requests per minute)""" | |
| # Make 13 rapid requests | |
| responses = [] | |
| for i in range(13): | |
| response = client.put( | |
| "/file", | |
| files={"file": ("test.txt", BytesIO(b"content"), "text/plain")}, | |
| data={"session_id": valid_session_id}, | |
| ) | |
| responses.append(response) | |
| # First 12 should succeed (or have normal errors) | |
| # 13th should be rate limited | |
| assert responses[-1].status_code == 429 | |
| def test_pii_sanitization_called( | |
| self, valid_session_id, valid_file, mock_dependencies | |
| ): | |
| """Test that PII filter is applied to extracted text""" | |
| extracted_text = "This is sensitive data" | |
| sanitized_text = "This is [REDACTED] data" | |
| mock_dependencies["extract"].return_value = extracted_text | |
| mock_dependencies["pii"].return_value.sanitize.return_value = sanitized_text | |
| client.put( | |
| "/file", files={"file": valid_file}, data={"session_id": valid_session_id} | |
| ) | |
| # Verify sanitize was called with extracted text | |
| mock_dependencies["pii"].return_value.sanitize.assert_called_once_with( | |
| extracted_text | |
| ) | |
| # Verify store was called with sanitized text | |
| mock_dependencies["store"].create_document.assert_called_once_with( | |
| valid_session_id, sanitized_text, "test.txt" | |
| ) | |
| def test_different_file_types(self, valid_session_id, mock_dependencies): | |
| """Test endpoint with different MIME types""" | |
| file_types = [ | |
| ("test.txt", b"text content", "text/plain"), | |
| ("test.pdf", b"%PDF-1.4 content", "application/pdf"), | |
| ( | |
| "test.docx", | |
| b"docx content", | |
| "application/vnd.openxmlformats-officedocument.wordprocessingml.document", | |
| ), | |
| ] | |
| for filename, content, mime_type in file_types: | |
| response = client.put( | |
| "/file", | |
| files={"file": (filename, BytesIO(content), mime_type)}, | |
| data={"session_id": valid_session_id}, | |
| ) | |
| # Should process all valid types | |
| assert response.status_code == 200 | |