import pytest from fastapi.testclient import TestClient from unittest.mock import Mock, patch from main import app client = TestClient(app) class TestChatEndpoint: """Test the POST /chat endpoint""" @pytest.fixture def base_required_fields(self): """Base fields required by IdentifierBase and ProfileBase""" return { "user_id": "test-user-123", "participant_id": "participant-456", "session_id": "test-session-123", "consent": True, "age_group": "25-34", "gender": "M", "roles": ["patient"], } @pytest.fixture def valid_payload(self, base_required_fields): return { **base_required_fields, "conversation_id": "conversation-abc", "model_type": "champ", "lang": "en", "human_message": "What should I do about a fever?", } @pytest.fixture def mock_dependencies(self): """Mock all external dependencies""" with ( patch("main.session_tracker") as mock_tracker, patch("main.PIIFilter") as mock_pii_class, patch("main.session_conversation_store") as mock_conv_store, patch("main.session_document_store") as mock_doc_store, patch("main.call_llm") as mock_call_llm, patch("main.log_event") as mock_log_event, ): # Setup PIIFilter mock_pii = Mock() mock_pii.sanitize.return_value = "sanitized message" mock_pii_class.return_value = mock_pii # Setup conversation store mock_conv_store.add_human_message.return_value = [ Mock(role="user", content="sanitized message") ] # Setup document store mock_doc_store.get_document_contents.return_value = None # Setup call_llm (non-streaming by default) mock_call_llm.return_value = ("AI response", {}, []) yield { "tracker": mock_tracker, "pii_filter": mock_pii, "conv_store": mock_conv_store, "doc_store": mock_doc_store, "call_llm": mock_call_llm, "log_event": mock_log_event, } # ==================== Successful Chat Tests ==================== def test_chat_success_non_streaming(self, valid_payload, mock_dependencies): """Test successful non-streaming chat response""" response = client.post("/chat", json=valid_payload) assert response.status_code == 200 assert response.json() == {"reply": "AI response"} def test_chat_updates_session_tracker(self, valid_payload, mock_dependencies): """Test that session tracker is updated""" client.post("/chat", json=valid_payload) mock_dependencies["tracker"].update_session.assert_called_once_with( "test-session-123" ) def test_chat_sanitizes_message(self, valid_payload, mock_dependencies): """Test that PII filter is applied to message""" client.post("/chat", json=valid_payload) mock_dependencies["pii_filter"].sanitize.assert_called_once_with( "What should I do about a fever?" ) def test_chat_adds_human_message_to_store(self, valid_payload, mock_dependencies): """Test that sanitized message is added to conversation store""" client.post("/chat", json=valid_payload) mock_dependencies["conv_store"].add_human_message.assert_called_once_with( "test-session-123", "conversation-abc", "sanitized message" ) def test_chat_retrieves_documents(self, valid_payload, mock_dependencies): """Test that documents are retrieved from document store""" client.post("/chat", json=valid_payload) mock_dependencies["doc_store"].get_document_contents.assert_called_once_with( "test-session-123" ) def test_chat_calls_llm_with_correct_params(self, valid_payload, mock_dependencies): """Test that call_llm is invoked with correct parameters""" mock_conversation = [Mock()] mock_dependencies[ "conv_store" ].add_human_message.return_value = mock_conversation mock_dependencies["doc_store"].get_document_contents.return_value = ["doc1"] client.post("/chat", json=valid_payload) # call_llm is wrapped in run_in_executor, so we need to wait # The test client handles this synchronously mock_dependencies["call_llm"].assert_called_once_with( "champ", "en", mock_conversation, ["doc1"] ) def test_chat_adds_assistant_reply_to_store(self, valid_payload, mock_dependencies): """Test that assistant reply is added to conversation store""" client.post("/chat", json=valid_payload) mock_dependencies["conv_store"].add_assistant_reply.assert_called_once_with( "test-session-123", "conversation-abc", "AI response" ) # ==================== Streaming Response Tests ==================== def test_chat_streaming_response(self, valid_payload, mock_dependencies): """Test streaming response from OpenAI""" async def mock_stream(): yield "Hello " yield "world" mock_dependencies["call_llm"].return_value = mock_stream() response = client.post("/chat", json=valid_payload) assert response.status_code == 200 # StreamingResponse returns chunks content = response.content.decode() assert "Hello world" in content # ==================== Different Model Types Tests ==================== def test_chat_openai_model(self, base_required_fields, mock_dependencies): """Test chat with OpenAI model""" payload = { **base_required_fields, "conversation_id": "conv-1", "model_type": "openai", "lang": "en", "human_message": "Hello", } # OpenAI returns AsyncGenerator async def mock_stream(): yield "response" mock_dependencies["call_llm"].return_value = mock_stream() response = client.post("/chat", json=payload) assert response.status_code == 200 def test_chat_google_conservative_model( self, base_required_fields, mock_dependencies ): """Test chat with Google conservative model""" payload = { **base_required_fields, "conversation_id": "conv-1", "model_type": "google-conservative", "lang": "en", "human_message": "Hello", } mock_dependencies["call_llm"].return_value = ("Response", {}, []) response = client.post("/chat", json=payload) assert response.status_code == 200 assert response.json() == {"reply": "Response"} def test_chat_google_creative_model(self, base_required_fields, mock_dependencies): """Test chat with Google creative model""" payload = { **base_required_fields, "conversation_id": "conv-1", "model_type": "google-creative", "lang": "fr", "human_message": "Bonjour", } mock_dependencies["call_llm"].return_value = ("Réponse", {}, []) response = client.post("/chat", json=payload) assert response.status_code == 200 assert response.json() == {"reply": "Réponse"} # ==================== Language Tests ==================== def test_chat_french_language(self, base_required_fields, mock_dependencies): """Test chat with French language""" payload = { **base_required_fields, "conversation_id": "conv-1", "model_type": "champ", "lang": "fr", "human_message": "Comment allez-vous?", } response = client.post("/chat", json=payload) assert response.status_code == 200 # ==================== Request Validation Tests ==================== def test_chat_missing_human_message(self, base_required_fields, mock_dependencies): """Test that missing human_message returns 422""" payload = { **base_required_fields, "conversation_id": "conv-1", "model_type": "champ", "lang": "en", } response = client.post("/chat", json=payload) assert response.status_code == 422 def test_chat_empty_human_message(self, base_required_fields, mock_dependencies): """Test that empty human_message is rejected""" payload = { **base_required_fields, "conversation_id": "conv-1", "model_type": "champ", "lang": "en", "human_message": "", } response = client.post("/chat", json=payload) assert response.status_code == 422 def test_chat_invalid_model_type(self, base_required_fields, mock_dependencies): """Test that invalid model_type is rejected""" payload = { **base_required_fields, "conversation_id": "conv-1", "model_type": "invalid-model", "lang": "en", "human_message": "Hello", } response = client.post("/chat", json=payload) assert response.status_code == 422 def test_chat_invalid_language(self, base_required_fields, mock_dependencies): """Test that invalid language is rejected""" payload = { **base_required_fields, "conversation_id": "conv-1", "model_type": "champ", "lang": "es", # Not in Literal["en", "fr"] "human_message": "Hello", } response = client.post("/chat", json=payload) assert response.status_code == 422 def test_chat_message_too_long(self, base_required_fields, mock_dependencies): """Test that message exceeding MAX_MESSAGE_LENGTH is rejected""" payload = { **base_required_fields, "conversation_id": "conv-1", "model_type": "champ", "lang": "en", "human_message": "x" * 100000, # Assuming this exceeds limit } response = client.post("/chat", json=payload) assert response.status_code == 422 def test_chat_sanitizes_html_in_message( self, base_required_fields, mock_dependencies ): """Test that HTML tags are removed from human_message""" payload = { **base_required_fields, "conversation_id": "conv-1", "model_type": "champ", "lang": "en", "human_message": "Hello", } response = client.post("/chat", json=payload) # Should succeed with sanitized message assert response.status_code == 200 def test_chat_invalid_conversation_id( self, base_required_fields, mock_dependencies ): """Test that invalid conversation_id is rejected""" payload = { **base_required_fields, "conversation_id": "invalid@id!", "model_type": "champ", "lang": "en", "human_message": "Hello", } response = client.post("/chat", json=payload) assert response.status_code == 422 # ==================== Rate Limiting Tests ==================== @pytest.mark.enable_rate_limit def test_chat_rate_limiting(self, valid_payload, mock_dependencies): """Test that rate limiting works (20 requests per minute)""" from fastapi.testclient import TestClient from main import app rate_limit_client = TestClient(app) # Make 21 rapid requests responses = [] for i in range(21): response = rate_limit_client.post("/chat", json=valid_payload) responses.append(response) # 21st should be rate limited assert responses[-1].status_code == 429 # ==================== Integration Tests ==================== def test_chat_full_workflow(self, valid_payload, mock_dependencies): """Test complete chat workflow""" mock_conversation = [Mock(role="user", content="sanitized message")] mock_dependencies[ "conv_store" ].add_human_message.return_value = mock_conversation mock_dependencies["doc_store"].get_document_contents.return_value = ["doc1"] mock_dependencies["call_llm"].return_value = ( "Full response", {"key": "value"}, ["ctx"], ) response = client.post("/chat", json=valid_payload) assert response.status_code == 200 assert response.json() == {"reply": "Full response"} # Verify workflow order mock_dependencies["tracker"].update_session.assert_called_once() mock_dependencies["pii_filter"].sanitize.assert_called_once() mock_dependencies["conv_store"].add_human_message.assert_called_once() mock_dependencies["doc_store"].get_document_contents.assert_called_once() mock_dependencies["call_llm"].assert_called_once() mock_dependencies["conv_store"].add_assistant_reply.assert_called_once() def test_chat_with_documents(self, valid_payload, mock_dependencies): """Test chat when user has uploaded documents""" mock_dependencies["doc_store"].get_document_contents.return_value = [ "Document content 1", "Document content 2", ] response = client.post("/chat", json=valid_payload) assert response.status_code == 200 # TODO # Documents should be passed to call_llm def test_chat_multiple_messages_same_conversation( self, base_required_fields, mock_dependencies ): """Test multiple messages in same conversation""" payload1 = { **base_required_fields, "conversation_id": "conv-1", "model_type": "champ", "lang": "en", "human_message": "First message", } payload2 = { **base_required_fields, "conversation_id": "conv-1", "model_type": "champ", "lang": "en", "human_message": "Second message", } response1 = client.post("/chat", json=payload1) response2 = client.post("/chat", json=payload2) assert response1.status_code == 200 assert response2.status_code == 200 def test_chat_different_conversations_same_session( self, base_required_fields, mock_dependencies ): """Test different conversations in same session""" payload1 = { **base_required_fields, "conversation_id": "conv-1", "model_type": "champ", "lang": "en", "human_message": "Message in conv 1", } payload2 = { **base_required_fields, "conversation_id": "conv-2", "model_type": "champ", "lang": "en", "human_message": "Message in conv 2", } response1 = client.post("/chat", json=payload1) response2 = client.post("/chat", json=payload2) assert response1.status_code == 200 assert response2.status_code == 200 # ==================== Edge Cases ==================== def test_chat_special_characters_in_message( self, base_required_fields, mock_dependencies ): """Test message with special characters""" payload = { **base_required_fields, "conversation_id": "conv-1", "model_type": "champ", "lang": "en", "human_message": "Hello! 你好 🎉 @#$%", } response = client.post("/chat", json=payload) assert response.status_code == 200 def test_chat_multiline_message(self, base_required_fields, mock_dependencies): """Test message with newlines""" payload = { **base_required_fields, "conversation_id": "conv-1", "model_type": "champ", "lang": "en", "human_message": "Line 1\nLine 2\nLine 3", } response = client.post("/chat", json=payload) assert response.status_code == 200 def test_chat_empty_reply_from_llm(self, valid_payload, mock_dependencies): """Test handling of empty reply from LLM""" mock_dependencies["call_llm"].return_value = ("", {}, []) response = client.post("/chat", json=valid_payload) assert response.status_code == 200 assert response.json() == {"reply": ""}