File size: 5,020 Bytes
461adca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | """
Test script for session management system.
"""
import asyncio
import sys
from pathlib import Path
# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent / "src"))
from config import get_settings
from src.session import (
SessionManager,
UserSessionState,
generate_session_id,
create_empty_session,
get_session_manager
)
async def test_session_state():
"""Test UserSessionState functionality."""
print("=" * 50)
print("Testing UserSessionState")
print("=" * 50)
# Test session creation
session_id = generate_session_id()
session = create_empty_session(session_id)
print(f"✅ Created session: {session}")
assert session.session_id == session_id
assert not session.has_legal_position()
assert not session.has_search_results()
# Test session data
legal_position = {
"title": "Test Position",
"text": "Test content",
"proceeding": "Civil",
"category": "Contract Law"
}
session.legal_position_json = legal_position
session.update_activity()
print(f"✅ Added legal position: {session.has_legal_position()}")
assert session.has_legal_position()
# Test serialization
session_dict = session.to_dict()
print(f"✅ Serialized to dict: {len(session_dict)} keys")
# Test deserialization
restored_session = UserSessionState.from_dict(session_dict)
print(f"✅ Restored from dict: {restored_session}")
assert restored_session.session_id == session.session_id
assert restored_session.legal_position_json == legal_position
print("✅ UserSessionState tests passed!")
async def test_session_manager():
"""Test SessionManager functionality."""
print("\n" + "=" * 50)
print("Testing SessionManager")
print("=" * 50)
# Create session manager
manager = SessionManager(storage_type="memory")
print(f"✅ Created session manager: {manager}")
# Test session creation
session1 = await manager.get_session()
print(f"✅ Created session 1: {session1}")
session2 = await manager.get_session()
print(f"✅ Created session 2: {session2}")
assert session1.session_id != session2.session_id
# Test session retrieval
retrieved_session1 = await manager.get_session(session1.session_id)
print(f"✅ Retrieved session 1: {retrieved_session1}")
assert retrieved_session1.session_id == session1.session_id
# Test session update
retrieved_session1.legal_position_json = {"test": "data"}
await manager.update_session(retrieved_session1)
# Verify update
updated_session = await manager.get_session(session1.session_id)
print(f"✅ Updated session: {updated_session.has_legal_position()}")
assert updated_session.has_legal_position()
# Test session count
session_count = await manager.get_session_count()
print(f"✅ Session count: {session_count}")
assert session_count >= 2
# Test session deletion
await manager.delete_session(session1.session_id)
deleted_session = await manager.get_session(session1.session_id)
print(f"✅ Session after deletion: {deleted_session}")
assert deleted_session.session_id != session1.session_id # Should create new
# Test cleanup
cleaned_count = await manager.cleanup_expired_sessions()
print(f"✅ Cleaned sessions: {cleaned_count}")
# Shutdown
await manager.shutdown()
print("✅ Session manager shutdown")
print("✅ SessionManager tests passed!")
async def test_global_session_manager():
"""Test global session manager."""
print("\n" + "=" * 50)
print("Testing Global Session Manager")
print("=" * 50)
# Get global manager
manager1 = get_session_manager()
manager2 = get_session_manager()
print(f"✅ Global manager instances: {manager1 is manager2}")
assert manager1 is manager2 # Should be same instance
# Test with global manager
session = await manager1.get_session()
print(f"✅ Global session: {session}")
print("✅ Global session manager tests passed!")
async def main():
"""Run all session tests."""
print("🧪 Starting Session System Tests")
try:
# Load configuration (without API key validation for testing)
settings = get_settings(validate_api_keys=False)
print(f"✅ Configuration loaded: {settings.app.environment}")
# Run tests
await test_session_state()
await test_session_manager()
await test_global_session_manager()
print("\n" + "=" * 60)
print("🎉 All session tests passed!")
print("=" * 60)
except Exception as e:
print(f"\n❌ Test failed: {e}")
import traceback
traceback.print_exc()
return False
return True
if __name__ == "__main__":
success = asyncio.run(main())
exit(0 if success else 1)
|