Spaces:
Sleeping
Sleeping
| """Comprehensive integration tests for conversation and chat management.""" | |
| import pytest | |
| import asyncio | |
| from datetime import datetime | |
| from unittest.mock import AsyncMock, patch, MagicMock | |
| from sqlmodel import select | |
| from sqlmodel.ext.asyncio.session import AsyncSession | |
| from src.db.database import AsyncSessionLocal | |
| from src.db.models import User, Conversation, Message | |
| from src.services.conversation_service import ConversationService | |
| from src.models import ChatRequest, ChatResponse | |
| # ============================================================================== | |
| # FIXTURES | |
| # ============================================================================== | |
| def event_loop(): | |
| """Create an event loop for async tests.""" | |
| loop = asyncio.get_event_loop_policy().new_event_loop() | |
| yield loop | |
| loop.close() | |
| async def test_user_1(session: AsyncSession): | |
| """Create a test user for conversation tests.""" | |
| user = User( | |
| email="user1@example.com", | |
| name="Test User 1", | |
| password_hash="hashed_password_1", | |
| is_active=True, | |
| ) | |
| session.add(user) | |
| await session.commit() | |
| await session.refresh(user) | |
| return user | |
| async def test_user_2(session: AsyncSession): | |
| """Create a second test user to verify isolation.""" | |
| user = User( | |
| email="user2@example.com", | |
| name="Test User 2", | |
| password_hash="hashed_password_2", | |
| is_active=True, | |
| ) | |
| session.add(user) | |
| await session.commit() | |
| await session.refresh(user) | |
| return user | |
| async def session(): | |
| """Provide a database session for tests.""" | |
| async with AsyncSessionLocal() as session: | |
| yield session | |
| # ============================================================================== | |
| # CONVERSATION CREATION TESTS | |
| # ============================================================================== | |
| async def test_create_conversation_success(test_user_1: User, session: AsyncSession): | |
| """Test successful conversation creation.""" | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| title="Test Conversation" | |
| ) | |
| assert conv.id is not None | |
| assert conv.user_id == test_user_1.id | |
| assert conv.title == "Test Conversation" | |
| assert conv.is_deleted == False | |
| assert conv.created_at is not None | |
| assert conv.last_message_at is not None | |
| async def test_create_conversation_without_title(test_user_1: User, session: AsyncSession): | |
| """Test conversation creation without title.""" | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| assert conv.id is not None | |
| assert conv.user_id == test_user_1.id | |
| assert conv.title is None # Title should be None | |
| assert conv.is_deleted == False | |
| # ============================================================================== | |
| # MESSAGE SAVING TESTS | |
| # ============================================================================== | |
| async def test_save_message_user_message(test_user_1: User, session: AsyncSession): | |
| """Test saving a user message to a conversation.""" | |
| # Create conversation | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| # Save user message | |
| msg = await ConversationService.save_message( | |
| conversation_id=conv.id, | |
| sender_id=test_user_1.id, | |
| content="Hello, this is my question", | |
| session=session, | |
| ) | |
| assert msg.id is not None | |
| assert msg.conversation_id == conv.id | |
| assert msg.sender_id == test_user_1.id | |
| assert msg.content == "Hello, this is my question" | |
| assert msg.created_at is not None | |
| async def test_save_message_updates_last_message_at(test_user_1: User, session: AsyncSession): | |
| """Test that saving a message updates conversation's last_message_at.""" | |
| # Create conversation | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| original_time = conv.last_message_at | |
| # Save a message | |
| await asyncio.sleep(0.1) # Small delay to ensure timestamp difference | |
| await ConversationService.save_message( | |
| conversation_id=conv.id, | |
| sender_id=test_user_1.id, | |
| content="Test message", | |
| session=session, | |
| ) | |
| # Verify last_message_at was updated | |
| stmt = select(Conversation).where(Conversation.id == conv.id) | |
| result = await session.execute(stmt) | |
| updated_conv = result.scalar_one() | |
| assert updated_conv.last_message_at > original_time | |
| async def test_save_multiple_messages(test_user_1: User, session: AsyncSession): | |
| """Test saving multiple messages to the same conversation.""" | |
| # Create conversation | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| # Save multiple messages | |
| messages = [] | |
| for i in range(3): | |
| msg = await ConversationService.save_message( | |
| conversation_id=conv.id, | |
| sender_id=test_user_1.id, | |
| content=f"Message {i}", | |
| session=session, | |
| ) | |
| messages.append(msg) | |
| # Verify all messages were saved | |
| stmt = select(Message).where(Message.conversation_id == conv.id).order_by(Message.created_at) | |
| result = await session.execute(stmt) | |
| saved_messages = result.scalars().all() | |
| assert len(saved_messages) == 3 | |
| for i, msg in enumerate(saved_messages): | |
| assert msg.content == f"Message {i}" | |
| # ============================================================================== | |
| # CONVERSATION RETRIEVAL TESTS | |
| # ============================================================================== | |
| async def test_get_user_conversations(test_user_1: User, session: AsyncSession): | |
| """Test retrieving all conversations for a user.""" | |
| # Create multiple conversations | |
| conv_ids = [] | |
| for i in range(3): | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| title=f"Conversation {i}" | |
| ) | |
| conv_ids.append(conv.id) | |
| # Retrieve conversations | |
| conversations, total = await ConversationService.get_user_conversations( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| assert len(conversations) >= 3 | |
| assert total >= 3 | |
| retrieved_ids = [c.id for c in conversations] | |
| for conv_id in conv_ids: | |
| assert conv_id in retrieved_ids | |
| async def test_get_user_conversations_pagination(test_user_1: User, session: AsyncSession): | |
| """Test pagination in conversation listing.""" | |
| # Create 5 conversations | |
| for i in range(5): | |
| await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| # Get first page | |
| conversations1, total1 = await ConversationService.get_user_conversations( | |
| user_id=test_user_1.id, | |
| session=session, | |
| skip=0, | |
| limit=2, | |
| ) | |
| # Get second page | |
| conversations2, total2 = await ConversationService.get_user_conversations( | |
| user_id=test_user_1.id, | |
| session=session, | |
| skip=2, | |
| limit=2, | |
| ) | |
| assert len(conversations1) <= 2 | |
| assert len(conversations2) <= 2 | |
| assert total1 >= 5 # Should have at least 5 conversations total | |
| assert total2 >= 5 | |
| async def test_get_user_conversations_ordered_by_recency(test_user_1: User, session: AsyncSession): | |
| """Test that conversations are ordered by most recent first.""" | |
| # Create conversations with delay | |
| conv1 = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| await asyncio.sleep(0.1) | |
| conv2 = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| # Add message to conv1 to update its timestamp | |
| await asyncio.sleep(0.1) | |
| await ConversationService.save_message( | |
| conversation_id=conv1.id, | |
| sender_id=test_user_1.id, | |
| content="Updated message", | |
| session=session, | |
| ) | |
| # Retrieve conversations | |
| conversations, _ = await ConversationService.get_user_conversations( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| # conv1 should be first because it was updated most recently | |
| assert conversations[0].id == conv1.id | |
| async def test_get_conversation_with_messages_success(test_user_1: User, session: AsyncSession): | |
| """Test retrieving conversation with all messages.""" | |
| # Create conversation and add messages | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| for i in range(3): | |
| await ConversationService.save_message( | |
| conversation_id=conv.id, | |
| sender_id=test_user_1.id, | |
| content=f"Message {i}", | |
| session=session, | |
| ) | |
| # Retrieve conversation with messages | |
| result = await ConversationService.get_conversation_with_messages( | |
| conversation_id=conv.id, | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| assert result is not None | |
| retrieved_conv, messages = result | |
| assert retrieved_conv.id == conv.id | |
| assert len(messages) == 3 | |
| assert messages[0].content == "Message 0" | |
| async def test_get_conversation_not_found(test_user_1: User, session: AsyncSession): | |
| """Test retrieving non-existent conversation.""" | |
| result = await ConversationService.get_conversation_with_messages( | |
| conversation_id=99999, | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| assert result is None | |
| async def test_get_conversation_authorization_check( | |
| test_user_1: User, | |
| test_user_2: User, | |
| session: AsyncSession | |
| ): | |
| """Test that user cannot access another user's conversation.""" | |
| # User 1 creates conversation | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| # User 2 tries to access it | |
| with pytest.raises(PermissionError): | |
| await ConversationService.get_conversation_with_messages( | |
| conversation_id=conv.id, | |
| user_id=test_user_2.id, | |
| session=session, | |
| ) | |
| # ============================================================================== | |
| # CONVERSATION SEARCH TESTS | |
| # ============================================================================== | |
| async def test_search_conversations_by_keyword(test_user_1: User, session: AsyncSession): | |
| """Test searching conversations by keyword in messages.""" | |
| # Create conversation with specific content | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| await ConversationService.save_message( | |
| conversation_id=conv.id, | |
| sender_id=test_user_1.id, | |
| content="Tell me about machine learning", | |
| session=session, | |
| ) | |
| # Search for keyword | |
| conversations, total = await ConversationService.search_conversations( | |
| user_id=test_user_1.id, | |
| keyword="machine learning", | |
| session=session, | |
| ) | |
| assert len(conversations) >= 1 | |
| assert conv.id in [c.id for c in conversations] | |
| async def test_search_conversations_case_insensitive(test_user_1: User, session: AsyncSession): | |
| """Test that search is case-insensitive.""" | |
| # Create conversation | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| await ConversationService.save_message( | |
| conversation_id=conv.id, | |
| sender_id=test_user_1.id, | |
| content="MACHINE LEARNING", | |
| session=session, | |
| ) | |
| # Search with different case | |
| conversations, _ = await ConversationService.search_conversations( | |
| user_id=test_user_1.id, | |
| keyword="machine learning", | |
| session=session, | |
| ) | |
| assert len(conversations) >= 1 | |
| assert conv.id in [c.id for c in conversations] | |
| async def test_search_conversations_pagination(test_user_1: User, session: AsyncSession): | |
| """Test pagination in search results.""" | |
| # Create conversations with multiple matching messages | |
| for i in range(5): | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| await ConversationService.save_message( | |
| conversation_id=conv.id, | |
| sender_id=test_user_1.id, | |
| content="python programming", | |
| session=session, | |
| ) | |
| # Search with pagination | |
| conversations1, total1 = await ConversationService.search_conversations( | |
| user_id=test_user_1.id, | |
| keyword="python", | |
| session=session, | |
| skip=0, | |
| limit=2, | |
| ) | |
| conversations2, total2 = await ConversationService.search_conversations( | |
| user_id=test_user_1.id, | |
| keyword="python", | |
| session=session, | |
| skip=2, | |
| limit=2, | |
| ) | |
| assert len(conversations1) <= 2 | |
| assert len(conversations2) <= 2 | |
| assert total1 >= 5 | |
| async def test_search_user_isolation( | |
| test_user_1: User, | |
| test_user_2: User, | |
| session: AsyncSession | |
| ): | |
| """Test that search only returns current user's conversations.""" | |
| # User 1 creates conversation with keyword | |
| conv1 = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| await ConversationService.save_message( | |
| conversation_id=conv1.id, | |
| sender_id=test_user_1.id, | |
| content="unique keyword xyz123", | |
| session=session, | |
| ) | |
| # User 2 searches for same keyword | |
| conversations, _ = await ConversationService.search_conversations( | |
| user_id=test_user_2.id, | |
| keyword="xyz123", | |
| session=session, | |
| ) | |
| # Should not find User 1's conversation | |
| for conv in conversations: | |
| assert conv.user_id == test_user_2.id | |
| # ============================================================================== | |
| # CONVERSATION DELETION TESTS | |
| # ============================================================================== | |
| async def test_delete_conversation_success(test_user_1: User, session: AsyncSession): | |
| """Test successful conversation deletion (soft delete).""" | |
| # Create conversation | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| # Delete conversation | |
| result = await ConversationService.delete_conversation( | |
| conversation_id=conv.id, | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| assert result == True | |
| # Verify it's marked as deleted | |
| stmt = select(Conversation).where(Conversation.id == conv.id) | |
| result = await session.execute(stmt) | |
| conv_check = result.scalar_one() | |
| assert conv_check.is_deleted == True | |
| async def test_delete_conversation_not_found(test_user_1: User, session: AsyncSession): | |
| """Test deleting non-existent conversation.""" | |
| with pytest.raises(ValueError): | |
| await ConversationService.delete_conversation( | |
| conversation_id=99999, | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| async def test_delete_conversation_authorization_check( | |
| test_user_1: User, | |
| test_user_2: User, | |
| session: AsyncSession | |
| ): | |
| """Test that user cannot delete another user's conversation.""" | |
| # User 1 creates conversation | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| # User 2 tries to delete it | |
| with pytest.raises(PermissionError): | |
| await ConversationService.delete_conversation( | |
| conversation_id=conv.id, | |
| user_id=test_user_2.id, | |
| session=session, | |
| ) | |
| async def test_soft_delete_preserves_data(test_user_1: User, session: AsyncSession): | |
| """Test that soft delete preserves message data.""" | |
| # Create conversation with messages | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| await ConversationService.save_message( | |
| conversation_id=conv.id, | |
| sender_id=test_user_1.id, | |
| content="Important data", | |
| session=session, | |
| ) | |
| # Delete conversation | |
| await ConversationService.delete_conversation( | |
| conversation_id=conv.id, | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| # Verify message still exists in database | |
| stmt = select(Message).where(Message.conversation_id == conv.id) | |
| result = await session.execute(stmt) | |
| messages = result.scalars().all() | |
| assert len(messages) == 1 | |
| assert messages[0].content == "Important data" | |
| # ============================================================================== | |
| # MESSAGE COUNT TESTS | |
| # ============================================================================== | |
| async def test_get_message_count_zero(test_user_1: User, session: AsyncSession): | |
| """Test message count for empty conversation.""" | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| count = await ConversationService.get_message_count( | |
| conversation_id=conv.id, | |
| session=session, | |
| ) | |
| assert count == 0 | |
| async def test_get_message_count_accurate(test_user_1: User, session: AsyncSession): | |
| """Test message count is accurate.""" | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| for i in range(5): | |
| await ConversationService.save_message( | |
| conversation_id=conv.id, | |
| sender_id=test_user_1.id, | |
| content=f"Message {i}", | |
| session=session, | |
| ) | |
| count = await ConversationService.get_message_count( | |
| conversation_id=conv.id, | |
| session=session, | |
| ) | |
| assert count == 5 | |
| # ============================================================================== | |
| # AUTO-TITLE GENERATION TESTS | |
| # ============================================================================== | |
| async def test_auto_generate_title(test_user_1: User, session: AsyncSession): | |
| """Test auto-generating conversation title from first message.""" | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| await ConversationService.save_message( | |
| conversation_id=conv.id, | |
| sender_id=test_user_1.id, | |
| content="How do I learn Python programming effectively?", | |
| session=session, | |
| ) | |
| title = await ConversationService.auto_generate_title( | |
| conversation_id=conv.id, | |
| session=session, | |
| ) | |
| assert title is not None | |
| assert "Python" in title | |
| # Verify title was saved | |
| stmt = select(Conversation).where(Conversation.id == conv.id) | |
| result = await session.execute(stmt) | |
| conv_check = result.scalar_one() | |
| assert conv_check.title is not None | |
| async def test_auto_generate_title_empty_conversation(test_user_1: User, session: AsyncSession): | |
| """Test auto-title generation for empty conversation.""" | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| title = await ConversationService.auto_generate_title( | |
| conversation_id=conv.id, | |
| session=session, | |
| ) | |
| assert title == "Conversation" # Default title | |
| # ============================================================================== | |
| # DELETED CONVERSATIONS EXCLUDED TESTS | |
| # ============================================================================== | |
| async def test_list_conversations_excludes_deleted(test_user_1: User, session: AsyncSession): | |
| """Test that deleted conversations are excluded from listing.""" | |
| # Create 2 conversations | |
| conv1 = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| conv2 = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| # Delete one | |
| await ConversationService.delete_conversation( | |
| conversation_id=conv2.id, | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| # List conversations | |
| conversations, total = await ConversationService.get_user_conversations( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| # Should only see non-deleted conversations | |
| conv_ids = [c.id for c in conversations] | |
| assert conv1.id in conv_ids | |
| assert conv2.id not in conv_ids | |
| async def test_search_excludes_deleted(test_user_1: User, session: AsyncSession): | |
| """Test that deleted conversations are excluded from search.""" | |
| # Create 2 conversations with searchable content | |
| conv1 = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| await ConversationService.save_message( | |
| conversation_id=conv1.id, | |
| sender_id=test_user_1.id, | |
| content="python programming", | |
| session=session, | |
| ) | |
| conv2 = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| await ConversationService.save_message( | |
| conversation_id=conv2.id, | |
| sender_id=test_user_1.id, | |
| content="python programming", | |
| session=session, | |
| ) | |
| # Delete one | |
| await ConversationService.delete_conversation( | |
| conversation_id=conv2.id, | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| # Search | |
| conversations, _ = await ConversationService.search_conversations( | |
| user_id=test_user_1.id, | |
| keyword="python", | |
| session=session, | |
| ) | |
| # Should only see non-deleted conversations | |
| conv_ids = [c.id for c in conversations] | |
| assert conv1.id in conv_ids | |
| assert conv2.id not in conv_ids | |
| # ============================================================================== | |
| # DATABASE ERROR HANDLING TESTS | |
| # ============================================================================== | |
| async def test_save_message_database_error_rollback(test_user_1: User, session: AsyncSession): | |
| """Test that database errors are handled and rolled back properly.""" | |
| conv = await ConversationService.create_conversation( | |
| user_id=test_user_1.id, | |
| session=session, | |
| ) | |
| # Close the session to simulate database error | |
| await session.close() | |
| # Attempt to save message should raise exception | |
| with pytest.raises(Exception): | |
| await ConversationService.save_message( | |
| conversation_id=conv.id, | |
| sender_id=test_user_1.id, | |
| content="This should fail", | |
| session=session, | |
| ) | |
| # ============================================================================== | |
| # PYTEST CONFIGURATION | |
| # ============================================================================== | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v", "-s"]) | |