scratch_chat / tests /integration /test_chat_api.py
WebashalarForML's picture
Upload 178 files
330b6e4 verified
"""Integration tests for Chat API endpoints."""
import json
import pytest
from datetime import datetime
from unittest.mock import patch, MagicMock
from app import create_app, db
from chat_agent.models.chat_session import ChatSession
from chat_agent.models.message import Message
from chat_agent.models.language_context import LanguageContext
@pytest.fixture
def app():
"""Create test application."""
app = create_app('testing')
with app.app_context():
db.create_all()
yield app
db.drop_all()
@pytest.fixture
def client(app):
"""Create test client."""
return app.test_client()
@pytest.fixture
def auth_headers():
"""Create authentication headers for testing."""
return {
'X-User-ID': 'test-user-123',
'Content-Type': 'application/json'
}
@pytest.fixture
def sample_session(app):
"""Create a sample session for testing."""
with app.app_context():
session = ChatSession.create_session(
user_id='test-user-123',
language='python',
session_metadata={'test': True}
)
# Create language context
LanguageContext.create_context(session.id, 'python')
yield session
@pytest.fixture
def sample_messages(app, sample_session):
"""Create sample messages for testing."""
with app.app_context():
messages = []
# Create user message
user_msg = Message.create_user_message(
session_id=sample_session.id,
content="Hello, can you help me with Python?",
language='python'
)
db.session.add(user_msg)
messages.append(user_msg)
# Create assistant message
assistant_msg = Message.create_assistant_message(
session_id=sample_session.id,
content="Of course! I'd be happy to help you with Python programming.",
language='python',
message_metadata={'tokens': 15}
)
db.session.add(assistant_msg)
messages.append(assistant_msg)
db.session.commit()
yield messages
class TestSessionManagement:
"""Test session management endpoints."""
def test_create_session_success(self, client, auth_headers):
"""Test successful session creation."""
data = {
'language': 'python',
'metadata': {'source': 'test'}
}
response = client.post(
'/api/v1/chat/sessions',
data=json.dumps(data),
headers=auth_headers
)
assert response.status_code == 201
response_data = json.loads(response.data)
assert 'session_id' in response_data
assert response_data['user_id'] == 'test-user-123'
assert response_data['language'] == 'python'
assert response_data['message_count'] == 0
assert response_data['metadata']['source'] == 'test'
def test_create_session_invalid_language(self, client, auth_headers):
"""Test session creation with invalid language."""
data = {
'language': 'invalid-language'
}
response = client.post(
'/api/v1/chat/sessions',
data=json.dumps(data),
headers=auth_headers
)
assert response.status_code == 400
response_data = json.loads(response.data)
assert 'Unsupported language' in response_data['error']
def test_create_session_missing_auth(self, client):
"""Test session creation without authentication."""
data = {
'language': 'python'
}
response = client.post(
'/api/v1/chat/sessions',
data=json.dumps(data),
headers={'Content-Type': 'application/json'}
)
assert response.status_code == 401
response_data = json.loads(response.data)
assert 'Authentication required' in response_data['error']
def test_create_session_missing_language(self, client, auth_headers):
"""Test session creation without required language field."""
data = {
'metadata': {'test': True}
}
response = client.post(
'/api/v1/chat/sessions',
data=json.dumps(data),
headers=auth_headers
)
assert response.status_code == 400
response_data = json.loads(response.data)
assert 'Missing required fields' in response_data['error']
def test_get_session_success(self, client, auth_headers, sample_session):
"""Test successful session retrieval."""
response = client.get(
f'/api/v1/chat/sessions/{sample_session.id}',
headers=auth_headers
)
assert response.status_code == 200
response_data = json.loads(response.data)
assert response_data['session_id'] == sample_session.id
assert response_data['user_id'] == 'test-user-123'
assert response_data['language'] == 'python'
assert response_data['is_active'] is True
def test_get_session_not_found(self, client, auth_headers):
"""Test getting non-existent session."""
response = client.get(
'/api/v1/chat/sessions/non-existent-id',
headers=auth_headers
)
assert response.status_code == 404
response_data = json.loads(response.data)
assert 'Session not found' in response_data['error']
def test_get_session_wrong_user(self, client, sample_session):
"""Test getting session with wrong user."""
headers = {
'X-User-ID': 'different-user',
'Content-Type': 'application/json'
}
response = client.get(
f'/api/v1/chat/sessions/{sample_session.id}',
headers=headers
)
assert response.status_code == 403
response_data = json.loads(response.data)
assert 'Access denied' in response_data['error']
def test_list_user_sessions(self, client, auth_headers, sample_session):
"""Test listing user sessions."""
response = client.get(
'/api/v1/chat/sessions',
headers=auth_headers
)
assert response.status_code == 200
response_data = json.loads(response.data)
assert 'sessions' in response_data
assert response_data['total_count'] >= 1
assert response_data['active_only'] is True
# Check if our sample session is in the list
session_ids = [s['session_id'] for s in response_data['sessions']]
assert sample_session.id in session_ids
def test_delete_session_success(self, client, auth_headers, sample_session):
"""Test successful session deletion."""
response = client.delete(
f'/api/v1/chat/sessions/{sample_session.id}',
headers=auth_headers
)
assert response.status_code == 200
response_data = json.loads(response.data)
assert 'Session deleted successfully' in response_data['message']
assert response_data['session_id'] == sample_session.id
def test_delete_session_wrong_user(self, client, sample_session):
"""Test deleting session with wrong user."""
headers = {
'X-User-ID': 'different-user',
'Content-Type': 'application/json'
}
response = client.delete(
f'/api/v1/chat/sessions/{sample_session.id}',
headers=headers
)
assert response.status_code == 403
response_data = json.loads(response.data)
assert 'Access denied' in response_data['error']
class TestChatHistory:
"""Test chat history endpoints."""
def test_get_chat_history_success(self, client, auth_headers, sample_session, sample_messages):
"""Test successful chat history retrieval."""
response = client.get(
f'/api/v1/chat/sessions/{sample_session.id}/history',
headers=auth_headers
)
assert response.status_code == 200
response_data = json.loads(response.data)
assert 'messages' in response_data
assert response_data['session_id'] == sample_session.id
assert response_data['total_count'] == 2
assert len(response_data['messages']) == 2
# Check message content
messages = response_data['messages']
assert messages[0]['role'] == 'user'
assert messages[1]['role'] == 'assistant'
def test_get_chat_history_recent_only(self, client, auth_headers, sample_session, sample_messages):
"""Test getting recent chat history only."""
response = client.get(
f'/api/v1/chat/sessions/{sample_session.id}/history?recent_only=true&limit=1',
headers=auth_headers
)
assert response.status_code == 200
response_data = json.loads(response.data)
assert len(response_data['messages']) == 1
assert 'page' not in response_data # Recent only doesn't have pagination
def test_get_chat_history_pagination(self, client, auth_headers, sample_session, sample_messages):
"""Test chat history pagination."""
response = client.get(
f'/api/v1/chat/sessions/{sample_session.id}/history?page=1&page_size=1',
headers=auth_headers
)
assert response.status_code == 200
response_data = json.loads(response.data)
assert response_data['page'] == 1
assert response_data['page_size'] == 1
assert response_data['total_pages'] == 2
assert len(response_data['messages']) == 1
def test_get_chat_history_wrong_user(self, client, sample_session, sample_messages):
"""Test getting chat history with wrong user."""
headers = {
'X-User-ID': 'different-user',
'Content-Type': 'application/json'
}
response = client.get(
f'/api/v1/chat/sessions/{sample_session.id}/history',
headers=headers
)
assert response.status_code == 403
response_data = json.loads(response.data)
assert 'Access denied' in response_data['error']
def test_search_chat_history_success(self, client, auth_headers, sample_session, sample_messages):
"""Test successful chat history search."""
response = client.get(
f'/api/v1/chat/sessions/{sample_session.id}/history/search?q=Python',
headers=auth_headers
)
assert response.status_code == 200
response_data = json.loads(response.data)
assert 'messages' in response_data
assert response_data['query'] == 'Python'
assert response_data['result_count'] >= 1
def test_search_chat_history_empty_query(self, client, auth_headers, sample_session):
"""Test chat history search with empty query."""
response = client.get(
f'/api/v1/chat/sessions/{sample_session.id}/history/search?q=',
headers=auth_headers
)
assert response.status_code == 400
response_data = json.loads(response.data)
assert 'Search query is required' in response_data['error']
def test_search_chat_history_short_query(self, client, auth_headers, sample_session):
"""Test chat history search with too short query."""
response = client.get(
f'/api/v1/chat/sessions/{sample_session.id}/history/search?q=ab',
headers=auth_headers
)
assert response.status_code == 400
response_data = json.loads(response.data)
assert 'at least 3 characters' in response_data['error']
class TestLanguageContext:
"""Test language context endpoints."""
def test_get_language_context_success(self, client, auth_headers, sample_session):
"""Test successful language context retrieval."""
response = client.get(
f'/api/v1/chat/sessions/{sample_session.id}/language',
headers=auth_headers
)
assert response.status_code == 200
response_data = json.loads(response.data)
assert response_data['session_id'] == sample_session.id
assert response_data['language'] == 'python'
assert 'prompt_template' in response_data
assert 'syntax_highlighting' in response_data
assert 'language_info' in response_data
def test_update_language_context_success(self, client, auth_headers, sample_session):
"""Test successful language context update."""
data = {
'language': 'javascript'
}
response = client.put(
f'/api/v1/chat/sessions/{sample_session.id}/language',
data=json.dumps(data),
headers=auth_headers
)
assert response.status_code == 200
response_data = json.loads(response.data)
assert response_data['language'] == 'javascript'
assert 'JavaScript' in response_data['language_info']['name']
def test_update_language_context_invalid_language(self, client, auth_headers, sample_session):
"""Test language context update with invalid language."""
data = {
'language': 'invalid-language'
}
response = client.put(
f'/api/v1/chat/sessions/{sample_session.id}/language',
data=json.dumps(data),
headers=auth_headers
)
assert response.status_code == 400
response_data = json.loads(response.data)
assert 'Unsupported language' in response_data['error']
def test_get_supported_languages(self, client):
"""Test getting supported languages."""
response = client.get('/api/v1/chat/languages')
assert response.status_code == 200
response_data = json.loads(response.data)
assert 'languages' in response_data
assert response_data['default_language'] == 'python'
assert response_data['total_count'] > 0
# Check if Python is in the list
language_codes = [lang['code'] for lang in response_data['languages']]
assert 'python' in language_codes
class TestHealthCheck:
"""Test health check endpoint."""
@patch('redis.from_url')
def test_health_check_success(self, mock_redis, client, app):
"""Test successful health check."""
# Mock Redis ping
mock_redis_client = MagicMock()
mock_redis_client.ping.return_value = True
mock_redis.return_value = mock_redis_client
response = client.get('/api/v1/chat/health')
assert response.status_code == 200
response_data = json.loads(response.data)
assert response_data['status'] == 'healthy'
assert 'timestamp' in response_data
assert response_data['services']['database'] == 'connected'
assert response_data['services']['redis'] == 'connected'
@patch('redis.from_url')
def test_health_check_redis_failure(self, mock_redis, client):
"""Test health check with Redis failure."""
# Mock Redis ping failure
mock_redis_client = MagicMock()
mock_redis_client.ping.side_effect = Exception("Redis connection failed")
mock_redis.return_value = mock_redis_client
response = client.get('/api/v1/chat/health')
assert response.status_code == 503
response_data = json.loads(response.data)
assert response_data['status'] == 'unhealthy'
assert 'error' in response_data
class TestRateLimiting:
"""Test rate limiting functionality."""
def test_rate_limiting_session_creation(self, client, auth_headers):
"""Test rate limiting on session creation endpoint."""
data = {
'language': 'python'
}
# Make multiple requests quickly
responses = []
for i in range(15): # Exceed the 10 per minute limit
response = client.post(
'/api/v1/chat/sessions',
data=json.dumps(data),
headers=auth_headers
)
responses.append(response)
# Check that some requests were rate limited
rate_limited_responses = [r for r in responses if r.status_code == 429]
assert len(rate_limited_responses) > 0
# Check rate limit response format
if rate_limited_responses:
response_data = json.loads(rate_limited_responses[0].data)
assert 'Rate limit exceeded' in response_data['error']
class TestErrorHandling:
"""Test error handling scenarios."""
def test_invalid_json_request(self, client, auth_headers):
"""Test handling of invalid JSON requests."""
response = client.post(
'/api/v1/chat/sessions',
data='invalid json',
headers=auth_headers
)
assert response.status_code == 400
response_data = json.loads(response.data)
assert 'Request must be JSON' in response_data['error']
def test_empty_request_body(self, client, auth_headers):
"""Test handling of empty request body."""
response = client.post(
'/api/v1/chat/sessions',
data='{}',
headers=auth_headers
)
assert response.status_code == 400
response_data = json.loads(response.data)
assert 'Missing required fields' in response_data['error']
def test_non_existent_endpoint(self, client, auth_headers):
"""Test handling of non-existent endpoints."""
response = client.get(
'/api/v1/chat/non-existent',
headers=auth_headers
)
assert response.status_code == 404
response_data = json.loads(response.data)
assert 'Not found' in response_data['error']
if __name__ == '__main__':
pytest.main([__file__])