| import pytest
|
| from fastapi.testclient import TestClient
|
| from unittest.mock import Mock, patch
|
| from main import app
|
|
|
| client = TestClient(app)
|
|
|
|
|
| class TestChatEndpoint:
|
| """Test the POST /chat endpoint"""
|
|
|
| @pytest.fixture
|
| def base_required_fields(self):
|
| """Base fields required by IdentifierBase and ProfileBase"""
|
| return {
|
| "user_id": "test-user-123",
|
| "participant_id": "participant-456",
|
| "session_id": "test-session-123",
|
| "consent": True,
|
| "age_group": "25-34",
|
| "gender": "M",
|
| "roles": ["patient"],
|
| }
|
|
|
| @pytest.fixture
|
| def valid_payload(self, base_required_fields):
|
| return {
|
| **base_required_fields,
|
| "conversation_id": "conversation-abc",
|
| "model_type": "champ",
|
| "lang": "en",
|
| "human_message": "What should I do about a fever?",
|
| }
|
|
|
| @pytest.fixture
|
| def mock_dependencies(self):
|
| """Mock all external dependencies"""
|
| with (
|
| patch("main.session_tracker") as mock_tracker,
|
| patch("main.PIIFilter") as mock_pii_class,
|
| patch("main.session_conversation_store") as mock_conv_store,
|
| patch("main.session_document_store") as mock_doc_store,
|
| patch("main.call_llm") as mock_call_llm,
|
| patch("main.log_event") as mock_log_event,
|
| ):
|
|
|
| mock_pii = Mock()
|
| mock_pii.sanitize.return_value = "sanitized message"
|
| mock_pii_class.return_value = mock_pii
|
|
|
|
|
| mock_conv_store.add_human_message.return_value = [
|
| Mock(role="user", content="sanitized message")
|
| ]
|
|
|
|
|
| mock_doc_store.get_document_contents.return_value = None
|
|
|
|
|
| mock_call_llm.return_value = ("AI response", {}, [])
|
|
|
| yield {
|
| "tracker": mock_tracker,
|
| "pii_filter": mock_pii,
|
| "conv_store": mock_conv_store,
|
| "doc_store": mock_doc_store,
|
| "call_llm": mock_call_llm,
|
| "log_event": mock_log_event,
|
| }
|
|
|
|
|
|
|
| def test_chat_success_non_streaming(self, valid_payload, mock_dependencies):
|
| """Test successful non-streaming chat response"""
|
| response = client.post("/chat", json=valid_payload)
|
|
|
| assert response.status_code == 200
|
| assert response.json() == {"reply": "AI response"}
|
|
|
| def test_chat_updates_session_tracker(self, valid_payload, mock_dependencies):
|
| """Test that session tracker is updated"""
|
| client.post("/chat", json=valid_payload)
|
|
|
| mock_dependencies["tracker"].update_session.assert_called_once_with(
|
| "test-session-123"
|
| )
|
|
|
| def test_chat_sanitizes_message(self, valid_payload, mock_dependencies):
|
| """Test that PII filter is applied to message"""
|
| client.post("/chat", json=valid_payload)
|
|
|
| mock_dependencies["pii_filter"].sanitize.assert_called_once_with(
|
| "What should I do about a fever?"
|
| )
|
|
|
| def test_chat_adds_human_message_to_store(self, valid_payload, mock_dependencies):
|
| """Test that sanitized message is added to conversation store"""
|
| client.post("/chat", json=valid_payload)
|
|
|
| mock_dependencies["conv_store"].add_human_message.assert_called_once_with(
|
| "test-session-123", "conversation-abc", "sanitized message"
|
| )
|
|
|
| def test_chat_retrieves_documents(self, valid_payload, mock_dependencies):
|
| """Test that documents are retrieved from document store"""
|
| client.post("/chat", json=valid_payload)
|
|
|
| mock_dependencies["doc_store"].get_document_contents.assert_called_once_with(
|
| "test-session-123"
|
| )
|
|
|
| def test_chat_calls_llm_with_correct_params(self, valid_payload, mock_dependencies):
|
| """Test that call_llm is invoked with correct parameters"""
|
| mock_conversation = [Mock()]
|
| mock_dependencies[
|
| "conv_store"
|
| ].add_human_message.return_value = mock_conversation
|
| mock_dependencies["doc_store"].get_document_contents.return_value = ["doc1"]
|
|
|
| client.post("/chat", json=valid_payload)
|
|
|
|
|
|
|
| mock_dependencies["call_llm"].assert_called_once_with(
|
| "champ", "en", mock_conversation, ["doc1"]
|
| )
|
|
|
| def test_chat_adds_assistant_reply_to_store(self, valid_payload, mock_dependencies):
|
| """Test that assistant reply is added to conversation store"""
|
| client.post("/chat", json=valid_payload)
|
|
|
| mock_dependencies["conv_store"].add_assistant_reply.assert_called_once_with(
|
| "test-session-123", "conversation-abc", "AI response"
|
| )
|
|
|
|
|
|
|
| def test_chat_streaming_response(self, valid_payload, mock_dependencies):
|
| """Test streaming response from OpenAI"""
|
|
|
| async def mock_stream():
|
| yield "Hello "
|
| yield "world"
|
|
|
| mock_dependencies["call_llm"].return_value = mock_stream()
|
|
|
| response = client.post("/chat", json=valid_payload)
|
|
|
| assert response.status_code == 200
|
|
|
| content = response.content.decode()
|
| assert "Hello world" in content
|
|
|
|
|
|
|
| def test_chat_openai_model(self, base_required_fields, mock_dependencies):
|
| """Test chat with OpenAI model"""
|
| payload = {
|
| **base_required_fields,
|
| "conversation_id": "conv-1",
|
| "model_type": "openai",
|
| "lang": "en",
|
| "human_message": "Hello",
|
| }
|
|
|
|
|
| async def mock_stream():
|
| yield "response"
|
|
|
| mock_dependencies["call_llm"].return_value = mock_stream()
|
|
|
| response = client.post("/chat", json=payload)
|
| assert response.status_code == 200
|
|
|
| def test_chat_google_conservative_model(
|
| self, base_required_fields, mock_dependencies
|
| ):
|
| """Test chat with Google conservative model"""
|
| payload = {
|
| **base_required_fields,
|
| "conversation_id": "conv-1",
|
| "model_type": "google-conservative",
|
| "lang": "en",
|
| "human_message": "Hello",
|
| }
|
|
|
| mock_dependencies["call_llm"].return_value = ("Response", {}, [])
|
|
|
| response = client.post("/chat", json=payload)
|
| assert response.status_code == 200
|
| assert response.json() == {"reply": "Response"}
|
|
|
| def test_chat_google_creative_model(self, base_required_fields, mock_dependencies):
|
| """Test chat with Google creative model"""
|
| payload = {
|
| **base_required_fields,
|
| "conversation_id": "conv-1",
|
| "model_type": "google-creative",
|
| "lang": "fr",
|
| "human_message": "Bonjour",
|
| }
|
|
|
| mock_dependencies["call_llm"].return_value = ("Réponse", {}, [])
|
|
|
| response = client.post("/chat", json=payload)
|
| assert response.status_code == 200
|
| assert response.json() == {"reply": "Réponse"}
|
|
|
|
|
|
|
| def test_chat_french_language(self, base_required_fields, mock_dependencies):
|
| """Test chat with French language"""
|
| payload = {
|
| **base_required_fields,
|
| "conversation_id": "conv-1",
|
| "model_type": "champ",
|
| "lang": "fr",
|
| "human_message": "Comment allez-vous?",
|
| }
|
|
|
| response = client.post("/chat", json=payload)
|
| assert response.status_code == 200
|
|
|
|
|
|
|
| def test_chat_missing_human_message(self, base_required_fields, mock_dependencies):
|
| """Test that missing human_message returns 422"""
|
| payload = {
|
| **base_required_fields,
|
| "conversation_id": "conv-1",
|
| "model_type": "champ",
|
| "lang": "en",
|
| }
|
|
|
| response = client.post("/chat", json=payload)
|
| assert response.status_code == 422
|
|
|
| def test_chat_empty_human_message(self, base_required_fields, mock_dependencies):
|
| """Test that empty human_message is rejected"""
|
| payload = {
|
| **base_required_fields,
|
| "conversation_id": "conv-1",
|
| "model_type": "champ",
|
| "lang": "en",
|
| "human_message": "",
|
| }
|
|
|
| response = client.post("/chat", json=payload)
|
| assert response.status_code == 422
|
|
|
| def test_chat_invalid_model_type(self, base_required_fields, mock_dependencies):
|
| """Test that invalid model_type is rejected"""
|
| payload = {
|
| **base_required_fields,
|
| "conversation_id": "conv-1",
|
| "model_type": "invalid-model",
|
| "lang": "en",
|
| "human_message": "Hello",
|
| }
|
|
|
| response = client.post("/chat", json=payload)
|
| assert response.status_code == 422
|
|
|
| def test_chat_invalid_language(self, base_required_fields, mock_dependencies):
|
| """Test that invalid language is rejected"""
|
| payload = {
|
| **base_required_fields,
|
| "conversation_id": "conv-1",
|
| "model_type": "champ",
|
| "lang": "es",
|
| "human_message": "Hello",
|
| }
|
|
|
| response = client.post("/chat", json=payload)
|
| assert response.status_code == 422
|
|
|
| def test_chat_message_too_long(self, base_required_fields, mock_dependencies):
|
| """Test that message exceeding MAX_MESSAGE_LENGTH is rejected"""
|
| payload = {
|
| **base_required_fields,
|
| "conversation_id": "conv-1",
|
| "model_type": "champ",
|
| "lang": "en",
|
| "human_message": "x" * 100000,
|
| }
|
|
|
| response = client.post("/chat", json=payload)
|
| assert response.status_code == 422
|
|
|
| def test_chat_sanitizes_html_in_message(
|
| self, base_required_fields, mock_dependencies
|
| ):
|
| """Test that HTML tags are removed from human_message"""
|
| payload = {
|
| **base_required_fields,
|
| "conversation_id": "conv-1",
|
| "model_type": "champ",
|
| "lang": "en",
|
| "human_message": "<script>alert('xss')</script>Hello",
|
| }
|
|
|
| response = client.post("/chat", json=payload)
|
|
|
| assert response.status_code == 200
|
|
|
| def test_chat_invalid_conversation_id(
|
| self, base_required_fields, mock_dependencies
|
| ):
|
| """Test that invalid conversation_id is rejected"""
|
| payload = {
|
| **base_required_fields,
|
| "conversation_id": "invalid@id!",
|
| "model_type": "champ",
|
| "lang": "en",
|
| "human_message": "Hello",
|
| }
|
|
|
| response = client.post("/chat", json=payload)
|
| assert response.status_code == 422
|
|
|
|
|
|
|
| @pytest.mark.enable_rate_limit
|
| def test_chat_rate_limiting(self, valid_payload, mock_dependencies):
|
| """Test that rate limiting works (20 requests per minute)"""
|
| from fastapi.testclient import TestClient
|
| from main import app
|
|
|
| rate_limit_client = TestClient(app)
|
|
|
|
|
| responses = []
|
| for i in range(21):
|
| response = rate_limit_client.post("/chat", json=valid_payload)
|
| responses.append(response)
|
|
|
|
|
| assert responses[-1].status_code == 429
|
|
|
|
|
|
|
| def test_chat_full_workflow(self, valid_payload, mock_dependencies):
|
| """Test complete chat workflow"""
|
| mock_conversation = [Mock(role="user", content="sanitized message")]
|
| mock_dependencies[
|
| "conv_store"
|
| ].add_human_message.return_value = mock_conversation
|
| mock_dependencies["doc_store"].get_document_contents.return_value = ["doc1"]
|
| mock_dependencies["call_llm"].return_value = (
|
| "Full response",
|
| {"key": "value"},
|
| ["ctx"],
|
| )
|
|
|
| response = client.post("/chat", json=valid_payload)
|
|
|
| assert response.status_code == 200
|
| assert response.json() == {"reply": "Full response"}
|
|
|
|
|
| mock_dependencies["tracker"].update_session.assert_called_once()
|
| mock_dependencies["pii_filter"].sanitize.assert_called_once()
|
| mock_dependencies["conv_store"].add_human_message.assert_called_once()
|
| mock_dependencies["doc_store"].get_document_contents.assert_called_once()
|
| mock_dependencies["call_llm"].assert_called_once()
|
| mock_dependencies["conv_store"].add_assistant_reply.assert_called_once()
|
|
|
| def test_chat_with_documents(self, valid_payload, mock_dependencies):
|
| """Test chat when user has uploaded documents"""
|
| mock_dependencies["doc_store"].get_document_contents.return_value = [
|
| "Document content 1",
|
| "Document content 2",
|
| ]
|
|
|
| response = client.post("/chat", json=valid_payload)
|
|
|
| assert response.status_code == 200
|
|
|
|
|
|
|
| def test_chat_multiple_messages_same_conversation(
|
| self, base_required_fields, mock_dependencies
|
| ):
|
| """Test multiple messages in same conversation"""
|
| payload1 = {
|
| **base_required_fields,
|
| "conversation_id": "conv-1",
|
| "model_type": "champ",
|
| "lang": "en",
|
| "human_message": "First message",
|
| }
|
| payload2 = {
|
| **base_required_fields,
|
| "conversation_id": "conv-1",
|
| "model_type": "champ",
|
| "lang": "en",
|
| "human_message": "Second message",
|
| }
|
|
|
| response1 = client.post("/chat", json=payload1)
|
| response2 = client.post("/chat", json=payload2)
|
|
|
| assert response1.status_code == 200
|
| assert response2.status_code == 200
|
|
|
| def test_chat_different_conversations_same_session(
|
| self, base_required_fields, mock_dependencies
|
| ):
|
| """Test different conversations in same session"""
|
| payload1 = {
|
| **base_required_fields,
|
| "conversation_id": "conv-1",
|
| "model_type": "champ",
|
| "lang": "en",
|
| "human_message": "Message in conv 1",
|
| }
|
| payload2 = {
|
| **base_required_fields,
|
| "conversation_id": "conv-2",
|
| "model_type": "champ",
|
| "lang": "en",
|
| "human_message": "Message in conv 2",
|
| }
|
|
|
| response1 = client.post("/chat", json=payload1)
|
| response2 = client.post("/chat", json=payload2)
|
|
|
| assert response1.status_code == 200
|
| assert response2.status_code == 200
|
|
|
|
|
|
|
| def test_chat_special_characters_in_message(
|
| self, base_required_fields, mock_dependencies
|
| ):
|
| """Test message with special characters"""
|
| payload = {
|
| **base_required_fields,
|
| "conversation_id": "conv-1",
|
| "model_type": "champ",
|
| "lang": "en",
|
| "human_message": "Hello! 你好 🎉 @#$%",
|
| }
|
|
|
| response = client.post("/chat", json=payload)
|
| assert response.status_code == 200
|
|
|
| def test_chat_multiline_message(self, base_required_fields, mock_dependencies):
|
| """Test message with newlines"""
|
| payload = {
|
| **base_required_fields,
|
| "conversation_id": "conv-1",
|
| "model_type": "champ",
|
| "lang": "en",
|
| "human_message": "Line 1\nLine 2\nLine 3",
|
| }
|
|
|
| response = client.post("/chat", json=payload)
|
| assert response.status_code == 200
|
|
|
| def test_chat_empty_reply_from_llm(self, valid_payload, mock_dependencies):
|
| """Test handling of empty reply from LLM"""
|
| mock_dependencies["call_llm"].return_value = ("", {}, [])
|
|
|
| response = client.post("/chat", json=valid_payload)
|
| assert response.status_code == 200
|
| assert response.json() == {"reply": ""}
|
|
|