Spaces:
Sleeping
Sleeping
| import json | |
| # These tests are integration-like and exercise chat endpoints. | |
| import os | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| from app import app as flask_app | |
| def app(): | |
| yield flask_app | |
| def client(app): | |
| return app.test_client() | |
| class TestChatEndpoint: | |
| """Test cases for the /chat endpoint""" | |
| def test_chat_endpoint_valid_request( | |
| self, | |
| mock_embedding, | |
| mock_vector, | |
| mock_search, | |
| mock_llm, | |
| mock_formatter, | |
| mock_rag, | |
| client, | |
| ): | |
| """Test chat endpoint with valid request""" | |
| # Mock the RAG pipeline response | |
| mock_response = { | |
| "answer": ("Based on the remote work policy, employees can work " "remotely up to 3 days per week."), | |
| "confidence": 0.85, | |
| "sources": [{"chunk_id": "123", "content": "Remote work policy content..."}], | |
| "citations": ["remote_work_policy.md"], | |
| "processing_time_ms": 1500, | |
| } | |
| # Setup mock instances | |
| mock_rag_instance = MagicMock() | |
| mock_rag_instance.generate_answer.return_value = mock_response | |
| mock_rag.return_value = mock_rag_instance | |
| mock_formatter_instance = MagicMock() | |
| mock_formatter_instance.format_api_response.return_value = { | |
| "status": "success", | |
| "answer": mock_response["answer"], | |
| "confidence": mock_response["confidence"], | |
| "sources": mock_response["sources"], | |
| "citations": mock_response["citations"], | |
| } | |
| mock_formatter.return_value = mock_formatter_instance | |
| # Mock LLMService.from_environment to return a mock instance | |
| mock_llm_instance = MagicMock() | |
| mock_llm.from_environment.return_value = mock_llm_instance | |
| request_data = { | |
| "message": "What is the remote work policy?", | |
| "include_sources": True, | |
| } | |
| response = client.post("/chat", data=json.dumps(request_data), content_type="application/json") | |
| assert response.status_code == 200 | |
| data = response.get_json() | |
| assert data["status"] == "success" | |
| assert "answer" in data | |
| assert "confidence" in data | |
| assert "sources" in data | |
| assert "citations" in data | |
| def test_chat_endpoint_minimal_request( | |
| self, | |
| mock_embedding, | |
| mock_vector, | |
| mock_search, | |
| mock_llm, | |
| mock_formatter, | |
| mock_rag, | |
| client, | |
| ): | |
| """Test chat endpoint with minimal request (only message)""" | |
| mock_response = { | |
| "answer": ("Employee benefits include health insurance, " "retirement plans, and PTO."), | |
| "confidence": 0.78, | |
| "sources": [], | |
| "citations": ["employee_benefits_guide.md"], | |
| "processing_time_ms": 1200, | |
| } | |
| # Setup mock instances | |
| mock_rag_instance = MagicMock() | |
| mock_rag_instance.generate_answer.return_value = mock_response | |
| mock_rag.return_value = mock_rag_instance | |
| mock_formatter_instance = MagicMock() | |
| mock_formatter_instance.format_api_response.return_value = { | |
| "status": "success", | |
| "answer": mock_response["answer"], | |
| } | |
| mock_formatter.return_value = mock_formatter_instance | |
| mock_llm.from_environment.return_value = MagicMock() | |
| request_data = {"message": "What are the employee benefits?"} | |
| response = client.post("/chat", data=json.dumps(request_data), content_type="application/json") | |
| assert response.status_code == 200 | |
| data = response.get_json() | |
| assert data["status"] == "success" | |
| def test_chat_endpoint_missing_message(self, client): | |
| """Test chat endpoint with missing message parameter""" | |
| request_data = {"include_sources": True} | |
| response = client.post("/chat", data=json.dumps(request_data), content_type="application/json") | |
| assert response.status_code == 400 | |
| data = response.get_json() | |
| assert data["status"] == "error" | |
| assert "message parameter is required" in data["message"] | |
| def test_chat_endpoint_empty_message(self, client): | |
| """Test chat endpoint with empty message""" | |
| request_data = {"message": ""} | |
| response = client.post("/chat", data=json.dumps(request_data), content_type="application/json") | |
| assert response.status_code == 400 | |
| data = response.get_json() | |
| assert data["status"] == "error" | |
| assert "non-empty string" in data["message"] | |
| def test_chat_endpoint_non_string_message(self, client): | |
| """Test chat endpoint with non-string message""" | |
| request_data = {"message": 123} | |
| response = client.post("/chat", data=json.dumps(request_data), content_type="application/json") | |
| assert response.status_code == 400 | |
| data = response.get_json() | |
| assert data["status"] == "error" | |
| assert "non-empty string" in data["message"] | |
| def test_chat_endpoint_non_json_request(self, client): | |
| """Test chat endpoint with non-JSON request""" | |
| response = client.post("/chat", data="not json", content_type="text/plain") | |
| assert response.status_code == 400 | |
| data = response.get_json() | |
| assert data["status"] == "error" | |
| assert "application/json" in data["message"] | |
| def test_chat_endpoint_no_llm_config(self, client): | |
| """Test chat endpoint with no LLM configuration""" | |
| with patch.dict(os.environ, {}, clear=True): | |
| request_data = {"message": "What is the policy?"} | |
| response = client.post("/chat", data=json.dumps(request_data), content_type="application/json") | |
| assert response.status_code == 503 | |
| data = response.get_json() | |
| assert data["status"] == "error" | |
| assert "LLM service configuration error" in data["message"] | |
| def test_chat_endpoint_with_conversation_id( | |
| self, | |
| mock_embedding, | |
| mock_vector, | |
| mock_search, | |
| mock_llm, | |
| mock_formatter, | |
| mock_rag, | |
| client, | |
| ): | |
| """Test chat endpoint with conversation_id parameter""" | |
| mock_response = { | |
| "answer": "The PTO policy allows 15 days of vacation annually.", | |
| "confidence": 0.9, | |
| "sources": [], | |
| "citations": ["pto_policy.md"], | |
| "processing_time_ms": 1100, | |
| } | |
| # Setup mock instances | |
| mock_rag_instance = MagicMock() | |
| mock_rag_instance.generate_answer.return_value = mock_response | |
| mock_rag.return_value = mock_rag_instance | |
| mock_formatter_instance = MagicMock() | |
| mock_formatter_instance.format_chat_response.return_value = { | |
| "status": "success", | |
| "answer": mock_response["answer"], | |
| } | |
| mock_formatter.return_value = mock_formatter_instance | |
| mock_llm.from_environment.return_value = MagicMock() | |
| request_data = { | |
| "message": "What is the PTO policy?", | |
| "conversation_id": "conv_123", | |
| "include_sources": False, | |
| } | |
| response = client.post("/chat", data=json.dumps(request_data), content_type="application/json") | |
| assert response.status_code == 200 | |
| data = response.get_json() | |
| assert data["status"] == "success" | |
| def test_chat_endpoint_with_debug( | |
| self, | |
| mock_embedding, | |
| mock_vector, | |
| mock_search, | |
| mock_llm, | |
| mock_formatter, | |
| mock_rag, | |
| client, | |
| ): | |
| """Test chat endpoint with debug information""" | |
| mock_response = { | |
| "answer": "The security policy requires 2FA authentication.", | |
| "confidence": 0.95, | |
| "sources": [{"chunk_id": "456", "content": "Security requirements..."}], | |
| "citations": ["information_security_policy.md"], | |
| "processing_time_ms": 1800, | |
| "search_results_count": 5, | |
| "context_length": 2048, | |
| } | |
| # Setup mock instances | |
| mock_rag_instance = MagicMock() | |
| mock_rag_instance.generate_answer.return_value = mock_response | |
| mock_rag.return_value = mock_rag_instance | |
| mock_formatter_instance = MagicMock() | |
| mock_formatter_instance.format_api_response.return_value = { | |
| "status": "success", | |
| "answer": mock_response["answer"], | |
| "debug": {"processing_time": 1800}, | |
| } | |
| mock_formatter.return_value = mock_formatter_instance | |
| mock_llm.from_environment.return_value = MagicMock() | |
| request_data = { | |
| "message": "What are the security requirements?", | |
| "include_debug": True, | |
| } | |
| response = client.post("/chat", data=json.dumps(request_data), content_type="application/json") | |
| assert response.status_code == 200 | |
| data = response.get_json() | |
| assert data["status"] == "success" | |
| def test_chat_endpoint_refuses_off_corpus_questions(self, mock_get_rag_pipeline, client): | |
| """Test that /chat refuses to answer questions outside the corpus""" | |
| # Mock RAG pipeline to simulate off-corpus refusal | |
| mock_pipeline = MagicMock() | |
| mock_response = MagicMock() | |
| mock_response.answer = "I can only answer questions about our company policies and procedures." | |
| mock_response.confidence = 0.0 | |
| mock_response.sources = [] | |
| mock_pipeline.generate_answer.return_value = mock_response | |
| mock_get_rag_pipeline.return_value = mock_pipeline | |
| # Ask an off-corpus question | |
| request_data = {"message": "What is the capital of France?"} | |
| response = client.post("/chat", data=json.dumps(request_data), content_type="application/json") | |
| assert response.status_code == 200 | |
| data = response.get_json() | |
| assert data["status"] == "success" | |
| # Should contain refusal language | |
| assert "I can only answer" in data["answer"] or "only answer about" in data.get("response", "") | |
| def test_chat_endpoint_limits_output_length(self, mock_get_rag_pipeline, client): | |
| """Test that /chat limits output length according to config""" | |
| # Create a response longer than typical max_response_length (1000 chars) with ellipsis to simulate truncation | |
| long_response_base = "This is a very long policy response that should be truncated. " * 16 # ~1024 chars | |
| truncated_response = long_response_base[:997] + "..." # Simulate RAG pipeline truncation | |
| # Mock the RAG pipeline to return a truncated response as the actual pipeline would | |
| mock_pipeline = MagicMock() | |
| mock_response = MagicMock() | |
| mock_response.success = True | |
| mock_response.answer = truncated_response | |
| mock_response.confidence = 0.9 | |
| mock_response.sources = [{"source": "policy.md", "content": "test"}] | |
| mock_response.processing_time = 0.5 | |
| mock_response.context_length = 500 | |
| mock_response.llm_provider = "test" | |
| mock_response.llm_model = "test" | |
| mock_response.search_results_count = 1 | |
| mock_pipeline.generate_answer.return_value = mock_response | |
| mock_get_rag_pipeline.return_value = mock_pipeline | |
| request_data = {"message": "Tell me about the vacation policy in great detail"} | |
| response = client.post("/chat", data=json.dumps(request_data), content_type="application/json") | |
| assert response.status_code == 200 | |
| data = response.get_json() | |
| assert data["status"] == "success" | |
| # Response should be truncated - check for ellipsis or reasonable length | |
| response_text = data.get("answer", data.get("response", "")) | |
| assert len(response_text) <= 1000 or response_text.endswith( | |
| "..." | |
| ), f"Response should be limited or truncated, got {len(response_text)} chars" | |
| def test_chat_endpoint_always_includes_citations(self, mock_get_rag_pipeline, client): | |
| """Test that /chat always includes at least one source citation""" | |
| # Mock RAG pipeline to return response with sources | |
| mock_pipeline = MagicMock() | |
| mock_response = MagicMock() | |
| mock_response.answer = "According to our policy, employees get 15 days of vacation." | |
| mock_response.confidence = 0.9 | |
| mock_response.sources = [{"source": "vacation_policy.md", "content": "Vacation policy details"}] | |
| mock_pipeline.generate_answer.return_value = mock_response | |
| mock_get_rag_pipeline.return_value = mock_pipeline | |
| request_data = {"message": "How many vacation days do employees get?"} | |
| response = client.post("/chat", data=json.dumps(request_data), content_type="application/json") | |
| assert response.status_code == 200 | |
| data = response.get_json() | |
| assert data["status"] == "success" | |
| # Should have citations in sources array | |
| assert "sources" in data | |
| assert len(data["sources"]) > 0 | |
| # Response text should contain citation markers or sources should be populated | |
| response_text = data.get("answer", data.get("response", "")) | |
| has_citation_in_text = "[Source:" in response_text or "[source:" in response_text | |
| has_sources_array = len(data.get("sources", [])) > 0 | |
| assert ( | |
| has_citation_in_text or has_sources_array | |
| ), "Response must include citations either in text or sources array" | |
| class TestChatHealthEndpoint: | |
| """Test cases for the /chat/health endpoint""" | |
| def _clear_app_config(self, app): | |
| # Clear any mock state that might persist between tests | |
| # Clear app cache to ensure clean state | |
| app.config["RAG_PIPELINE"] = None | |
| app.config["INGESTION_PIPELINE"] = None | |
| app.config["SEARCH_SERVICE"] = None | |
| def test_chat_health_healthy(self, mock_health_check, mock_llm_service, client): | |
| """Test chat health endpoint when all services are healthy""" | |
| mock_health_data = { | |
| "pipeline": "healthy", | |
| "components": { | |
| "search_service": {"status": "healthy"}, | |
| "llm_service": {"status": "healthy"}, | |
| "vector_db": {"status": "healthy"}, | |
| }, | |
| } | |
| mock_health_check.return_value = mock_health_data | |
| # Return a simple object instead of MagicMock to avoid serialization issues | |
| mock_llm_service.return_value = object() | |
| response = client.get("/chat/health") | |
| assert response.status_code == 200 | |
| data = response.get_json() | |
| assert data["status"] == "success" | |
| def test_chat_health_degraded(self, mock_health_check, mock_llm_service, client): | |
| """Test chat health endpoint when services are degraded""" | |
| mock_health_data = { | |
| "pipeline": "degraded", | |
| "components": { | |
| "search_service": {"status": "healthy"}, | |
| "llm_service": {"status": "degraded", "warning": "High latency"}, | |
| "vector_db": {"status": "healthy"}, | |
| }, | |
| } | |
| mock_health_check.return_value = mock_health_data | |
| # Return a simple object instead of MagicMock to avoid serialization issues | |
| mock_llm_service.return_value = object() | |
| response = client.get("/chat/health") | |
| assert response.status_code == 200 | |
| data = response.get_json() | |
| assert data["status"] == "success" | |
| def test_chat_health_no_llm_config(self, client): | |
| """Test chat health endpoint with no LLM configuration""" | |
| with patch.dict(os.environ, {}, clear=True): | |
| response = client.get("/chat/health") | |
| assert response.status_code == 503 | |
| data = response.get_json() | |
| assert data["status"] == "error" | |
| assert "LLM" in data["message"] and "configuration error" in data["message"] | |
| def test_chat_health_unhealthy(self, mock_get_rag_pipeline, mock_health_check, mock_llm_service, client): | |
| """Test chat health endpoint when services are unhealthy""" | |
| mock_health_data = { | |
| "pipeline": "unhealthy", | |
| "components": { | |
| "search_service": { | |
| "status": "unhealthy", | |
| "error": "Database connection failed", | |
| }, | |
| "llm_service": {"status": "unhealthy", "error": "API unreachable"}, | |
| "vector_db": {"status": "unhealthy"}, | |
| }, | |
| } | |
| mock_health_check.return_value = mock_health_data | |
| # Return a simple object instead of MagicMock to avoid serialization issues | |
| mock_llm_service.return_value = object() | |
| # Create a mock pipeline that has the health_check method | |
| mock_pipeline = MagicMock() | |
| mock_pipeline.health_check.return_value = mock_health_data | |
| mock_get_rag_pipeline.return_value = mock_pipeline | |
| response = client.get("/chat/health") | |
| assert response.status_code == 503 | |
| data = response.get_json() | |
| assert data["status"] == "success" # Still returns success, but 503 status code | |