""" Integration tests for message persistence. """ import pytest from sqlmodel import Session, create_engine from datetime import datetime, timedelta from backend.models.conversation import Conversation from backend.models.message import Message @pytest.fixture def test_engine(): """Create a test database engine.""" engine = create_engine("sqlite:///:memory:", echo=False) return engine @pytest.fixture def session(test_engine): """Create a test database session.""" from sqlmodel import SQLModel SQLModel.metadata.create_all(test_engine) with Session(test_engine) as session: yield session SQLModel.metadata.drop_all(test_engine) def test_message_persistence_across_sessions(session, test_engine): """Test messages persist across different sessions.""" # Create conversation and messages in session 1 conversation = Conversation(user_id="user_abc123") session.add(conversation) session.commit() session.refresh(conversation) conv_id = conversation.id # Add messages for i in range(3): message = Message( conversation_id=conv_id, user_id="user_abc123", role="user", content=f"Message {i+1}" ) session.add(message) session.commit() # Start new session and verify messages exist with Session(test_engine) as new_session: from sqlmodel import select messages = new_session.exec(select(Message).where( Message.conversation_id == conv_id )).all() assert len(messages) == 3 assert messages[0].content == "Message 1" assert messages[1].content == "Message 2" assert messages[2].content == "Message 3" def test_cross_session_retrieval_chronological(session, test_engine): """Test messages maintain chronological order across sessions.""" # Setup: Create conversation and messages conversation = Conversation(user_id="user_abc123") session.add(conversation) session.commit() session.refresh(conversation) conv_id = conversation.id base_time = datetime.utcnow() # Add messages with specific timestamps for i in range(3): message = Message( conversation_id=conv_id, user_id="user_abc123", role="user", content=f"Message {i+1}" ) message.created_at = base_time + timedelta(seconds=i * 10) session.add(message) session.commit() # New session: Retrieve and verify order with Session(test_engine) as new_session: from sqlmodel import select messages = new_session.exec( select(Message).where( Message.conversation_id == conv_id ).order_by(Message.created_at) ).all() assert len(messages) == 3 # Verify content and timestamps in order for i, message in enumerate(messages): assert message.content == f"Message {i+1}" assert message.created_at == base_time + timedelta(seconds=i * 10) def test_user_isolation_in_messages(session, test_engine): """Test messages are isolated by user_id.""" # Create conversations for different users user1_conv = Conversation(user_id="user_abc123") user2_conv = Conversation(user_id="user_xyz456") session.add(user1_conv) session.add(user2_conv) session.commit() session.refresh(user1_conv) session.refresh(user2_conv) # Add messages for user1 msg1 = Message( conversation_id=user1_conv.id, user_id="user_abc123", role="user", content="User1 message" ) session.add(msg1) # Add messages for user2 msg2 = Message( conversation_id=user2_conv.id, user_id="user_xyz456", role="user", content="User2 message" ) session.add(msg2) session.commit() # Verify isolation with Session(test_engine) as new_session: from sqlmodel import select user1_messages = new_session.exec(select(Message).where( Message.user_id == "user_abc123" )).all() user2_messages = new_session.exec(select(Message).where( Message.user_id == "user_xyz456" )).all() assert len(user1_messages) == 1 assert len(user2_messages) == 1 assert user1_messages[0].content == "User1 message" assert user2_messages[0].content == "User2 message" def test_message_survives_server_restart_simulation(session, test_engine): """Test messages survive simulated server restart (session close/reopen).""" # Create conversation and messages conversation = Conversation(user_id="user_abc123") session.add(conversation) session.commit() session.refresh(conversation) conv_id = conversation.id for i in range(5): message = Message( conversation_id=conv_id, user_id="user_abc123", role="user" if i % 2 == 0 else "assistant", content=f"Message {i+1}" ) session.add(message) session.commit() # Simulate server restart: close and reopen session session.close() # New session with Session(test_engine) as new_session: from sqlmodel import select messages = new_session.exec( select(Message).where( Message.conversation_id == conv_id ).order_by(Message.created_at) ).all() # Verify all messages survived assert len(messages) == 5 for i, message in enumerate(messages): assert message.content == f"Message {i+1}" def test_message_tool_calls_persistence(session, test_engine): """Test tool_calls JSONB data persists correctly.""" # Create conversation conversation = Conversation(user_id="user_abc123") session.add(conversation) session.commit() session.refresh(conversation) # Create message with tool calls tool_calls = { "name": "add_task", "parameters": {"title": "Buy groceries", "description": "Milk, eggs"}, "result": {"task_id": 42, "status": "created"} } message = Message( conversation_id=conversation.id, user_id="user_abc123", role="assistant", content="I've created that task for you.", tool_calls=tool_calls ) session.add(message) session.commit() # Retrieve and verify tool_calls with Session(test_engine) as new_session: retrieved = new_session.get(Message, message.id) assert retrieved is not None assert retrieved.tool_calls is not None assert retrieved.tool_calls["name"] == "add_task" assert retrieved.tool_calls["parameters"]["title"] == "Buy groceries" assert retrieved.tool_calls["result"]["task_id"] == 42 def test_empty_conversation_start(session): """Test a conversation starts empty (no messages yet).""" # Create conversation conversation = Conversation(user_id="user_abc123") session.add(conversation) session.commit() session.refresh(conversation) # Verify no messages yet from sqlmodel import select messages = session.exec(select(Message).where( Message.conversation_id == conversation.id )).all() assert len(messages) == 0 def test_message_roles_persist(session, test_engine): """Test message roles persist correctly.""" # Create conversation conversation = Conversation(user_id="user_abc123") session.add(conversation) session.commit() session.refresh(conversation) # Add messages with different roles user_msg = Message( conversation_id=conversation.id, user_id="user_abc123", role="user", content="User message" ) assistant_msg = Message( conversation_id=conversation.id, user_id="user_abc123", role="assistant", content="Assistant response" ) session.add(user_msg) session.add(assistant_msg) session.commit() # Retrieve and verify roles with Session(test_engine) as new_session: from sqlmodel import select messages = new_session.exec( select(Message).where( Message.conversation_id == conversation.id ).order_by(Message.created_at) ).all() assert len(messages) == 2 assert messages[0].role == "user" assert messages[1].role == "assistant"