Spaces:
Runtime error
Runtime error
| """ | |
| Session Management Tests & Validation | |
| Tests for multi-user session isolation, persistence, and state management. | |
| Can be run with: python session_tests.py | |
| """ | |
| import tempfile | |
| import shutil | |
| import json | |
| from pathlib import Path | |
| import uuid | |
| from datetime import datetime | |
| # Mock HF storage path for testing | |
| TEST_STORAGE_DIR = Path(tempfile.mkdtemp(prefix="session_test_")) | |
| def setup_test_storage(): | |
| """Set up test storage directory.""" | |
| global TEST_STORAGE_DIR | |
| TEST_STORAGE_DIR = Path(tempfile.mkdtemp(prefix="session_test_")) | |
| print(f"Test storage: {TEST_STORAGE_DIR}") | |
| def teardown_test_storage(): | |
| """Clean up test storage directory.""" | |
| global TEST_STORAGE_DIR | |
| if TEST_STORAGE_DIR.exists(): | |
| shutil.rmtree(TEST_STORAGE_DIR) | |
| print("Test storage cleaned up") | |
| def test_user_isolation(): | |
| """Test that different users don't see each other's sessions.""" | |
| print("\n=== Test: User Isolation ===") | |
| from session_manager import SessionManager | |
| # Create two users | |
| user1_id = str(uuid.uuid4()) | |
| user2_id = str(uuid.uuid4()) | |
| sm1 = SessionManager(user1_id) | |
| sm2 = SessionManager(user2_id) | |
| # User 1 creates sessions | |
| sid1, meta1 = sm1.create_session("my_em_sim", "EM", "User 1's EM simulation") | |
| sid2, meta2 = sm1.create_session("my_fluid_sim", "QLBM", "User 1's fluid simulation") | |
| # User 2 creates sessions | |
| sid3, meta3 = sm2.create_session("my_em_sim", "EM", "User 2's EM simulation") # Same alias, different user | |
| sid4, meta4 = sm2.create_session("test_fluid", "QLBM", "User 2's fluid simulation") | |
| # Verify isolation | |
| user1_sessions = sm1.list_all_sessions() | |
| user2_sessions = sm2.list_all_sessions() | |
| print(f"User 1 sessions: {len(user1_sessions)}") | |
| print(f"User 2 sessions: {len(user2_sessions)}") | |
| assert len(user1_sessions) == 2, "User 1 should have 2 sessions" | |
| assert len(user2_sessions) == 2, "User 2 should have 2 sessions" | |
| # Verify users only see their own aliases | |
| user1_aliases = {s.alias for s in user1_sessions} | |
| user2_aliases = {s.alias for s in user2_sessions} | |
| assert "my_em_sim" in user1_aliases, "User 1 should have 'my_em_sim'" | |
| assert "my_fluid_sim" in user1_aliases, "User 1 should have 'my_fluid_sim'" | |
| assert "my_em_sim" in user2_aliases, "User 2 should have 'my_em_sim'" | |
| assert "test_fluid" in user2_aliases, "User 2 should have 'test_fluid'" | |
| print("β User isolation working correctly") | |
| def test_alias_collision_resolution(): | |
| """Test handling of duplicate aliases within a user's sessions.""" | |
| print("\n=== Test: Alias Collision Resolution ===") | |
| from session_manager import SessionManager | |
| user_id = str(uuid.uuid4()) | |
| sm = SessionManager(user_id) | |
| # Create multiple sessions with same alias | |
| sid1, _ = sm.create_session("simulation_v1", "EM") | |
| sid2, _ = sm.create_session("simulation_v1", "EM") # Same alias, different session | |
| sid3, _ = sm.create_session("simulation_v1", "QLBM") # Same alias, different app | |
| # Get by alias (should return sorted by recency) | |
| matches = sm.get_by_alias("simulation_v1") | |
| print(f"Found {len(matches)} matches for alias 'simulation_v1'") | |
| assert len(matches) == 3, "Should find 3 sessions with same alias" | |
| # Most recent should be first | |
| most_recent = sm.get_most_recent_by_alias("simulation_v1") | |
| assert most_recent is not None, "Should find most recent session" | |
| assert most_recent[1] == sid3, "Most recent should be the last created session" | |
| print("β Alias collision resolution working correctly") | |
| def test_state_persistence(): | |
| """Test that session state is correctly saved and restored.""" | |
| print("\n=== Test: State Persistence ===") | |
| from session_manager import SessionManager | |
| from session_models import SessionState | |
| user_id = str(uuid.uuid4()) | |
| sm = SessionManager(user_id) | |
| # Create session | |
| sid, meta = sm.create_session("persistence_test", "EM") | |
| # Load and modify state | |
| meta1, state1 = sm.load_session(sid) | |
| state1.state_data["grid_size"] = 32 | |
| state1.state_data["frequency"] = 2.4e9 | |
| state1.state_data["backend"] = "qiskit_ibm" | |
| # Save | |
| save_success = sm.save_session(meta1, state1) | |
| assert save_success, "Save should succeed" | |
| # Reload and verify | |
| meta2, state2 = sm.load_session(sid) | |
| assert state2.state_data["grid_size"] == 32, "Grid size should persist" | |
| assert state2.state_data["frequency"] == 2.4e9, "Frequency should persist" | |
| assert state2.state_data["backend"] == "qiskit_ibm", "Backend should persist" | |
| print("β State persistence working correctly") | |
| def test_job_tracking(): | |
| """Test job submission tracking within sessions.""" | |
| print("\n=== Test: Job Tracking ===") | |
| from session_manager import SessionManager | |
| user_id = str(uuid.uuid4()) | |
| sm = SessionManager(user_id) | |
| # Create session | |
| sid, meta = sm.create_session("job_tracking_test", "EM") | |
| # Add jobs | |
| success1 = sm.add_job_to_session(sid, "job_ibm_001", "qiskit_ibm") | |
| success2 = sm.add_job_to_session(sid, "job_ionq_001", "ionq") | |
| assert success1 and success2, "Job additions should succeed" | |
| # Load session and verify jobs | |
| meta, state = sm.load_session(sid) | |
| assert len(state.submitted_jobs) == 2, "Should have 2 jobs" | |
| job_ids = {j.job_id for j in state.submitted_jobs} | |
| assert "job_ibm_001" in job_ids, "Should have IBM job" | |
| assert "job_ionq_001" in job_ids, "Should have IonQ job" | |
| # Update job status | |
| updated = sm.update_job_status(sid, "job_ibm_001", "completed", {"result": "success"}) | |
| assert updated, "Job update should succeed" | |
| # Verify update | |
| meta, state = sm.load_session(sid) | |
| ibm_job = next((j for j in state.submitted_jobs if j.job_id == "job_ibm_001"), None) | |
| assert ibm_job is not None, "IBM job should exist" | |
| assert ibm_job.status == "completed", "Job status should be updated" | |
| assert ibm_job.result_data["result"] == "success", "Result should be stored" | |
| print("β Job tracking working correctly") | |
| def test_session_deletion(): | |
| """Test session deletion and cleanup.""" | |
| print("\n=== Test: Session Deletion ===") | |
| from session_manager import SessionManager | |
| user_id = str(uuid.uuid4()) | |
| sm = SessionManager(user_id) | |
| # Create sessions | |
| sid1, _ = sm.create_session("to_delete", "EM") | |
| sid2, _ = sm.create_session("to_keep", "EM") | |
| sessions_before = sm.list_all_sessions() | |
| assert len(sessions_before) == 2, "Should have 2 sessions" | |
| # Delete one | |
| delete_success = sm.delete_session(sid1) | |
| assert delete_success, "Deletion should succeed" | |
| # Verify | |
| sessions_after = sm.list_all_sessions() | |
| assert len(sessions_after) == 1, "Should have 1 session after deletion" | |
| assert sessions_after[0].session_id == sid2, "Remaining session should be the one we kept" | |
| print("β Session deletion working correctly") | |
| def test_concurrent_access(): | |
| """Test that concurrent access from multiple users doesn't cause conflicts.""" | |
| print("\n=== Test: Concurrent Access ===") | |
| import threading | |
| from session_manager import SessionManager | |
| results = [] | |
| def user_workflow(user_idx: int): | |
| user_id = f"user_{user_idx}" | |
| sm = SessionManager(user_id) | |
| try: | |
| # Create sessions | |
| sid1, _ = sm.create_session(f"session_{user_idx}_1", "EM") | |
| sid2, _ = sm.create_session(f"session_{user_idx}_2", "QLBM") | |
| # Load and modify | |
| for sid in [sid1, sid2]: | |
| meta, state = sm.load_session(sid) | |
| state.state_data[f"user_{user_idx}_data"] = f"data_{user_idx}" | |
| sm.save_session(meta, state) | |
| # List sessions | |
| sessions = sm.list_all_sessions() | |
| results.append((user_idx, len(sessions), True, None)) | |
| except Exception as e: | |
| results.append((user_idx, 0, False, str(e))) | |
| # Create threads for multiple users | |
| threads = [] | |
| num_users = 5 | |
| for i in range(num_users): | |
| t = threading.Thread(target=user_workflow, args=(i,)) | |
| threads.append(t) | |
| t.start() | |
| # Wait for all threads | |
| for t in threads: | |
| t.join() | |
| # Verify results | |
| print(f"Concurrent user workflows: {len(results)}") | |
| for user_idx, session_count, success, error in results: | |
| if success: | |
| print(f" User {user_idx}: β ({session_count} sessions)") | |
| else: | |
| print(f" User {user_idx}: β ({error})") | |
| assert all(r[2] for r in results), "All concurrent operations should succeed" | |
| print("β Concurrent access working correctly") | |
| def run_all_tests(): | |
| """Run all tests.""" | |
| print("=" * 60) | |
| print("SESSION MANAGEMENT TEST SUITE") | |
| print("=" * 60) | |
| try: | |
| setup_test_storage() | |
| test_user_isolation() | |
| test_alias_collision_resolution() | |
| test_state_persistence() | |
| test_job_tracking() | |
| test_session_deletion() | |
| test_concurrent_access() | |
| print("\n" + "=" * 60) | |
| print("ALL TESTS PASSED β") | |
| print("=" * 60) | |
| except AssertionError as e: | |
| print(f"\nβ TEST FAILED: {e}") | |
| return False | |
| except Exception as e: | |
| print(f"\nβ UNEXPECTED ERROR: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| finally: | |
| teardown_test_storage() | |
| return True | |
| if __name__ == "__main__": | |
| success = run_all_tests() | |
| exit(0 if success else 1) | |