chatbot / tests /test_conversations.py
jawadsaghir12's picture
new update
8c77cd6
"""Comprehensive integration tests for conversation and chat management."""
import pytest
import asyncio
from datetime import datetime
from unittest.mock import AsyncMock, patch, MagicMock
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from src.db.database import AsyncSessionLocal
from src.db.models import User, Conversation, Message
from src.services.conversation_service import ConversationService
from src.models import ChatRequest, ChatResponse
# ==============================================================================
# FIXTURES
# ==============================================================================
@pytest.fixture(scope="session")
def event_loop():
"""Create an event loop for async tests."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture
async def test_user_1(session: AsyncSession):
"""Create a test user for conversation tests."""
user = User(
email="user1@example.com",
name="Test User 1",
password_hash="hashed_password_1",
is_active=True,
)
session.add(user)
await session.commit()
await session.refresh(user)
return user
@pytest.fixture
async def test_user_2(session: AsyncSession):
"""Create a second test user to verify isolation."""
user = User(
email="user2@example.com",
name="Test User 2",
password_hash="hashed_password_2",
is_active=True,
)
session.add(user)
await session.commit()
await session.refresh(user)
return user
@pytest.fixture
async def session():
"""Provide a database session for tests."""
async with AsyncSessionLocal() as session:
yield session
# ==============================================================================
# CONVERSATION CREATION TESTS
# ==============================================================================
@pytest.mark.asyncio
async def test_create_conversation_success(test_user_1: User, session: AsyncSession):
"""Test successful conversation creation."""
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
title="Test Conversation"
)
assert conv.id is not None
assert conv.user_id == test_user_1.id
assert conv.title == "Test Conversation"
assert conv.is_deleted == False
assert conv.created_at is not None
assert conv.last_message_at is not None
@pytest.mark.asyncio
async def test_create_conversation_without_title(test_user_1: User, session: AsyncSession):
"""Test conversation creation without title."""
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
assert conv.id is not None
assert conv.user_id == test_user_1.id
assert conv.title is None # Title should be None
assert conv.is_deleted == False
# ==============================================================================
# MESSAGE SAVING TESTS
# ==============================================================================
@pytest.mark.asyncio
async def test_save_message_user_message(test_user_1: User, session: AsyncSession):
"""Test saving a user message to a conversation."""
# Create conversation
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
# Save user message
msg = await ConversationService.save_message(
conversation_id=conv.id,
sender_id=test_user_1.id,
content="Hello, this is my question",
session=session,
)
assert msg.id is not None
assert msg.conversation_id == conv.id
assert msg.sender_id == test_user_1.id
assert msg.content == "Hello, this is my question"
assert msg.created_at is not None
@pytest.mark.asyncio
async def test_save_message_updates_last_message_at(test_user_1: User, session: AsyncSession):
"""Test that saving a message updates conversation's last_message_at."""
# Create conversation
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
original_time = conv.last_message_at
# Save a message
await asyncio.sleep(0.1) # Small delay to ensure timestamp difference
await ConversationService.save_message(
conversation_id=conv.id,
sender_id=test_user_1.id,
content="Test message",
session=session,
)
# Verify last_message_at was updated
stmt = select(Conversation).where(Conversation.id == conv.id)
result = await session.execute(stmt)
updated_conv = result.scalar_one()
assert updated_conv.last_message_at > original_time
@pytest.mark.asyncio
async def test_save_multiple_messages(test_user_1: User, session: AsyncSession):
"""Test saving multiple messages to the same conversation."""
# Create conversation
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
# Save multiple messages
messages = []
for i in range(3):
msg = await ConversationService.save_message(
conversation_id=conv.id,
sender_id=test_user_1.id,
content=f"Message {i}",
session=session,
)
messages.append(msg)
# Verify all messages were saved
stmt = select(Message).where(Message.conversation_id == conv.id).order_by(Message.created_at)
result = await session.execute(stmt)
saved_messages = result.scalars().all()
assert len(saved_messages) == 3
for i, msg in enumerate(saved_messages):
assert msg.content == f"Message {i}"
# ==============================================================================
# CONVERSATION RETRIEVAL TESTS
# ==============================================================================
@pytest.mark.asyncio
async def test_get_user_conversations(test_user_1: User, session: AsyncSession):
"""Test retrieving all conversations for a user."""
# Create multiple conversations
conv_ids = []
for i in range(3):
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
title=f"Conversation {i}"
)
conv_ids.append(conv.id)
# Retrieve conversations
conversations, total = await ConversationService.get_user_conversations(
user_id=test_user_1.id,
session=session,
)
assert len(conversations) >= 3
assert total >= 3
retrieved_ids = [c.id for c in conversations]
for conv_id in conv_ids:
assert conv_id in retrieved_ids
@pytest.mark.asyncio
async def test_get_user_conversations_pagination(test_user_1: User, session: AsyncSession):
"""Test pagination in conversation listing."""
# Create 5 conversations
for i in range(5):
await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
# Get first page
conversations1, total1 = await ConversationService.get_user_conversations(
user_id=test_user_1.id,
session=session,
skip=0,
limit=2,
)
# Get second page
conversations2, total2 = await ConversationService.get_user_conversations(
user_id=test_user_1.id,
session=session,
skip=2,
limit=2,
)
assert len(conversations1) <= 2
assert len(conversations2) <= 2
assert total1 >= 5 # Should have at least 5 conversations total
assert total2 >= 5
@pytest.mark.asyncio
async def test_get_user_conversations_ordered_by_recency(test_user_1: User, session: AsyncSession):
"""Test that conversations are ordered by most recent first."""
# Create conversations with delay
conv1 = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
await asyncio.sleep(0.1)
conv2 = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
# Add message to conv1 to update its timestamp
await asyncio.sleep(0.1)
await ConversationService.save_message(
conversation_id=conv1.id,
sender_id=test_user_1.id,
content="Updated message",
session=session,
)
# Retrieve conversations
conversations, _ = await ConversationService.get_user_conversations(
user_id=test_user_1.id,
session=session,
)
# conv1 should be first because it was updated most recently
assert conversations[0].id == conv1.id
@pytest.mark.asyncio
async def test_get_conversation_with_messages_success(test_user_1: User, session: AsyncSession):
"""Test retrieving conversation with all messages."""
# Create conversation and add messages
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
for i in range(3):
await ConversationService.save_message(
conversation_id=conv.id,
sender_id=test_user_1.id,
content=f"Message {i}",
session=session,
)
# Retrieve conversation with messages
result = await ConversationService.get_conversation_with_messages(
conversation_id=conv.id,
user_id=test_user_1.id,
session=session,
)
assert result is not None
retrieved_conv, messages = result
assert retrieved_conv.id == conv.id
assert len(messages) == 3
assert messages[0].content == "Message 0"
@pytest.mark.asyncio
async def test_get_conversation_not_found(test_user_1: User, session: AsyncSession):
"""Test retrieving non-existent conversation."""
result = await ConversationService.get_conversation_with_messages(
conversation_id=99999,
user_id=test_user_1.id,
session=session,
)
assert result is None
@pytest.mark.asyncio
async def test_get_conversation_authorization_check(
test_user_1: User,
test_user_2: User,
session: AsyncSession
):
"""Test that user cannot access another user's conversation."""
# User 1 creates conversation
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
# User 2 tries to access it
with pytest.raises(PermissionError):
await ConversationService.get_conversation_with_messages(
conversation_id=conv.id,
user_id=test_user_2.id,
session=session,
)
# ==============================================================================
# CONVERSATION SEARCH TESTS
# ==============================================================================
@pytest.mark.asyncio
async def test_search_conversations_by_keyword(test_user_1: User, session: AsyncSession):
"""Test searching conversations by keyword in messages."""
# Create conversation with specific content
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
await ConversationService.save_message(
conversation_id=conv.id,
sender_id=test_user_1.id,
content="Tell me about machine learning",
session=session,
)
# Search for keyword
conversations, total = await ConversationService.search_conversations(
user_id=test_user_1.id,
keyword="machine learning",
session=session,
)
assert len(conversations) >= 1
assert conv.id in [c.id for c in conversations]
@pytest.mark.asyncio
async def test_search_conversations_case_insensitive(test_user_1: User, session: AsyncSession):
"""Test that search is case-insensitive."""
# Create conversation
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
await ConversationService.save_message(
conversation_id=conv.id,
sender_id=test_user_1.id,
content="MACHINE LEARNING",
session=session,
)
# Search with different case
conversations, _ = await ConversationService.search_conversations(
user_id=test_user_1.id,
keyword="machine learning",
session=session,
)
assert len(conversations) >= 1
assert conv.id in [c.id for c in conversations]
@pytest.mark.asyncio
async def test_search_conversations_pagination(test_user_1: User, session: AsyncSession):
"""Test pagination in search results."""
# Create conversations with multiple matching messages
for i in range(5):
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
await ConversationService.save_message(
conversation_id=conv.id,
sender_id=test_user_1.id,
content="python programming",
session=session,
)
# Search with pagination
conversations1, total1 = await ConversationService.search_conversations(
user_id=test_user_1.id,
keyword="python",
session=session,
skip=0,
limit=2,
)
conversations2, total2 = await ConversationService.search_conversations(
user_id=test_user_1.id,
keyword="python",
session=session,
skip=2,
limit=2,
)
assert len(conversations1) <= 2
assert len(conversations2) <= 2
assert total1 >= 5
@pytest.mark.asyncio
async def test_search_user_isolation(
test_user_1: User,
test_user_2: User,
session: AsyncSession
):
"""Test that search only returns current user's conversations."""
# User 1 creates conversation with keyword
conv1 = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
await ConversationService.save_message(
conversation_id=conv1.id,
sender_id=test_user_1.id,
content="unique keyword xyz123",
session=session,
)
# User 2 searches for same keyword
conversations, _ = await ConversationService.search_conversations(
user_id=test_user_2.id,
keyword="xyz123",
session=session,
)
# Should not find User 1's conversation
for conv in conversations:
assert conv.user_id == test_user_2.id
# ==============================================================================
# CONVERSATION DELETION TESTS
# ==============================================================================
@pytest.mark.asyncio
async def test_delete_conversation_success(test_user_1: User, session: AsyncSession):
"""Test successful conversation deletion (soft delete)."""
# Create conversation
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
# Delete conversation
result = await ConversationService.delete_conversation(
conversation_id=conv.id,
user_id=test_user_1.id,
session=session,
)
assert result == True
# Verify it's marked as deleted
stmt = select(Conversation).where(Conversation.id == conv.id)
result = await session.execute(stmt)
conv_check = result.scalar_one()
assert conv_check.is_deleted == True
@pytest.mark.asyncio
async def test_delete_conversation_not_found(test_user_1: User, session: AsyncSession):
"""Test deleting non-existent conversation."""
with pytest.raises(ValueError):
await ConversationService.delete_conversation(
conversation_id=99999,
user_id=test_user_1.id,
session=session,
)
@pytest.mark.asyncio
async def test_delete_conversation_authorization_check(
test_user_1: User,
test_user_2: User,
session: AsyncSession
):
"""Test that user cannot delete another user's conversation."""
# User 1 creates conversation
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
# User 2 tries to delete it
with pytest.raises(PermissionError):
await ConversationService.delete_conversation(
conversation_id=conv.id,
user_id=test_user_2.id,
session=session,
)
@pytest.mark.asyncio
async def test_soft_delete_preserves_data(test_user_1: User, session: AsyncSession):
"""Test that soft delete preserves message data."""
# Create conversation with messages
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
await ConversationService.save_message(
conversation_id=conv.id,
sender_id=test_user_1.id,
content="Important data",
session=session,
)
# Delete conversation
await ConversationService.delete_conversation(
conversation_id=conv.id,
user_id=test_user_1.id,
session=session,
)
# Verify message still exists in database
stmt = select(Message).where(Message.conversation_id == conv.id)
result = await session.execute(stmt)
messages = result.scalars().all()
assert len(messages) == 1
assert messages[0].content == "Important data"
# ==============================================================================
# MESSAGE COUNT TESTS
# ==============================================================================
@pytest.mark.asyncio
async def test_get_message_count_zero(test_user_1: User, session: AsyncSession):
"""Test message count for empty conversation."""
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
count = await ConversationService.get_message_count(
conversation_id=conv.id,
session=session,
)
assert count == 0
@pytest.mark.asyncio
async def test_get_message_count_accurate(test_user_1: User, session: AsyncSession):
"""Test message count is accurate."""
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
for i in range(5):
await ConversationService.save_message(
conversation_id=conv.id,
sender_id=test_user_1.id,
content=f"Message {i}",
session=session,
)
count = await ConversationService.get_message_count(
conversation_id=conv.id,
session=session,
)
assert count == 5
# ==============================================================================
# AUTO-TITLE GENERATION TESTS
# ==============================================================================
@pytest.mark.asyncio
async def test_auto_generate_title(test_user_1: User, session: AsyncSession):
"""Test auto-generating conversation title from first message."""
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
await ConversationService.save_message(
conversation_id=conv.id,
sender_id=test_user_1.id,
content="How do I learn Python programming effectively?",
session=session,
)
title = await ConversationService.auto_generate_title(
conversation_id=conv.id,
session=session,
)
assert title is not None
assert "Python" in title
# Verify title was saved
stmt = select(Conversation).where(Conversation.id == conv.id)
result = await session.execute(stmt)
conv_check = result.scalar_one()
assert conv_check.title is not None
@pytest.mark.asyncio
async def test_auto_generate_title_empty_conversation(test_user_1: User, session: AsyncSession):
"""Test auto-title generation for empty conversation."""
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
title = await ConversationService.auto_generate_title(
conversation_id=conv.id,
session=session,
)
assert title == "Conversation" # Default title
# ==============================================================================
# DELETED CONVERSATIONS EXCLUDED TESTS
# ==============================================================================
@pytest.mark.asyncio
async def test_list_conversations_excludes_deleted(test_user_1: User, session: AsyncSession):
"""Test that deleted conversations are excluded from listing."""
# Create 2 conversations
conv1 = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
conv2 = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
# Delete one
await ConversationService.delete_conversation(
conversation_id=conv2.id,
user_id=test_user_1.id,
session=session,
)
# List conversations
conversations, total = await ConversationService.get_user_conversations(
user_id=test_user_1.id,
session=session,
)
# Should only see non-deleted conversations
conv_ids = [c.id for c in conversations]
assert conv1.id in conv_ids
assert conv2.id not in conv_ids
@pytest.mark.asyncio
async def test_search_excludes_deleted(test_user_1: User, session: AsyncSession):
"""Test that deleted conversations are excluded from search."""
# Create 2 conversations with searchable content
conv1 = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
await ConversationService.save_message(
conversation_id=conv1.id,
sender_id=test_user_1.id,
content="python programming",
session=session,
)
conv2 = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
await ConversationService.save_message(
conversation_id=conv2.id,
sender_id=test_user_1.id,
content="python programming",
session=session,
)
# Delete one
await ConversationService.delete_conversation(
conversation_id=conv2.id,
user_id=test_user_1.id,
session=session,
)
# Search
conversations, _ = await ConversationService.search_conversations(
user_id=test_user_1.id,
keyword="python",
session=session,
)
# Should only see non-deleted conversations
conv_ids = [c.id for c in conversations]
assert conv1.id in conv_ids
assert conv2.id not in conv_ids
# ==============================================================================
# DATABASE ERROR HANDLING TESTS
# ==============================================================================
@pytest.mark.asyncio
async def test_save_message_database_error_rollback(test_user_1: User, session: AsyncSession):
"""Test that database errors are handled and rolled back properly."""
conv = await ConversationService.create_conversation(
user_id=test_user_1.id,
session=session,
)
# Close the session to simulate database error
await session.close()
# Attempt to save message should raise exception
with pytest.raises(Exception):
await ConversationService.save_message(
conversation_id=conv.id,
sender_id=test_user_1.id,
content="This should fail",
session=session,
)
# ==============================================================================
# PYTEST CONFIGURATION
# ==============================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])