dummyQuantum / session_integration.py
Apurva Tiwari
feature: sessions, init
ca961b4
"""
Session Integration Module
Provides session management hooks for EM and QLBM modules.
Handles session loading, state restoration, and auto-saving.
"""
from typing import Optional, Dict, Any, Callable
import json
from pathlib import Path
from session_manager import SessionManager
from session_models import SessionMetadata, SessionState
class SessionIntegration:
"""
Manages session state synchronization with EM/QLBM modules.
Handles loading session state into app modules and saving changes back.
"""
def __init__(self, user_id: str):
"""Initialize session integration for a user."""
self.user_id = user_id
self.session_manager = SessionManager(user_id)
self.current_session_id: Optional[str] = None
self.current_metadata: Optional[SessionMetadata] = None
self.current_state: Optional[SessionState] = None
self.auto_save_enabled = True
def create_new_session(
self,
alias: str,
app_type: str,
description: str = ""
) -> tuple[str, SessionMetadata]:
"""Create a new session."""
session_id, metadata = self.session_manager.create_session(
alias=alias,
app_type=app_type,
description=description,
)
self.current_session_id = session_id
self.current_metadata = metadata
self.current_state = SessionState(
session_id=session_id,
app_type=app_type,
)
return session_id, metadata
def load_session(self, session_id: str) -> tuple[SessionMetadata, SessionState]:
"""Load a session by ID."""
metadata, state = self.session_manager.load_session(session_id)
self.current_session_id = session_id
self.current_metadata = metadata
self.current_state = state
return metadata, state
def load_by_alias(self, alias: str) -> Optional[tuple[SessionMetadata, SessionState]]:
"""Load the most recent session matching an alias."""
result = self.session_manager.get_most_recent_by_alias(alias)
if result:
metadata, session_id = result
return self.load_session(session_id)
return None
def get_state_dict(self) -> Dict[str, Any]:
"""Get the current session state as a dict (e.g., for saving to disk)."""
if not self.current_state:
return {}
return self.current_state.to_dict()
def restore_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Restore session state from a dict."""
if self.current_session_id and state_dict:
self.current_state = SessionState.from_dict(state_dict)
def save_current_session(self, trame_state: Optional[Any] = None) -> bool:
"""
Save the current session.
Args:
trame_state: Optional Trame state object to capture app state
Returns:
True if successful
"""
if not self.current_session_id or not self.current_metadata or not self.current_state:
return False
# Capture Trame state if provided
if trame_state:
self._capture_trame_state(trame_state)
return self.session_manager.save_session(self.current_metadata, self.current_state)
def _capture_trame_state(self, trame_state: Any) -> None:
"""
Capture Trame server state into session state_data.
Handles serialization of complex types.
"""
if not self.current_state:
return
# Collect app-specific state variables
captured_keys = [
# EM-specific
"grid_size", "frequency", "excitation_type", "excitation_frequency",
"geometry_type", "outer_boundary_condition",
"num_qubits", "t_final", "dt", "backend_type", "selected_simulator",
"selected_qpu", "aqc_enabled", "measurement_type",
# QLBM-specific
"grid_size_qlbm", "num_steps", "initial_condition", "viscosity",
"flow_field_type",
]
state_data = {}
for key in captured_keys:
if hasattr(trame_state, key):
value = getattr(trame_state, key)
# Only store serializable types
try:
json.dumps(value) # Test if serializable
state_data[key] = value
except (TypeError, ValueError):
# Skip non-serializable values
pass
self.current_state.state_data.update(state_data)
def restore_to_trame_state(self, trame_state: Any) -> None:
"""
Restore session state_data back into Trame state.
Args:
trame_state: Trame state object to restore to
"""
if not self.current_state:
return
for key, value in self.current_state.state_data.items():
if hasattr(trame_state, key):
setattr(trame_state, key, value)
def add_job(self, job_id: str, service_type: str) -> bool:
"""Track a submitted job in the session."""
if not self.current_session_id:
return False
return self.session_manager.add_job_to_session(
self.current_session_id,
job_id,
service_type,
)
def update_job_status(
self,
job_id: str,
status: str,
result: Optional[Dict[str, Any]] = None
) -> bool:
"""Update job status in the session."""
if not self.current_session_id:
return False
return self.session_manager.update_job_status(
self.current_session_id,
job_id,
status,
result,
)
def get_current_session_info(self) -> Optional[Dict[str, Any]]:
"""Get current session metadata as a dict."""
if not self.current_metadata:
return None
return self.current_metadata.to_dict()
class AutoSaveContext:
"""
Context manager for auto-saving session changes.
Usage:
with AutoSaveContext(session_integration, trame_state):
# Make changes
pass # Auto-saves on exit
"""
def __init__(
self,
session_integration: SessionIntegration,
trame_state: Any,
capture_state: bool = True
):
self.session_integration = session_integration
self.trame_state = trame_state
self.capture_state = capture_state
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.session_integration.auto_save_enabled:
if self.capture_state:
self.session_integration._capture_trame_state(self.trame_state)
self.session_integration.save_current_session()
return False
class SessionAutoSaveHook:
"""
Decorator for automatically saving sessions after specific operations.
Usage:
@SessionAutoSaveHook(session_integration, trame_state)
def some_operation():
# Make changes
pass # Auto-saves after execution
"""
def __init__(self, session_integration: SessionIntegration, trame_state: Any):
self.session_integration = session_integration
self.trame_state = trame_state
def __call__(self, func: Callable) -> Callable:
def wrapper(*args, **kwargs):
try:
result = func(*args, **kwargs)
finally:
if self.session_integration.auto_save_enabled:
self.session_integration._capture_trame_state(self.trame_state)
self.session_integration.save_current_session()
return result
return wrapper