champ-chatbot / tests /api /test_file_put.py
qyle's picture
deployment
8b9e569 verified
import pytest
from fastapi.testclient import TestClient
from unittest.mock import Mock, patch, AsyncMock
from io import BytesIO
from constants import (
STATUS_CODE_BAD_REQUEST,
STATUS_CODE_CONTENT_TOO_LARGE,
STATUS_CODE_EXCEED_SIZE_LIMIT,
STATUS_CODE_INTERNAL_SERVER_ERROR,
STATUS_CODE_LENGTH_REQUIRED,
STATUS_CODE_UNSUPPORTED_MEDIA_TYPE,
)
from exceptions import (
FileExtractionError,
FileExtractionException,
FileValidationError,
FileValidationException,
)
from main import app
client = TestClient(app)
FILE_VALIDATION_TEST_CASES = [
(FileValidationError.MISSING_SIZE, STATUS_CODE_LENGTH_REQUIRED),
(FileValidationError.FILE_TOO_LARGE, STATUS_CODE_CONTENT_TOO_LARGE),
(FileValidationError.MISSING_FILE_NAME, STATUS_CODE_BAD_REQUEST),
(FileValidationError.FILE_NAME_TOO_LARGE, STATUS_CODE_BAD_REQUEST),
(FileValidationError.INVALID_FILE_NAME, STATUS_CODE_BAD_REQUEST),
(FileValidationError.INVALID_MIME_TYPE, STATUS_CODE_UNSUPPORTED_MEDIA_TYPE),
(FileValidationError.UNSUPPORTED_EXTENSION, STATUS_CODE_UNSUPPORTED_MEDIA_TYPE),
(FileValidationError.EMPTY_FILE, STATUS_CODE_BAD_REQUEST),
]
FILE_EXTRACTION_TEST_CASES = [
(FileExtractionError.INVALID_MIME_TYPE, STATUS_CODE_UNSUPPORTED_MEDIA_TYPE),
(FileExtractionError.NO_TEXT, STATUS_CODE_BAD_REQUEST),
(FileExtractionError.TEXT_EXTRACTION_TIMEOUT, STATUS_CODE_INTERNAL_SERVER_ERROR),
(FileExtractionError.UNSAFE_ZIP, STATUS_CODE_INTERNAL_SERVER_ERROR),
(FileExtractionError.FILE_TOO_LARGE, STATUS_CODE_CONTENT_TOO_LARGE),
(FileExtractionError.MALFORMED_FILE, STATUS_CODE_BAD_REQUEST),
]
class TestUploadFileEndpoint:
"""Test the PUT /file endpoint"""
@pytest.fixture
def valid_session_id(self):
return "test-session-123"
@pytest.fixture
def valid_file(self):
"""Create a valid file upload"""
return ("test.txt", BytesIO(b"Test file content"), "text/plain")
@pytest.fixture
def mock_dependencies(self):
"""Mock all external dependencies"""
with (
patch("main.validate_file") as mock_validate,
patch("main.extract_text_from_file") as mock_extract,
patch("main.PIIFilter") as mock_pii,
patch("main.session_document_store") as mock_store,
patch("main.session_tracker") as mock_tracker,
):
# Setup default successful behavior
mock_validate.return_value = AsyncMock(
content=b"content", filename="test.txt", mime_type="text/plain"
)
mock_extract.return_value = "extracted text content"
mock_pii.return_value.sanitize.return_value = "sanitized text"
mock_store.create_document.return_value = True
yield {
"validate": mock_validate,
"extract": mock_extract,
"pii": mock_pii,
"store": mock_store,
"tracker": mock_tracker,
}
def test_successful_file_upload(
self, valid_session_id, valid_file, mock_dependencies
):
"""Test successful file upload with all valid inputs"""
response = client.put(
"/file", files={"file": valid_file}, data={"session_id": valid_session_id}
)
assert response.status_code == 200
# Verify the workflow was executed
assert mock_dependencies["validate"].called
assert mock_dependencies["extract"].called
assert mock_dependencies["pii"].return_value.sanitize.called
assert mock_dependencies["store"].create_document.called
assert mock_dependencies["tracker"].update_session.called
@pytest.mark.parametrize(
"validation_error,expected_status", FILE_VALIDATION_TEST_CASES
)
def test_file_validation_failure(
self,
valid_session_id,
valid_file,
mock_dependencies,
validation_error,
expected_status,
):
"""Test that validation errors return correct status codes"""
# Mock validation to raise an exception
mock_dependencies["validate"].side_effect = FileValidationException(
validation_error
)
response = client.put(
"/file", files={"file": valid_file}, data={"session_id": valid_session_id}
)
# Should return the appropriate error status code
assert response.status_code == expected_status
# Verify later steps weren't called
assert not mock_dependencies["extract"].called
assert not mock_dependencies["store"].create_document.called
@pytest.mark.parametrize(
"extraction_error,expected_status", FILE_EXTRACTION_TEST_CASES
)
def test_file_extraction_failure(
self,
valid_session_id,
valid_file,
mock_dependencies,
extraction_error,
expected_status,
):
"""Test that extraction errors return correct status codes"""
mock_dependencies["extract"].side_effect = FileExtractionException(
extraction_error
)
response = client.put(
"/file", files={"file": valid_file}, data={"session_id": valid_session_id}
)
assert response.status_code == expected_status
assert not mock_dependencies["store"].create_document.called
def test_unexpected_extraction_exception(
self, valid_session_id, valid_file, mock_dependencies
):
"""Test that unexpected exceptions return 500"""
mock_dependencies["extract"].side_effect = Exception("Unexpected error")
response = client.put(
"/file",
files={"file": valid_file},
data={"session_id": valid_session_id},
)
assert response.status_code == STATUS_CODE_INTERNAL_SERVER_ERROR
def test_exceed_size_limit(self, valid_session_id, valid_file, mock_dependencies):
"""Test that size limit exceeded returns correct status"""
mock_dependencies["store"].create_document.return_value = False
response = client.put(
"/file", files={"file": valid_file}, data={"session_id": valid_session_id}
)
# Should return size limit status code
assert response.status_code == STATUS_CODE_EXCEED_SIZE_LIMIT
# Session tracker should NOT be called when storage fails
assert not mock_dependencies["tracker"].update_session.called
def test_invalid_session_id_format(self, valid_file):
"""Test that invalid session_id format is rejected"""
invalid_session_ids = [
"", # Too short
"a" * 51, # Too long
"invalid@chars!", # Invalid characters
"has spaces", # Spaces not allowed
]
for invalid_id in invalid_session_ids:
response = client.put(
"/file", files={"file": valid_file}, data={"session_id": invalid_id}
)
assert response.status_code == 422 # Validation error
def test_empty(self):
response = client.put("/file")
assert response.status_code == 422
def test_missing_file(self, valid_session_id):
"""Test that missing file returns validation error"""
response = client.put("/file", data={"session_id": valid_session_id})
assert response.status_code == 422
def test_missing_session_id(self, valid_file):
"""Test that missing session_id returns validation error"""
response = client.put("/file", files={"file": valid_file})
assert response.status_code == 422
@pytest.mark.enable_rate_limit
def test_rate_limiting(self, valid_session_id, valid_file, mock_dependencies):
"""Test that rate limiting works (12 requests per minute)"""
# Make 13 rapid requests
responses = []
for i in range(13):
response = client.put(
"/file",
files={"file": ("test.txt", BytesIO(b"content"), "text/plain")},
data={"session_id": valid_session_id},
)
responses.append(response)
# First 12 should succeed (or have normal errors)
# 13th should be rate limited
assert responses[-1].status_code == 429
def test_pii_sanitization_called(
self, valid_session_id, valid_file, mock_dependencies
):
"""Test that PII filter is applied to extracted text"""
extracted_text = "This is sensitive data"
sanitized_text = "This is [REDACTED] data"
mock_dependencies["extract"].return_value = extracted_text
mock_dependencies["pii"].return_value.sanitize.return_value = sanitized_text
client.put(
"/file", files={"file": valid_file}, data={"session_id": valid_session_id}
)
# Verify sanitize was called with extracted text
mock_dependencies["pii"].return_value.sanitize.assert_called_once_with(
extracted_text
)
# Verify store was called with sanitized text
mock_dependencies["store"].create_document.assert_called_once_with(
valid_session_id, sanitized_text, "test.txt"
)
def test_different_file_types(self, valid_session_id, mock_dependencies):
"""Test endpoint with different MIME types"""
file_types = [
("test.txt", b"text content", "text/plain"),
("test.pdf", b"%PDF-1.4 content", "application/pdf"),
(
"test.docx",
b"docx content",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
),
]
for filename, content, mime_type in file_types:
response = client.put(
"/file",
files={"file": (filename, BytesIO(content), mime_type)},
data={"session_id": valid_session_id},
)
# Should process all valid types
assert response.status_code == 200