Spaces:
Runtime error
Runtime error
File size: 7,948 Bytes
ca961b4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 | """
Session Integration Module
Provides session management hooks for EM and QLBM modules.
Handles session loading, state restoration, and auto-saving.
"""
from typing import Optional, Dict, Any, Callable
import json
from pathlib import Path
from session_manager import SessionManager
from session_models import SessionMetadata, SessionState
class SessionIntegration:
"""
Manages session state synchronization with EM/QLBM modules.
Handles loading session state into app modules and saving changes back.
"""
def __init__(self, user_id: str):
"""Initialize session integration for a user."""
self.user_id = user_id
self.session_manager = SessionManager(user_id)
self.current_session_id: Optional[str] = None
self.current_metadata: Optional[SessionMetadata] = None
self.current_state: Optional[SessionState] = None
self.auto_save_enabled = True
def create_new_session(
self,
alias: str,
app_type: str,
description: str = ""
) -> tuple[str, SessionMetadata]:
"""Create a new session."""
session_id, metadata = self.session_manager.create_session(
alias=alias,
app_type=app_type,
description=description,
)
self.current_session_id = session_id
self.current_metadata = metadata
self.current_state = SessionState(
session_id=session_id,
app_type=app_type,
)
return session_id, metadata
def load_session(self, session_id: str) -> tuple[SessionMetadata, SessionState]:
"""Load a session by ID."""
metadata, state = self.session_manager.load_session(session_id)
self.current_session_id = session_id
self.current_metadata = metadata
self.current_state = state
return metadata, state
def load_by_alias(self, alias: str) -> Optional[tuple[SessionMetadata, SessionState]]:
"""Load the most recent session matching an alias."""
result = self.session_manager.get_most_recent_by_alias(alias)
if result:
metadata, session_id = result
return self.load_session(session_id)
return None
def get_state_dict(self) -> Dict[str, Any]:
"""Get the current session state as a dict (e.g., for saving to disk)."""
if not self.current_state:
return {}
return self.current_state.to_dict()
def restore_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Restore session state from a dict."""
if self.current_session_id and state_dict:
self.current_state = SessionState.from_dict(state_dict)
def save_current_session(self, trame_state: Optional[Any] = None) -> bool:
"""
Save the current session.
Args:
trame_state: Optional Trame state object to capture app state
Returns:
True if successful
"""
if not self.current_session_id or not self.current_metadata or not self.current_state:
return False
# Capture Trame state if provided
if trame_state:
self._capture_trame_state(trame_state)
return self.session_manager.save_session(self.current_metadata, self.current_state)
def _capture_trame_state(self, trame_state: Any) -> None:
"""
Capture Trame server state into session state_data.
Handles serialization of complex types.
"""
if not self.current_state:
return
# Collect app-specific state variables
captured_keys = [
# EM-specific
"grid_size", "frequency", "excitation_type", "excitation_frequency",
"geometry_type", "outer_boundary_condition",
"num_qubits", "t_final", "dt", "backend_type", "selected_simulator",
"selected_qpu", "aqc_enabled", "measurement_type",
# QLBM-specific
"grid_size_qlbm", "num_steps", "initial_condition", "viscosity",
"flow_field_type",
]
state_data = {}
for key in captured_keys:
if hasattr(trame_state, key):
value = getattr(trame_state, key)
# Only store serializable types
try:
json.dumps(value) # Test if serializable
state_data[key] = value
except (TypeError, ValueError):
# Skip non-serializable values
pass
self.current_state.state_data.update(state_data)
def restore_to_trame_state(self, trame_state: Any) -> None:
"""
Restore session state_data back into Trame state.
Args:
trame_state: Trame state object to restore to
"""
if not self.current_state:
return
for key, value in self.current_state.state_data.items():
if hasattr(trame_state, key):
setattr(trame_state, key, value)
def add_job(self, job_id: str, service_type: str) -> bool:
"""Track a submitted job in the session."""
if not self.current_session_id:
return False
return self.session_manager.add_job_to_session(
self.current_session_id,
job_id,
service_type,
)
def update_job_status(
self,
job_id: str,
status: str,
result: Optional[Dict[str, Any]] = None
) -> bool:
"""Update job status in the session."""
if not self.current_session_id:
return False
return self.session_manager.update_job_status(
self.current_session_id,
job_id,
status,
result,
)
def get_current_session_info(self) -> Optional[Dict[str, Any]]:
"""Get current session metadata as a dict."""
if not self.current_metadata:
return None
return self.current_metadata.to_dict()
class AutoSaveContext:
"""
Context manager for auto-saving session changes.
Usage:
with AutoSaveContext(session_integration, trame_state):
# Make changes
pass # Auto-saves on exit
"""
def __init__(
self,
session_integration: SessionIntegration,
trame_state: Any,
capture_state: bool = True
):
self.session_integration = session_integration
self.trame_state = trame_state
self.capture_state = capture_state
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.session_integration.auto_save_enabled:
if self.capture_state:
self.session_integration._capture_trame_state(self.trame_state)
self.session_integration.save_current_session()
return False
class SessionAutoSaveHook:
"""
Decorator for automatically saving sessions after specific operations.
Usage:
@SessionAutoSaveHook(session_integration, trame_state)
def some_operation():
# Make changes
pass # Auto-saves after execution
"""
def __init__(self, session_integration: SessionIntegration, trame_state: Any):
self.session_integration = session_integration
self.trame_state = trame_state
def __call__(self, func: Callable) -> Callable:
def wrapper(*args, **kwargs):
try:
result = func(*args, **kwargs)
finally:
if self.session_integration.auto_save_enabled:
self.session_integration._capture_trame_state(self.trame_state)
self.session_integration.save_current_session()
return result
return wrapper
|