ai-engineering-project / tests /test_chat_endpoint.py
GitHub Action
Clean deployment without binary files
f884e6e
import json
# These tests are integration-like and exercise chat endpoints.
import os
from unittest.mock import MagicMock, patch
import pytest
from app import app as flask_app
@pytest.fixture
def app():
yield flask_app
@pytest.fixture
def client(app):
return app.test_client()
class TestChatEndpoint:
"""Test cases for the /chat endpoint"""
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
@patch("src.rag.rag_pipeline.RAGPipeline")
@patch("src.rag.response_formatter.ResponseFormatter")
@patch("src.llm.llm_service.LLMService")
@patch("src.search.search_service.SearchService")
@patch("src.vector_store.vector_db.VectorDatabase")
@patch("src.embedding.embedding_service.EmbeddingService")
def test_chat_endpoint_valid_request(
self,
mock_embedding,
mock_vector,
mock_search,
mock_llm,
mock_formatter,
mock_rag,
client,
):
"""Test chat endpoint with valid request"""
# Mock the RAG pipeline response
mock_response = {
"answer": ("Based on the remote work policy, employees can work " "remotely up to 3 days per week."),
"confidence": 0.85,
"sources": [{"chunk_id": "123", "content": "Remote work policy content..."}],
"citations": ["remote_work_policy.md"],
"processing_time_ms": 1500,
}
# Setup mock instances
mock_rag_instance = MagicMock()
mock_rag_instance.generate_answer.return_value = mock_response
mock_rag.return_value = mock_rag_instance
mock_formatter_instance = MagicMock()
mock_formatter_instance.format_api_response.return_value = {
"status": "success",
"answer": mock_response["answer"],
"confidence": mock_response["confidence"],
"sources": mock_response["sources"],
"citations": mock_response["citations"],
}
mock_formatter.return_value = mock_formatter_instance
# Mock LLMService.from_environment to return a mock instance
mock_llm_instance = MagicMock()
mock_llm.from_environment.return_value = mock_llm_instance
request_data = {
"message": "What is the remote work policy?",
"include_sources": True,
}
response = client.post("/chat", data=json.dumps(request_data), content_type="application/json")
assert response.status_code == 200
data = response.get_json()
assert data["status"] == "success"
assert "answer" in data
assert "confidence" in data
assert "sources" in data
assert "citations" in data
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
@patch("src.rag.rag_pipeline.RAGPipeline")
@patch("src.rag.response_formatter.ResponseFormatter")
@patch("src.llm.llm_service.LLMService")
@patch("src.search.search_service.SearchService")
@patch("src.vector_store.vector_db.VectorDatabase")
@patch("src.embedding.embedding_service.EmbeddingService")
def test_chat_endpoint_minimal_request(
self,
mock_embedding,
mock_vector,
mock_search,
mock_llm,
mock_formatter,
mock_rag,
client,
):
"""Test chat endpoint with minimal request (only message)"""
mock_response = {
"answer": ("Employee benefits include health insurance, " "retirement plans, and PTO."),
"confidence": 0.78,
"sources": [],
"citations": ["employee_benefits_guide.md"],
"processing_time_ms": 1200,
}
# Setup mock instances
mock_rag_instance = MagicMock()
mock_rag_instance.generate_answer.return_value = mock_response
mock_rag.return_value = mock_rag_instance
mock_formatter_instance = MagicMock()
mock_formatter_instance.format_api_response.return_value = {
"status": "success",
"answer": mock_response["answer"],
}
mock_formatter.return_value = mock_formatter_instance
mock_llm.from_environment.return_value = MagicMock()
request_data = {"message": "What are the employee benefits?"}
response = client.post("/chat", data=json.dumps(request_data), content_type="application/json")
assert response.status_code == 200
data = response.get_json()
assert data["status"] == "success"
def test_chat_endpoint_missing_message(self, client):
"""Test chat endpoint with missing message parameter"""
request_data = {"include_sources": True}
response = client.post("/chat", data=json.dumps(request_data), content_type="application/json")
assert response.status_code == 400
data = response.get_json()
assert data["status"] == "error"
assert "message parameter is required" in data["message"]
def test_chat_endpoint_empty_message(self, client):
"""Test chat endpoint with empty message"""
request_data = {"message": ""}
response = client.post("/chat", data=json.dumps(request_data), content_type="application/json")
assert response.status_code == 400
data = response.get_json()
assert data["status"] == "error"
assert "non-empty string" in data["message"]
def test_chat_endpoint_non_string_message(self, client):
"""Test chat endpoint with non-string message"""
request_data = {"message": 123}
response = client.post("/chat", data=json.dumps(request_data), content_type="application/json")
assert response.status_code == 400
data = response.get_json()
assert data["status"] == "error"
assert "non-empty string" in data["message"]
def test_chat_endpoint_non_json_request(self, client):
"""Test chat endpoint with non-JSON request"""
response = client.post("/chat", data="not json", content_type="text/plain")
assert response.status_code == 400
data = response.get_json()
assert data["status"] == "error"
assert "application/json" in data["message"]
def test_chat_endpoint_no_llm_config(self, client):
"""Test chat endpoint with no LLM configuration"""
with patch.dict(os.environ, {}, clear=True):
request_data = {"message": "What is the policy?"}
response = client.post("/chat", data=json.dumps(request_data), content_type="application/json")
assert response.status_code == 503
data = response.get_json()
assert data["status"] == "error"
assert "LLM service configuration error" in data["message"]
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
@patch("src.rag.rag_pipeline.RAGPipeline")
@patch("src.rag.response_formatter.ResponseFormatter")
@patch("src.llm.llm_service.LLMService")
@patch("src.search.search_service.SearchService")
@patch("src.vector_store.vector_db.VectorDatabase")
@patch("src.embedding.embedding_service.EmbeddingService")
def test_chat_endpoint_with_conversation_id(
self,
mock_embedding,
mock_vector,
mock_search,
mock_llm,
mock_formatter,
mock_rag,
client,
):
"""Test chat endpoint with conversation_id parameter"""
mock_response = {
"answer": "The PTO policy allows 15 days of vacation annually.",
"confidence": 0.9,
"sources": [],
"citations": ["pto_policy.md"],
"processing_time_ms": 1100,
}
# Setup mock instances
mock_rag_instance = MagicMock()
mock_rag_instance.generate_answer.return_value = mock_response
mock_rag.return_value = mock_rag_instance
mock_formatter_instance = MagicMock()
mock_formatter_instance.format_chat_response.return_value = {
"status": "success",
"answer": mock_response["answer"],
}
mock_formatter.return_value = mock_formatter_instance
mock_llm.from_environment.return_value = MagicMock()
request_data = {
"message": "What is the PTO policy?",
"conversation_id": "conv_123",
"include_sources": False,
}
response = client.post("/chat", data=json.dumps(request_data), content_type="application/json")
assert response.status_code == 200
data = response.get_json()
assert data["status"] == "success"
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
@patch("src.rag.rag_pipeline.RAGPipeline")
@patch("src.rag.response_formatter.ResponseFormatter")
@patch("src.llm.llm_service.LLMService")
@patch("src.search.search_service.SearchService")
@patch("src.vector_store.vector_db.VectorDatabase")
@patch("src.embedding.embedding_service.EmbeddingService")
def test_chat_endpoint_with_debug(
self,
mock_embedding,
mock_vector,
mock_search,
mock_llm,
mock_formatter,
mock_rag,
client,
):
"""Test chat endpoint with debug information"""
mock_response = {
"answer": "The security policy requires 2FA authentication.",
"confidence": 0.95,
"sources": [{"chunk_id": "456", "content": "Security requirements..."}],
"citations": ["information_security_policy.md"],
"processing_time_ms": 1800,
"search_results_count": 5,
"context_length": 2048,
}
# Setup mock instances
mock_rag_instance = MagicMock()
mock_rag_instance.generate_answer.return_value = mock_response
mock_rag.return_value = mock_rag_instance
mock_formatter_instance = MagicMock()
mock_formatter_instance.format_api_response.return_value = {
"status": "success",
"answer": mock_response["answer"],
"debug": {"processing_time": 1800},
}
mock_formatter.return_value = mock_formatter_instance
mock_llm.from_environment.return_value = MagicMock()
request_data = {
"message": "What are the security requirements?",
"include_debug": True,
}
response = client.post("/chat", data=json.dumps(request_data), content_type="application/json")
assert response.status_code == 200
data = response.get_json()
assert data["status"] == "success"
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
@patch("src.routes.main_routes.get_rag_pipeline")
def test_chat_endpoint_refuses_off_corpus_questions(self, mock_get_rag_pipeline, client):
"""Test that /chat refuses to answer questions outside the corpus"""
# Mock RAG pipeline to simulate off-corpus refusal
mock_pipeline = MagicMock()
mock_response = MagicMock()
mock_response.answer = "I can only answer questions about our company policies and procedures."
mock_response.confidence = 0.0
mock_response.sources = []
mock_pipeline.generate_answer.return_value = mock_response
mock_get_rag_pipeline.return_value = mock_pipeline
# Ask an off-corpus question
request_data = {"message": "What is the capital of France?"}
response = client.post("/chat", data=json.dumps(request_data), content_type="application/json")
assert response.status_code == 200
data = response.get_json()
assert data["status"] == "success"
# Should contain refusal language
assert "I can only answer" in data["answer"] or "only answer about" in data.get("response", "")
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
@patch("src.routes.main_routes.get_rag_pipeline")
def test_chat_endpoint_limits_output_length(self, mock_get_rag_pipeline, client):
"""Test that /chat limits output length according to config"""
# Create a response longer than typical max_response_length (1000 chars) with ellipsis to simulate truncation
long_response_base = "This is a very long policy response that should be truncated. " * 16 # ~1024 chars
truncated_response = long_response_base[:997] + "..." # Simulate RAG pipeline truncation
# Mock the RAG pipeline to return a truncated response as the actual pipeline would
mock_pipeline = MagicMock()
mock_response = MagicMock()
mock_response.success = True
mock_response.answer = truncated_response
mock_response.confidence = 0.9
mock_response.sources = [{"source": "policy.md", "content": "test"}]
mock_response.processing_time = 0.5
mock_response.context_length = 500
mock_response.llm_provider = "test"
mock_response.llm_model = "test"
mock_response.search_results_count = 1
mock_pipeline.generate_answer.return_value = mock_response
mock_get_rag_pipeline.return_value = mock_pipeline
request_data = {"message": "Tell me about the vacation policy in great detail"}
response = client.post("/chat", data=json.dumps(request_data), content_type="application/json")
assert response.status_code == 200
data = response.get_json()
assert data["status"] == "success"
# Response should be truncated - check for ellipsis or reasonable length
response_text = data.get("answer", data.get("response", ""))
assert len(response_text) <= 1000 or response_text.endswith(
"..."
), f"Response should be limited or truncated, got {len(response_text)} chars"
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
@patch("src.routes.main_routes.get_rag_pipeline")
def test_chat_endpoint_always_includes_citations(self, mock_get_rag_pipeline, client):
"""Test that /chat always includes at least one source citation"""
# Mock RAG pipeline to return response with sources
mock_pipeline = MagicMock()
mock_response = MagicMock()
mock_response.answer = "According to our policy, employees get 15 days of vacation."
mock_response.confidence = 0.9
mock_response.sources = [{"source": "vacation_policy.md", "content": "Vacation policy details"}]
mock_pipeline.generate_answer.return_value = mock_response
mock_get_rag_pipeline.return_value = mock_pipeline
request_data = {"message": "How many vacation days do employees get?"}
response = client.post("/chat", data=json.dumps(request_data), content_type="application/json")
assert response.status_code == 200
data = response.get_json()
assert data["status"] == "success"
# Should have citations in sources array
assert "sources" in data
assert len(data["sources"]) > 0
# Response text should contain citation markers or sources should be populated
response_text = data.get("answer", data.get("response", ""))
has_citation_in_text = "[Source:" in response_text or "[source:" in response_text
has_sources_array = len(data.get("sources", [])) > 0
assert (
has_citation_in_text or has_sources_array
), "Response must include citations either in text or sources array"
class TestChatHealthEndpoint:
"""Test cases for the /chat/health endpoint"""
@pytest.fixture(autouse=True)
def _clear_app_config(self, app):
# Clear any mock state that might persist between tests
# Clear app cache to ensure clean state
app.config["RAG_PIPELINE"] = None
app.config["INGESTION_PIPELINE"] = None
app.config["SEARCH_SERVICE"] = None
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
@patch("src.llm.llm_service.LLMService.from_environment")
@patch("src.rag.rag_pipeline.RAGPipeline.health_check")
def test_chat_health_healthy(self, mock_health_check, mock_llm_service, client):
"""Test chat health endpoint when all services are healthy"""
mock_health_data = {
"pipeline": "healthy",
"components": {
"search_service": {"status": "healthy"},
"llm_service": {"status": "healthy"},
"vector_db": {"status": "healthy"},
},
}
mock_health_check.return_value = mock_health_data
# Return a simple object instead of MagicMock to avoid serialization issues
mock_llm_service.return_value = object()
response = client.get("/chat/health")
assert response.status_code == 200
data = response.get_json()
assert data["status"] == "success"
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
@patch("src.llm.llm_service.LLMService.from_environment")
@patch("src.rag.rag_pipeline.RAGPipeline.health_check")
def test_chat_health_degraded(self, mock_health_check, mock_llm_service, client):
"""Test chat health endpoint when services are degraded"""
mock_health_data = {
"pipeline": "degraded",
"components": {
"search_service": {"status": "healthy"},
"llm_service": {"status": "degraded", "warning": "High latency"},
"vector_db": {"status": "healthy"},
},
}
mock_health_check.return_value = mock_health_data
# Return a simple object instead of MagicMock to avoid serialization issues
mock_llm_service.return_value = object()
response = client.get("/chat/health")
assert response.status_code == 200
data = response.get_json()
assert data["status"] == "success"
def test_chat_health_no_llm_config(self, client):
"""Test chat health endpoint with no LLM configuration"""
with patch.dict(os.environ, {}, clear=True):
response = client.get("/chat/health")
assert response.status_code == 503
data = response.get_json()
assert data["status"] == "error"
assert "LLM" in data["message"] and "configuration error" in data["message"]
@patch.dict(os.environ, {"OPENROUTER_API_KEY": "test_key"})
@patch("src.llm.llm_service.LLMService.from_environment")
@patch("src.rag.rag_pipeline.RAGPipeline.health_check")
@patch("src.routes.main_routes.get_rag_pipeline")
def test_chat_health_unhealthy(self, mock_get_rag_pipeline, mock_health_check, mock_llm_service, client):
"""Test chat health endpoint when services are unhealthy"""
mock_health_data = {
"pipeline": "unhealthy",
"components": {
"search_service": {
"status": "unhealthy",
"error": "Database connection failed",
},
"llm_service": {"status": "unhealthy", "error": "API unreachable"},
"vector_db": {"status": "unhealthy"},
},
}
mock_health_check.return_value = mock_health_data
# Return a simple object instead of MagicMock to avoid serialization issues
mock_llm_service.return_value = object()
# Create a mock pipeline that has the health_check method
mock_pipeline = MagicMock()
mock_pipeline.health_check.return_value = mock_health_data
mock_get_rag_pipeline.return_value = mock_pipeline
response = client.get("/chat/health")
assert response.status_code == 503
data = response.get_json()
assert data["status"] == "success" # Still returns success, but 503 status code