""" Session Integration Module Provides session management hooks for EM and QLBM modules. Handles session loading, state restoration, and auto-saving. """ from typing import Optional, Dict, Any, Callable import json from pathlib import Path from session_manager import SessionManager from session_models import SessionMetadata, SessionState class SessionIntegration: """ Manages session state synchronization with EM/QLBM modules. Handles loading session state into app modules and saving changes back. """ def __init__(self, user_id: str): """Initialize session integration for a user.""" self.user_id = user_id self.session_manager = SessionManager(user_id) self.current_session_id: Optional[str] = None self.current_metadata: Optional[SessionMetadata] = None self.current_state: Optional[SessionState] = None self.auto_save_enabled = True def create_new_session( self, alias: str, app_type: str, description: str = "" ) -> tuple[str, SessionMetadata]: """Create a new session.""" session_id, metadata = self.session_manager.create_session( alias=alias, app_type=app_type, description=description, ) self.current_session_id = session_id self.current_metadata = metadata self.current_state = SessionState( session_id=session_id, app_type=app_type, ) return session_id, metadata def load_session(self, session_id: str) -> tuple[SessionMetadata, SessionState]: """Load a session by ID.""" metadata, state = self.session_manager.load_session(session_id) self.current_session_id = session_id self.current_metadata = metadata self.current_state = state return metadata, state def load_by_alias(self, alias: str) -> Optional[tuple[SessionMetadata, SessionState]]: """Load the most recent session matching an alias.""" result = self.session_manager.get_most_recent_by_alias(alias) if result: metadata, session_id = result return self.load_session(session_id) return None def get_state_dict(self) -> Dict[str, Any]: """Get the current session state as a dict (e.g., for saving to disk).""" if not self.current_state: return {} return self.current_state.to_dict() def restore_state_dict(self, state_dict: Dict[str, Any]) -> None: """Restore session state from a dict.""" if self.current_session_id and state_dict: self.current_state = SessionState.from_dict(state_dict) def save_current_session(self, trame_state: Optional[Any] = None) -> bool: """ Save the current session. Args: trame_state: Optional Trame state object to capture app state Returns: True if successful """ if not self.current_session_id or not self.current_metadata or not self.current_state: return False # Capture Trame state if provided if trame_state: self._capture_trame_state(trame_state) return self.session_manager.save_session(self.current_metadata, self.current_state) def _capture_trame_state(self, trame_state: Any) -> None: """ Capture Trame server state into session state_data. Handles serialization of complex types. """ if not self.current_state: return # Collect app-specific state variables captured_keys = [ # EM-specific "grid_size", "frequency", "excitation_type", "excitation_frequency", "geometry_type", "outer_boundary_condition", "num_qubits", "t_final", "dt", "backend_type", "selected_simulator", "selected_qpu", "aqc_enabled", "measurement_type", # QLBM-specific "grid_size_qlbm", "num_steps", "initial_condition", "viscosity", "flow_field_type", ] state_data = {} for key in captured_keys: if hasattr(trame_state, key): value = getattr(trame_state, key) # Only store serializable types try: json.dumps(value) # Test if serializable state_data[key] = value except (TypeError, ValueError): # Skip non-serializable values pass self.current_state.state_data.update(state_data) def restore_to_trame_state(self, trame_state: Any) -> None: """ Restore session state_data back into Trame state. Args: trame_state: Trame state object to restore to """ if not self.current_state: return for key, value in self.current_state.state_data.items(): if hasattr(trame_state, key): setattr(trame_state, key, value) def add_job(self, job_id: str, service_type: str) -> bool: """Track a submitted job in the session.""" if not self.current_session_id: return False return self.session_manager.add_job_to_session( self.current_session_id, job_id, service_type, ) def update_job_status( self, job_id: str, status: str, result: Optional[Dict[str, Any]] = None ) -> bool: """Update job status in the session.""" if not self.current_session_id: return False return self.session_manager.update_job_status( self.current_session_id, job_id, status, result, ) def get_current_session_info(self) -> Optional[Dict[str, Any]]: """Get current session metadata as a dict.""" if not self.current_metadata: return None return self.current_metadata.to_dict() class AutoSaveContext: """ Context manager for auto-saving session changes. Usage: with AutoSaveContext(session_integration, trame_state): # Make changes pass # Auto-saves on exit """ def __init__( self, session_integration: SessionIntegration, trame_state: Any, capture_state: bool = True ): self.session_integration = session_integration self.trame_state = trame_state self.capture_state = capture_state def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): if self.session_integration.auto_save_enabled: if self.capture_state: self.session_integration._capture_trame_state(self.trame_state) self.session_integration.save_current_session() return False class SessionAutoSaveHook: """ Decorator for automatically saving sessions after specific operations. Usage: @SessionAutoSaveHook(session_integration, trame_state) def some_operation(): # Make changes pass # Auto-saves after execution """ def __init__(self, session_integration: SessionIntegration, trame_state: Any): self.session_integration = session_integration self.trame_state = trame_state def __call__(self, func: Callable) -> Callable: def wrapper(*args, **kwargs): try: result = func(*args, **kwargs) finally: if self.session_integration.auto_save_enabled: self.session_integration._capture_trame_state(self.trame_state) self.session_integration.save_current_session() return result return wrapper