Spaces:
Runtime error
Runtime error
| """ | |
| Session Integration Module | |
| Provides session management hooks for EM and QLBM modules. | |
| Handles session loading, state restoration, and auto-saving. | |
| """ | |
| from typing import Optional, Dict, Any, Callable | |
| import json | |
| from pathlib import Path | |
| from session_manager import SessionManager | |
| from session_models import SessionMetadata, SessionState | |
| class SessionIntegration: | |
| """ | |
| Manages session state synchronization with EM/QLBM modules. | |
| Handles loading session state into app modules and saving changes back. | |
| """ | |
| def __init__(self, user_id: str): | |
| """Initialize session integration for a user.""" | |
| self.user_id = user_id | |
| self.session_manager = SessionManager(user_id) | |
| self.current_session_id: Optional[str] = None | |
| self.current_metadata: Optional[SessionMetadata] = None | |
| self.current_state: Optional[SessionState] = None | |
| self.auto_save_enabled = True | |
| def create_new_session( | |
| self, | |
| alias: str, | |
| app_type: str, | |
| description: str = "" | |
| ) -> tuple[str, SessionMetadata]: | |
| """Create a new session.""" | |
| session_id, metadata = self.session_manager.create_session( | |
| alias=alias, | |
| app_type=app_type, | |
| description=description, | |
| ) | |
| self.current_session_id = session_id | |
| self.current_metadata = metadata | |
| self.current_state = SessionState( | |
| session_id=session_id, | |
| app_type=app_type, | |
| ) | |
| return session_id, metadata | |
| def load_session(self, session_id: str) -> tuple[SessionMetadata, SessionState]: | |
| """Load a session by ID.""" | |
| metadata, state = self.session_manager.load_session(session_id) | |
| self.current_session_id = session_id | |
| self.current_metadata = metadata | |
| self.current_state = state | |
| return metadata, state | |
| def load_by_alias(self, alias: str) -> Optional[tuple[SessionMetadata, SessionState]]: | |
| """Load the most recent session matching an alias.""" | |
| result = self.session_manager.get_most_recent_by_alias(alias) | |
| if result: | |
| metadata, session_id = result | |
| return self.load_session(session_id) | |
| return None | |
| def get_state_dict(self) -> Dict[str, Any]: | |
| """Get the current session state as a dict (e.g., for saving to disk).""" | |
| if not self.current_state: | |
| return {} | |
| return self.current_state.to_dict() | |
| def restore_state_dict(self, state_dict: Dict[str, Any]) -> None: | |
| """Restore session state from a dict.""" | |
| if self.current_session_id and state_dict: | |
| self.current_state = SessionState.from_dict(state_dict) | |
| def save_current_session(self, trame_state: Optional[Any] = None) -> bool: | |
| """ | |
| Save the current session. | |
| Args: | |
| trame_state: Optional Trame state object to capture app state | |
| Returns: | |
| True if successful | |
| """ | |
| if not self.current_session_id or not self.current_metadata or not self.current_state: | |
| return False | |
| # Capture Trame state if provided | |
| if trame_state: | |
| self._capture_trame_state(trame_state) | |
| return self.session_manager.save_session(self.current_metadata, self.current_state) | |
| def _capture_trame_state(self, trame_state: Any) -> None: | |
| """ | |
| Capture Trame server state into session state_data. | |
| Handles serialization of complex types. | |
| """ | |
| if not self.current_state: | |
| return | |
| # Collect app-specific state variables | |
| captured_keys = [ | |
| # EM-specific | |
| "grid_size", "frequency", "excitation_type", "excitation_frequency", | |
| "geometry_type", "outer_boundary_condition", | |
| "num_qubits", "t_final", "dt", "backend_type", "selected_simulator", | |
| "selected_qpu", "aqc_enabled", "measurement_type", | |
| # QLBM-specific | |
| "grid_size_qlbm", "num_steps", "initial_condition", "viscosity", | |
| "flow_field_type", | |
| ] | |
| state_data = {} | |
| for key in captured_keys: | |
| if hasattr(trame_state, key): | |
| value = getattr(trame_state, key) | |
| # Only store serializable types | |
| try: | |
| json.dumps(value) # Test if serializable | |
| state_data[key] = value | |
| except (TypeError, ValueError): | |
| # Skip non-serializable values | |
| pass | |
| self.current_state.state_data.update(state_data) | |
| def restore_to_trame_state(self, trame_state: Any) -> None: | |
| """ | |
| Restore session state_data back into Trame state. | |
| Args: | |
| trame_state: Trame state object to restore to | |
| """ | |
| if not self.current_state: | |
| return | |
| for key, value in self.current_state.state_data.items(): | |
| if hasattr(trame_state, key): | |
| setattr(trame_state, key, value) | |
| def add_job(self, job_id: str, service_type: str) -> bool: | |
| """Track a submitted job in the session.""" | |
| if not self.current_session_id: | |
| return False | |
| return self.session_manager.add_job_to_session( | |
| self.current_session_id, | |
| job_id, | |
| service_type, | |
| ) | |
| def update_job_status( | |
| self, | |
| job_id: str, | |
| status: str, | |
| result: Optional[Dict[str, Any]] = None | |
| ) -> bool: | |
| """Update job status in the session.""" | |
| if not self.current_session_id: | |
| return False | |
| return self.session_manager.update_job_status( | |
| self.current_session_id, | |
| job_id, | |
| status, | |
| result, | |
| ) | |
| def get_current_session_info(self) -> Optional[Dict[str, Any]]: | |
| """Get current session metadata as a dict.""" | |
| if not self.current_metadata: | |
| return None | |
| return self.current_metadata.to_dict() | |
| class AutoSaveContext: | |
| """ | |
| Context manager for auto-saving session changes. | |
| Usage: | |
| with AutoSaveContext(session_integration, trame_state): | |
| # Make changes | |
| pass # Auto-saves on exit | |
| """ | |
| def __init__( | |
| self, | |
| session_integration: SessionIntegration, | |
| trame_state: Any, | |
| capture_state: bool = True | |
| ): | |
| self.session_integration = session_integration | |
| self.trame_state = trame_state | |
| self.capture_state = capture_state | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| if self.session_integration.auto_save_enabled: | |
| if self.capture_state: | |
| self.session_integration._capture_trame_state(self.trame_state) | |
| self.session_integration.save_current_session() | |
| return False | |
| class SessionAutoSaveHook: | |
| """ | |
| Decorator for automatically saving sessions after specific operations. | |
| Usage: | |
| @SessionAutoSaveHook(session_integration, trame_state) | |
| def some_operation(): | |
| # Make changes | |
| pass # Auto-saves after execution | |
| """ | |
| def __init__(self, session_integration: SessionIntegration, trame_state: Any): | |
| self.session_integration = session_integration | |
| self.trame_state = trame_state | |
| def __call__(self, func: Callable) -> Callable: | |
| def wrapper(*args, **kwargs): | |
| try: | |
| result = func(*args, **kwargs) | |
| finally: | |
| if self.session_integration.auto_save_enabled: | |
| self.session_integration._capture_trame_state(self.trame_state) | |
| self.session_integration.save_current_session() | |
| return result | |
| return wrapper | |