| | """ |
| | Test script for session management system. |
| | """ |
| | import asyncio |
| | import sys |
| | from pathlib import Path |
| |
|
| | |
| | sys.path.insert(0, str(Path(__file__).parent / "src")) |
| |
|
| | from config import get_settings |
| | from src.session import ( |
| | SessionManager, |
| | UserSessionState, |
| | generate_session_id, |
| | create_empty_session, |
| | get_session_manager |
| | ) |
| |
|
| |
|
| | async def test_session_state(): |
| | """Test UserSessionState functionality.""" |
| | print("=" * 50) |
| | print("Testing UserSessionState") |
| | print("=" * 50) |
| | |
| | |
| | session_id = generate_session_id() |
| | session = create_empty_session(session_id) |
| | |
| | print(f"✅ Created session: {session}") |
| | assert session.session_id == session_id |
| | assert not session.has_legal_position() |
| | assert not session.has_search_results() |
| | |
| | |
| | legal_position = { |
| | "title": "Test Position", |
| | "text": "Test content", |
| | "proceeding": "Civil", |
| | "category": "Contract Law" |
| | } |
| | |
| | session.legal_position_json = legal_position |
| | session.update_activity() |
| | |
| | print(f"✅ Added legal position: {session.has_legal_position()}") |
| | assert session.has_legal_position() |
| | |
| | |
| | session_dict = session.to_dict() |
| | print(f"✅ Serialized to dict: {len(session_dict)} keys") |
| | |
| | |
| | restored_session = UserSessionState.from_dict(session_dict) |
| | print(f"✅ Restored from dict: {restored_session}") |
| | assert restored_session.session_id == session.session_id |
| | assert restored_session.legal_position_json == legal_position |
| | |
| | print("✅ UserSessionState tests passed!") |
| |
|
| |
|
| | async def test_session_manager(): |
| | """Test SessionManager functionality.""" |
| | print("\n" + "=" * 50) |
| | print("Testing SessionManager") |
| | print("=" * 50) |
| | |
| | |
| | manager = SessionManager(storage_type="memory") |
| | print(f"✅ Created session manager: {manager}") |
| | |
| | |
| | session1 = await manager.get_session() |
| | print(f"✅ Created session 1: {session1}") |
| | |
| | session2 = await manager.get_session() |
| | print(f"✅ Created session 2: {session2}") |
| | |
| | assert session1.session_id != session2.session_id |
| | |
| | |
| | retrieved_session1 = await manager.get_session(session1.session_id) |
| | print(f"✅ Retrieved session 1: {retrieved_session1}") |
| | assert retrieved_session1.session_id == session1.session_id |
| | |
| | |
| | retrieved_session1.legal_position_json = {"test": "data"} |
| | await manager.update_session(retrieved_session1) |
| | |
| | |
| | updated_session = await manager.get_session(session1.session_id) |
| | print(f"✅ Updated session: {updated_session.has_legal_position()}") |
| | assert updated_session.has_legal_position() |
| | |
| | |
| | session_count = await manager.get_session_count() |
| | print(f"✅ Session count: {session_count}") |
| | assert session_count >= 2 |
| | |
| | |
| | await manager.delete_session(session1.session_id) |
| | deleted_session = await manager.get_session(session1.session_id) |
| | print(f"✅ Session after deletion: {deleted_session}") |
| | assert deleted_session.session_id != session1.session_id |
| | |
| | |
| | cleaned_count = await manager.cleanup_expired_sessions() |
| | print(f"✅ Cleaned sessions: {cleaned_count}") |
| | |
| | |
| | await manager.shutdown() |
| | print("✅ Session manager shutdown") |
| | |
| | print("✅ SessionManager tests passed!") |
| |
|
| |
|
| | async def test_global_session_manager(): |
| | """Test global session manager.""" |
| | print("\n" + "=" * 50) |
| | print("Testing Global Session Manager") |
| | print("=" * 50) |
| | |
| | |
| | manager1 = get_session_manager() |
| | manager2 = get_session_manager() |
| | |
| | print(f"✅ Global manager instances: {manager1 is manager2}") |
| | assert manager1 is manager2 |
| | |
| | |
| | session = await manager1.get_session() |
| | print(f"✅ Global session: {session}") |
| | |
| | print("✅ Global session manager tests passed!") |
| |
|
| |
|
| | async def main(): |
| | """Run all session tests.""" |
| | print("🧪 Starting Session System Tests") |
| | |
| | try: |
| | |
| | settings = get_settings(validate_api_keys=False) |
| | print(f"✅ Configuration loaded: {settings.app.environment}") |
| | |
| | |
| | await test_session_state() |
| | await test_session_manager() |
| | await test_global_session_manager() |
| | |
| | print("\n" + "=" * 60) |
| | print("🎉 All session tests passed!") |
| | print("=" * 60) |
| | |
| | except Exception as e: |
| | print(f"\n❌ Test failed: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return False |
| | |
| | return True |
| |
|
| |
|
| | if __name__ == "__main__": |
| | success = asyncio.run(main()) |
| | exit(0 if success else 1) |
| |
|